Refactored tests

This commit is contained in:
Joona Hoikkala 2016-11-27 19:41:54 +02:00
parent b4cc6b8e81
commit 5072b231af
No known key found for this signature in database
GPG Key ID: C14AAE0F5ADCB854
7 changed files with 88 additions and 147 deletions

View File

@ -9,28 +9,8 @@ import (
"testing" "testing"
) )
func SetupIris(t *testing.T, debug bool, noauth bool) *httpexpect.Expect { func setupIris(t *testing.T, debug bool, noauth bool) *httpexpect.Expect {
iris.ResetDefault() iris.ResetDefault()
var dbcfg = dbsettings{
Engine: "sqlite3",
Connection: ":memory:"}
var httpapicfg = httpapi{
Domain: "",
Port: "8080",
TLS: "none",
CorsOrigins: []string{"*"},
}
var dnscfg = DNSConfig{
API: httpapicfg,
Database: dbcfg,
}
DNSConf = dnscfg
// In memory logger
//logging.InitForTesting(logging.DEBUG)
err := DB.Init(DNSConf.Database.Engine, DNSConf.Database.Connection)
if err != nil {
panic(err)
}
var ForceAuth = authMiddleware{} var ForceAuth = authMiddleware{}
iris.Get("/register", webRegisterGet) iris.Get("/register", webRegisterGet)
iris.Post("/register", webRegisterPost) iris.Post("/register", webRegisterPost)
@ -45,8 +25,7 @@ func SetupIris(t *testing.T, debug bool, noauth bool) *httpexpect.Expect {
} }
func TestApiRegister(t *testing.T) { func TestApiRegister(t *testing.T) {
e := SetupIris(t, false, false) e := setupIris(t, false, false)
defer DB.DB.Close()
e.GET("/register").Expect(). e.GET("/register").Expect().
Status(iris.StatusCreated). Status(iris.StatusCreated).
JSON().Object(). JSON().Object().
@ -66,22 +45,22 @@ func TestApiRegister(t *testing.T) {
} }
func TestApiRegisterWithMockDB(t *testing.T) { func TestApiRegisterWithMockDB(t *testing.T) {
e := SetupIris(t, false, false) e := setupIris(t, false, false)
DB.DB.Close() old_db := DB.DB
db, mock, _ := sqlmock.New() db, mock, _ := sqlmock.New()
DB.DB = db DB.DB = db
defer DB.DB.Close() defer db.Close()
mock.ExpectBegin() mock.ExpectBegin()
mock.ExpectPrepare("INSERT INTO records").WillReturnError(errors.New("error")) mock.ExpectPrepare("INSERT INTO records").WillReturnError(errors.New("error"))
e.GET("/register").Expect(). e.GET("/register").Expect().
Status(iris.StatusInternalServerError). Status(iris.StatusInternalServerError).
JSON().Object(). JSON().Object().
ContainsKey("error") ContainsKey("error")
DB.DB = old_db
} }
func TestApiUpdateWithoutCredentials(t *testing.T) { func TestApiUpdateWithoutCredentials(t *testing.T) {
e := SetupIris(t, false, false) e := setupIris(t, false, false)
defer DB.DB.Close()
e.POST("/update").Expect(). e.POST("/update").Expect().
Status(iris.StatusUnauthorized). Status(iris.StatusUnauthorized).
JSON().Object(). JSON().Object().
@ -96,8 +75,7 @@ func TestApiUpdateWithCredentials(t *testing.T) {
"subdomain": "", "subdomain": "",
"txt": ""} "txt": ""}
e := SetupIris(t, false, false) e := setupIris(t, false, false)
defer DB.DB.Close()
newUser, err := DB.Register() newUser, err := DB.Register()
if err != nil { if err != nil {
t.Errorf("Could not create new user, got error [%v]", err) t.Errorf("Could not create new user, got error [%v]", err)
@ -128,11 +106,11 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
updateJSON["subdomain"] = "a097455b-52cc-4569-90c8-7a4b97c6eba8" updateJSON["subdomain"] = "a097455b-52cc-4569-90c8-7a4b97c6eba8"
updateJSON["txt"] = validTxtData updateJSON["txt"] = validTxtData
e := SetupIris(t, false, true) e := setupIris(t, false, true)
DB.DB.Close() old_db := DB.DB
db, mock, _ := sqlmock.New() db, mock, _ := sqlmock.New()
DB.DB = db DB.DB = db
defer DB.DB.Close() defer db.Close()
mock.ExpectBegin() mock.ExpectBegin()
mock.ExpectPrepare("UPDATE records").WillReturnError(errors.New("error")) mock.ExpectPrepare("UPDATE records").WillReturnError(errors.New("error"))
e.POST("/update"). e.POST("/update").
@ -141,19 +119,19 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
Status(iris.StatusInternalServerError). Status(iris.StatusInternalServerError).
JSON().Object(). JSON().Object().
ContainsKey("error") ContainsKey("error")
DB.DB = old_db
} }
func TestApiManyUpdateWithCredentials(t *testing.T) { func TestApiManyUpdateWithCredentials(t *testing.T) {
// TODO: transfer to using httpexpect builder // TODO: transfer to using httpexpect builder
// If test fails and more debug info is needed, use SetupIris(t, true, false) // If test fails and more debug info is needed, use setupIris(t, true, false)
validTxtData := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" validTxtData := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
updateJSON := map[string]interface{}{ updateJSON := map[string]interface{}{
"subdomain": "", "subdomain": "",
"txt": ""} "txt": ""}
e := SetupIris(t, false, false) e := setupIris(t, false, false)
defer DB.DB.Close()
newUser, err := DB.Register() newUser, err := DB.Register()
if err != nil { if err != nil {
t.Errorf("Could not create new user, got error [%v]", err) t.Errorf("Could not create new user, got error [%v]", err)

14
db.go
View File

@ -9,10 +9,12 @@ import (
"github.com/satori/go.uuid" "github.com/satori/go.uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"regexp" "regexp"
"sync"
"time" "time"
) )
type database struct { type database struct {
sync.Mutex
DB *sql.DB DB *sql.DB
} }
@ -36,12 +38,14 @@ func getSQLiteStmt(s string) string {
} }
func (d *database) Init(engine string, connection string) error { func (d *database) Init(engine string, connection string) error {
d.Lock()
defer d.Unlock()
db, err := sql.Open(engine, connection) db, err := sql.Open(engine, connection)
if err != nil { if err != nil {
return err return err
} }
d.DB = db d.DB = db
d.DB.SetMaxOpenConns(1) //d.DB.SetMaxOpenConns(1)
_, err = d.DB.Exec(recordsTable) _, err = d.DB.Exec(recordsTable)
if err != nil { if err != nil {
return err return err
@ -50,6 +54,8 @@ func (d *database) Init(engine string, connection string) error {
} }
func (d *database) Register() (ACMETxt, error) { func (d *database) Register() (ACMETxt, error) {
d.Lock()
defer d.Unlock()
a, err := newACMETxt() a, err := newACMETxt()
if err != nil { if err != nil {
return ACMETxt{}, err return ACMETxt{}, err
@ -80,6 +86,8 @@ func (d *database) Register() (ACMETxt, error) {
} }
func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) { func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
d.Lock()
defer d.Unlock()
var results []ACMETxt var results []ACMETxt
getSQL := ` getSQL := `
SELECT Username, Password, Subdomain, Value, LastActive SELECT Username, Password, Subdomain, Value, LastActive
@ -122,6 +130,8 @@ func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
} }
func (d *database) GetByDomain(domain string) ([]ACMETxt, error) { func (d *database) GetByDomain(domain string) ([]ACMETxt, error) {
d.Lock()
defer d.Unlock()
domain = sanitizeString(domain) domain = sanitizeString(domain)
var a []ACMETxt var a []ACMETxt
getSQL := ` getSQL := `
@ -156,6 +166,8 @@ func (d *database) GetByDomain(domain string) ([]ACMETxt, error) {
} }
func (d *database) Update(a ACMETxt) error { func (d *database) Update(a ACMETxt) error {
d.Lock()
defer d.Unlock()
// Data in a is already sanitized // Data in a is already sanitized
timenow := time.Now().Unix() timenow := time.Now().Unix()
updSQL := ` updSQL := `

View File

@ -1,29 +1,10 @@
package main package main
import ( import (
"flag"
"testing" "testing"
) )
var (
postgres = flag.Bool("postgres", false, "run integration tests against PostgreSQL")
)
func TestRegister(t *testing.T) { func TestRegister(t *testing.T) {
flag.Parse()
if *postgres {
DNSConf.Database.Engine = "postgres"
err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
if err != nil {
t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"")
return
}
} else {
DNSConf.Database.Engine = "sqlite3"
_ = DB.Init("sqlite3", ":memory:")
}
defer DB.DB.Close()
// Register tests // Register tests
_, err := DB.Register() _, err := DB.Register()
if err != nil { if err != nil {
@ -32,20 +13,6 @@ func TestRegister(t *testing.T) {
} }
func TestGetByUsername(t *testing.T) { func TestGetByUsername(t *testing.T) {
flag.Parse()
if *postgres {
DNSConf.Database.Engine = "postgres"
err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
if err != nil {
t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"")
return
}
} else {
DNSConf.Database.Engine = "sqlite3"
_ = DB.Init("sqlite3", ":memory:")
}
defer DB.DB.Close()
// Create reg to refer to // Create reg to refer to
reg, err := DB.Register() reg, err := DB.Register()
if err != nil { if err != nil {
@ -72,20 +39,6 @@ func TestGetByUsername(t *testing.T) {
} }
func TestGetByDomain(t *testing.T) { func TestGetByDomain(t *testing.T) {
flag.Parse()
if *postgres {
DNSConf.Database.Engine = "postgres"
err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
if err != nil {
t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"")
return
}
} else {
DNSConf.Database.Engine = "sqlite3"
_ = DB.Init("sqlite3", ":memory:")
}
defer DB.DB.Close()
var regDomain = ACMETxt{} var regDomain = ACMETxt{}
// Create reg to refer to // Create reg to refer to
@ -125,20 +78,6 @@ func TestGetByDomain(t *testing.T) {
} }
func TestUpdate(t *testing.T) { func TestUpdate(t *testing.T) {
flag.Parse()
if *postgres {
DNSConf.Database.Engine = "postgres"
err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
if err != nil {
t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"")
return
}
} else {
DNSConf.Database.Engine = "sqlite3"
_ = DB.Init("sqlite3", ":memory:")
}
defer DB.DB.Close()
// Create reg to refer to // Create reg to refer to
reg, err := DB.Register() reg, err := DB.Register()
if err != nil { if err != nil {

11
dns.go
View File

@ -21,10 +21,9 @@ func readQuery(m *dns.Msg) {
func answerTXT(q dns.Question) ([]dns.RR, int, error) { func answerTXT(q dns.Question) ([]dns.RR, int, error) {
var ra []dns.RR var ra []dns.RR
var rcode = dns.RcodeNameError rcode := dns.RcodeNameError
var domain = strings.ToLower(q.Name) subdomain := sanitizeDomainQuestion(q.Name)
atxt, err := DB.GetByDomain(subdomain)
atxt, err := DB.GetByDomain(sanitizeDomainQuestion(domain))
if err != nil { if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record") log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record")
return ra, dns.RcodeNameError, err return ra, dns.RcodeNameError, err
@ -32,14 +31,14 @@ func answerTXT(q dns.Question) ([]dns.RR, int, error) {
for _, v := range atxt { for _, v := range atxt {
if len(v.Value) > 0 { if len(v.Value) > 0 {
r := new(dns.TXT) r := new(dns.TXT)
r.Hdr = dns.RR_Header{Name: domain, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1} r.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1}
r.Txt = append(r.Txt, v.Value) r.Txt = append(r.Txt, v.Value)
ra = append(ra, r) ra = append(ra, r)
rcode = dns.RcodeSuccess rcode = dns.RcodeSuccess
} }
} }
log.WithFields(log.Fields{"domain": domain}).Info("Answering TXT question for domain") log.WithFields(log.Fields{"domain": q.Name}).Info("Answering TXT question for domain")
return ra, rcode, nil return ra, rcode, nil
} }

View File

@ -2,7 +2,6 @@ package main
import ( import (
"errors" "errors"
"flag"
"fmt" "fmt"
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -75,27 +74,6 @@ func findRecordFromMemory(rrstr string, host string, qtype uint16) error {
return errors.New(errmsg) return errors.New(errmsg)
} }
func setupConfig() {
var dbcfg = dbsettings{
Engine: "sqlite3",
Connection: ":memory:",
}
var generalcfg = general{
Domain: "auth.example.org",
Nsname: "ns1.auth.example.org",
Nsadmin: "admin.example.org",
Debug: false,
}
var dnscfg = DNSConfig{
Database: dbcfg,
General: generalcfg,
}
DNSConf = dnscfg
}
func startDNSServer(addr string) (*dns.Server, resolver) { func startDNSServer(addr string) (*dns.Server, resolver) {
// DNS server part // DNS server part
@ -111,7 +89,6 @@ func startDNSServer(addr string) (*dns.Server, resolver) {
} }
func TestResolveA(t *testing.T) { func TestResolveA(t *testing.T) {
RR.Parse(records)
setupConfig() setupConfig()
answer, err := resolv.lookup("auth.example.org", dns.TypeA) answer, err := resolv.lookup("auth.example.org", dns.TypeA)
if err != nil { if err != nil {
@ -130,20 +107,7 @@ func TestResolveA(t *testing.T) {
} }
func TestResolveTXT(t *testing.T) { func TestResolveTXT(t *testing.T) {
flag.Parse()
setupConfig() setupConfig()
if *postgres {
DNSConf.Database.Engine = "postgres"
err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
if err != nil {
t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"")
return
}
} else {
DNSConf.Database.Engine = "sqlite3"
_ = DB.Init("sqlite3", ":memory:")
}
defer DB.DB.Close()
validTXT := "______________valid_response_______________" validTXT := "______________valid_response_______________"

View File

@ -2,15 +2,64 @@ package main
import ( import (
"flag" "flag"
"fmt"
"os" "os"
"testing" "testing"
) )
var (
postgres = flag.Bool("postgres", false, "run integration tests against PostgreSQL")
)
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
server, resolv = startDNSServer("0.0.0.0:15353") setupConfig()
RR.Parse(records)
flag.Parse() flag.Parse()
if *postgres {
DNSConf.Database.Engine = "postgres"
err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
if err != nil {
fmt.Println("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"")
os.Exit(1)
}
} else {
DNSConf.Database.Engine = "sqlite3"
_ = DB.Init("sqlite3", ":memory:")
}
server, resolv = startDNSServer("0.0.0.0:15353")
exitval := m.Run() exitval := m.Run()
server.Shutdown() server.Shutdown()
DB.DB.Close() DB.DB.Close()
os.Exit(exitval) os.Exit(exitval)
} }
func setupConfig() {
var dbcfg = dbsettings{
Engine: "sqlite3",
Connection: ":memory:",
}
var generalcfg = general{
Domain: "auth.example.org",
Nsname: "ns1.auth.example.org",
Nsadmin: "admin.example.org",
Debug: false,
}
var httpapicfg = httpapi{
Domain: "",
Port: "8080",
TLS: "none",
CorsOrigins: []string{"*"},
}
var dnscfg = DNSConfig{
Database: dbcfg,
General: generalcfg,
API: httpapicfg,
}
DNSConf = dnscfg
}

12
util.go
View File

@ -3,6 +3,7 @@ package main
import ( import (
"crypto/rand" "crypto/rand"
"errors" "errors"
"fmt"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
"github.com/satori/go.uuid" "github.com/satori/go.uuid"
@ -45,13 +46,12 @@ func generatePassword(length int) (string, error) {
} }
func sanitizeDomainQuestion(d string) string { func sanitizeDomainQuestion(d string) string {
var dom string dom := strings.ToLower(d)
suffix := DNSConf.General.Domain + "." firstDot := strings.Index(d, ".")
if strings.HasSuffix(d, suffix) { if firstDot > 0 {
dom = d[0 : len(d)-len(suffix)] dom = dom[0:firstDot]
} else {
dom = d
} }
fmt.Printf("%s\n", dom)
return dom return dom
} }