From 5072b231afc193ad40aab7ec1212ee6390d2d061 Mon Sep 17 00:00:00 2001 From: Joona Hoikkala Date: Sun, 27 Nov 2016 19:41:54 +0200 Subject: [PATCH] Refactored tests --- api_test.go | 50 ++++++++++++------------------------------ db.go | 14 +++++++++++- db_test.go | 61 ---------------------------------------------------- dns.go | 11 +++++----- dns_test.go | 36 ------------------------------- main_test.go | 51 ++++++++++++++++++++++++++++++++++++++++++- util.go | 12 +++++------ 7 files changed, 88 insertions(+), 147 deletions(-) diff --git a/api_test.go b/api_test.go index 5571fd2..3e1b472 100644 --- a/api_test.go +++ b/api_test.go @@ -9,28 +9,8 @@ import ( "testing" ) -func SetupIris(t *testing.T, debug bool, noauth bool) *httpexpect.Expect { +func setupIris(t *testing.T, debug bool, noauth bool) *httpexpect.Expect { iris.ResetDefault() - var dbcfg = dbsettings{ - Engine: "sqlite3", - Connection: ":memory:"} - var httpapicfg = httpapi{ - Domain: "", - Port: "8080", - TLS: "none", - CorsOrigins: []string{"*"}, - } - var dnscfg = DNSConfig{ - API: httpapicfg, - Database: dbcfg, - } - DNSConf = dnscfg - // In memory logger - //logging.InitForTesting(logging.DEBUG) - err := DB.Init(DNSConf.Database.Engine, DNSConf.Database.Connection) - if err != nil { - panic(err) - } var ForceAuth = authMiddleware{} iris.Get("/register", webRegisterGet) iris.Post("/register", webRegisterPost) @@ -45,8 +25,7 @@ func SetupIris(t *testing.T, debug bool, noauth bool) *httpexpect.Expect { } func TestApiRegister(t *testing.T) { - e := SetupIris(t, false, false) - defer DB.DB.Close() + e := setupIris(t, false, false) e.GET("/register").Expect(). Status(iris.StatusCreated). JSON().Object(). @@ -66,22 +45,22 @@ func TestApiRegister(t *testing.T) { } func TestApiRegisterWithMockDB(t *testing.T) { - e := SetupIris(t, false, false) - DB.DB.Close() + e := setupIris(t, false, false) + old_db := DB.DB db, mock, _ := sqlmock.New() DB.DB = db - defer DB.DB.Close() + defer db.Close() mock.ExpectBegin() mock.ExpectPrepare("INSERT INTO records").WillReturnError(errors.New("error")) e.GET("/register").Expect(). Status(iris.StatusInternalServerError). JSON().Object(). ContainsKey("error") + DB.DB = old_db } func TestApiUpdateWithoutCredentials(t *testing.T) { - e := SetupIris(t, false, false) - defer DB.DB.Close() + e := setupIris(t, false, false) e.POST("/update").Expect(). Status(iris.StatusUnauthorized). JSON().Object(). @@ -96,8 +75,7 @@ func TestApiUpdateWithCredentials(t *testing.T) { "subdomain": "", "txt": ""} - e := SetupIris(t, false, false) - defer DB.DB.Close() + e := setupIris(t, false, false) newUser, err := DB.Register() if err != nil { t.Errorf("Could not create new user, got error [%v]", err) @@ -128,11 +106,11 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) { updateJSON["subdomain"] = "a097455b-52cc-4569-90c8-7a4b97c6eba8" updateJSON["txt"] = validTxtData - e := SetupIris(t, false, true) - DB.DB.Close() + e := setupIris(t, false, true) + old_db := DB.DB db, mock, _ := sqlmock.New() DB.DB = db - defer DB.DB.Close() + defer db.Close() mock.ExpectBegin() mock.ExpectPrepare("UPDATE records").WillReturnError(errors.New("error")) e.POST("/update"). @@ -141,19 +119,19 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) { Status(iris.StatusInternalServerError). JSON().Object(). ContainsKey("error") + DB.DB = old_db } func TestApiManyUpdateWithCredentials(t *testing.T) { // TODO: transfer to using httpexpect builder - // If test fails and more debug info is needed, use SetupIris(t, true, false) + // If test fails and more debug info is needed, use setupIris(t, true, false) validTxtData := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" updateJSON := map[string]interface{}{ "subdomain": "", "txt": ""} - e := SetupIris(t, false, false) - defer DB.DB.Close() + e := setupIris(t, false, false) newUser, err := DB.Register() if err != nil { t.Errorf("Could not create new user, got error [%v]", err) diff --git a/db.go b/db.go index 563a2b2..99092df 100644 --- a/db.go +++ b/db.go @@ -9,10 +9,12 @@ import ( "github.com/satori/go.uuid" "golang.org/x/crypto/bcrypt" "regexp" + "sync" "time" ) type database struct { + sync.Mutex DB *sql.DB } @@ -36,12 +38,14 @@ func getSQLiteStmt(s string) string { } func (d *database) Init(engine string, connection string) error { + d.Lock() + defer d.Unlock() db, err := sql.Open(engine, connection) if err != nil { return err } d.DB = db - d.DB.SetMaxOpenConns(1) + //d.DB.SetMaxOpenConns(1) _, err = d.DB.Exec(recordsTable) if err != nil { return err @@ -50,6 +54,8 @@ func (d *database) Init(engine string, connection string) error { } func (d *database) Register() (ACMETxt, error) { + d.Lock() + defer d.Unlock() a, err := newACMETxt() if err != nil { return ACMETxt{}, err @@ -80,6 +86,8 @@ func (d *database) Register() (ACMETxt, error) { } func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) { + d.Lock() + defer d.Unlock() var results []ACMETxt getSQL := ` SELECT Username, Password, Subdomain, Value, LastActive @@ -122,6 +130,8 @@ func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) { } func (d *database) GetByDomain(domain string) ([]ACMETxt, error) { + d.Lock() + defer d.Unlock() domain = sanitizeString(domain) var a []ACMETxt getSQL := ` @@ -156,6 +166,8 @@ func (d *database) GetByDomain(domain string) ([]ACMETxt, error) { } func (d *database) Update(a ACMETxt) error { + d.Lock() + defer d.Unlock() // Data in a is already sanitized timenow := time.Now().Unix() updSQL := ` diff --git a/db_test.go b/db_test.go index 7e245ca..376a339 100644 --- a/db_test.go +++ b/db_test.go @@ -1,29 +1,10 @@ package main import ( - "flag" "testing" ) -var ( - postgres = flag.Bool("postgres", false, "run integration tests against PostgreSQL") -) - func TestRegister(t *testing.T) { - flag.Parse() - if *postgres { - DNSConf.Database.Engine = "postgres" - err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns") - if err != nil { - t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"") - return - } - } else { - DNSConf.Database.Engine = "sqlite3" - _ = DB.Init("sqlite3", ":memory:") - } - defer DB.DB.Close() - // Register tests _, err := DB.Register() if err != nil { @@ -32,20 +13,6 @@ func TestRegister(t *testing.T) { } func TestGetByUsername(t *testing.T) { - flag.Parse() - if *postgres { - DNSConf.Database.Engine = "postgres" - err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns") - if err != nil { - t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"") - return - } - } else { - DNSConf.Database.Engine = "sqlite3" - _ = DB.Init("sqlite3", ":memory:") - } - defer DB.DB.Close() - // Create reg to refer to reg, err := DB.Register() if err != nil { @@ -72,20 +39,6 @@ func TestGetByUsername(t *testing.T) { } func TestGetByDomain(t *testing.T) { - flag.Parse() - if *postgres { - DNSConf.Database.Engine = "postgres" - err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns") - if err != nil { - t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"") - return - } - } else { - DNSConf.Database.Engine = "sqlite3" - _ = DB.Init("sqlite3", ":memory:") - } - defer DB.DB.Close() - var regDomain = ACMETxt{} // Create reg to refer to @@ -125,20 +78,6 @@ func TestGetByDomain(t *testing.T) { } func TestUpdate(t *testing.T) { - flag.Parse() - if *postgres { - DNSConf.Database.Engine = "postgres" - err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns") - if err != nil { - t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"") - return - } - } else { - DNSConf.Database.Engine = "sqlite3" - _ = DB.Init("sqlite3", ":memory:") - } - defer DB.DB.Close() - // Create reg to refer to reg, err := DB.Register() if err != nil { diff --git a/dns.go b/dns.go index 4bb3994..7c7656b 100644 --- a/dns.go +++ b/dns.go @@ -21,10 +21,9 @@ func readQuery(m *dns.Msg) { func answerTXT(q dns.Question) ([]dns.RR, int, error) { var ra []dns.RR - var rcode = dns.RcodeNameError - var domain = strings.ToLower(q.Name) - - atxt, err := DB.GetByDomain(sanitizeDomainQuestion(domain)) + rcode := dns.RcodeNameError + subdomain := sanitizeDomainQuestion(q.Name) + atxt, err := DB.GetByDomain(subdomain) if err != nil { log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record") return ra, dns.RcodeNameError, err @@ -32,14 +31,14 @@ func answerTXT(q dns.Question) ([]dns.RR, int, error) { for _, v := range atxt { if len(v.Value) > 0 { r := new(dns.TXT) - r.Hdr = dns.RR_Header{Name: domain, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1} + r.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1} r.Txt = append(r.Txt, v.Value) ra = append(ra, r) rcode = dns.RcodeSuccess } } - log.WithFields(log.Fields{"domain": domain}).Info("Answering TXT question for domain") + log.WithFields(log.Fields{"domain": q.Name}).Info("Answering TXT question for domain") return ra, rcode, nil } diff --git a/dns_test.go b/dns_test.go index 71987d3..0d37975 100644 --- a/dns_test.go +++ b/dns_test.go @@ -2,7 +2,6 @@ package main import ( "errors" - "flag" "fmt" log "github.com/Sirupsen/logrus" "github.com/miekg/dns" @@ -75,27 +74,6 @@ func findRecordFromMemory(rrstr string, host string, qtype uint16) error { return errors.New(errmsg) } -func setupConfig() { - var dbcfg = dbsettings{ - Engine: "sqlite3", - Connection: ":memory:", - } - - var generalcfg = general{ - Domain: "auth.example.org", - Nsname: "ns1.auth.example.org", - Nsadmin: "admin.example.org", - Debug: false, - } - - var dnscfg = DNSConfig{ - Database: dbcfg, - General: generalcfg, - } - - DNSConf = dnscfg -} - func startDNSServer(addr string) (*dns.Server, resolver) { // DNS server part @@ -111,7 +89,6 @@ func startDNSServer(addr string) (*dns.Server, resolver) { } func TestResolveA(t *testing.T) { - RR.Parse(records) setupConfig() answer, err := resolv.lookup("auth.example.org", dns.TypeA) if err != nil { @@ -130,20 +107,7 @@ func TestResolveA(t *testing.T) { } func TestResolveTXT(t *testing.T) { - flag.Parse() setupConfig() - if *postgres { - DNSConf.Database.Engine = "postgres" - err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns") - if err != nil { - t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"") - return - } - } else { - DNSConf.Database.Engine = "sqlite3" - _ = DB.Init("sqlite3", ":memory:") - } - defer DB.DB.Close() validTXT := "______________valid_response_______________" diff --git a/main_test.go b/main_test.go index 3609754..795b5a0 100644 --- a/main_test.go +++ b/main_test.go @@ -2,15 +2,64 @@ package main import ( "flag" + "fmt" "os" "testing" ) +var ( + postgres = flag.Bool("postgres", false, "run integration tests against PostgreSQL") +) + func TestMain(m *testing.M) { - server, resolv = startDNSServer("0.0.0.0:15353") + setupConfig() + RR.Parse(records) flag.Parse() + + if *postgres { + DNSConf.Database.Engine = "postgres" + err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns") + if err != nil { + fmt.Println("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"") + os.Exit(1) + } + } else { + DNSConf.Database.Engine = "sqlite3" + _ = DB.Init("sqlite3", ":memory:") + } + + server, resolv = startDNSServer("0.0.0.0:15353") exitval := m.Run() server.Shutdown() DB.DB.Close() os.Exit(exitval) } + +func setupConfig() { + var dbcfg = dbsettings{ + Engine: "sqlite3", + Connection: ":memory:", + } + + var generalcfg = general{ + Domain: "auth.example.org", + Nsname: "ns1.auth.example.org", + Nsadmin: "admin.example.org", + Debug: false, + } + + var httpapicfg = httpapi{ + Domain: "", + Port: "8080", + TLS: "none", + CorsOrigins: []string{"*"}, + } + + var dnscfg = DNSConfig{ + Database: dbcfg, + General: generalcfg, + API: httpapicfg, + } + + DNSConf = dnscfg +} diff --git a/util.go b/util.go index 392495b..3a49e7b 100644 --- a/util.go +++ b/util.go @@ -3,6 +3,7 @@ package main import ( "crypto/rand" "errors" + "fmt" "github.com/BurntSushi/toml" log "github.com/Sirupsen/logrus" "github.com/satori/go.uuid" @@ -45,13 +46,12 @@ func generatePassword(length int) (string, error) { } func sanitizeDomainQuestion(d string) string { - var dom string - suffix := DNSConf.General.Domain + "." - if strings.HasSuffix(d, suffix) { - dom = d[0 : len(d)-len(suffix)] - } else { - dom = d + dom := strings.ToLower(d) + firstDot := strings.Index(d, ".") + if firstDot > 0 { + dom = dom[0:firstDot] } + fmt.Printf("%s\n", dom) return dom }