Refactored tests
This commit is contained in:
parent
b4cc6b8e81
commit
5072b231af
50
api_test.go
50
api_test.go
@ -9,28 +9,8 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func SetupIris(t *testing.T, debug bool, noauth bool) *httpexpect.Expect {
|
||||
func setupIris(t *testing.T, debug bool, noauth bool) *httpexpect.Expect {
|
||||
iris.ResetDefault()
|
||||
var dbcfg = dbsettings{
|
||||
Engine: "sqlite3",
|
||||
Connection: ":memory:"}
|
||||
var httpapicfg = httpapi{
|
||||
Domain: "",
|
||||
Port: "8080",
|
||||
TLS: "none",
|
||||
CorsOrigins: []string{"*"},
|
||||
}
|
||||
var dnscfg = DNSConfig{
|
||||
API: httpapicfg,
|
||||
Database: dbcfg,
|
||||
}
|
||||
DNSConf = dnscfg
|
||||
// In memory logger
|
||||
//logging.InitForTesting(logging.DEBUG)
|
||||
err := DB.Init(DNSConf.Database.Engine, DNSConf.Database.Connection)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var ForceAuth = authMiddleware{}
|
||||
iris.Get("/register", webRegisterGet)
|
||||
iris.Post("/register", webRegisterPost)
|
||||
@ -45,8 +25,7 @@ func SetupIris(t *testing.T, debug bool, noauth bool) *httpexpect.Expect {
|
||||
}
|
||||
|
||||
func TestApiRegister(t *testing.T) {
|
||||
e := SetupIris(t, false, false)
|
||||
defer DB.DB.Close()
|
||||
e := setupIris(t, false, false)
|
||||
e.GET("/register").Expect().
|
||||
Status(iris.StatusCreated).
|
||||
JSON().Object().
|
||||
@ -66,22 +45,22 @@ func TestApiRegister(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestApiRegisterWithMockDB(t *testing.T) {
|
||||
e := SetupIris(t, false, false)
|
||||
DB.DB.Close()
|
||||
e := setupIris(t, false, false)
|
||||
old_db := DB.DB
|
||||
db, mock, _ := sqlmock.New()
|
||||
DB.DB = db
|
||||
defer DB.DB.Close()
|
||||
defer db.Close()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectPrepare("INSERT INTO records").WillReturnError(errors.New("error"))
|
||||
e.GET("/register").Expect().
|
||||
Status(iris.StatusInternalServerError).
|
||||
JSON().Object().
|
||||
ContainsKey("error")
|
||||
DB.DB = old_db
|
||||
}
|
||||
|
||||
func TestApiUpdateWithoutCredentials(t *testing.T) {
|
||||
e := SetupIris(t, false, false)
|
||||
defer DB.DB.Close()
|
||||
e := setupIris(t, false, false)
|
||||
e.POST("/update").Expect().
|
||||
Status(iris.StatusUnauthorized).
|
||||
JSON().Object().
|
||||
@ -96,8 +75,7 @@ func TestApiUpdateWithCredentials(t *testing.T) {
|
||||
"subdomain": "",
|
||||
"txt": ""}
|
||||
|
||||
e := SetupIris(t, false, false)
|
||||
defer DB.DB.Close()
|
||||
e := setupIris(t, false, false)
|
||||
newUser, err := DB.Register()
|
||||
if err != nil {
|
||||
t.Errorf("Could not create new user, got error [%v]", err)
|
||||
@ -128,11 +106,11 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
|
||||
updateJSON["subdomain"] = "a097455b-52cc-4569-90c8-7a4b97c6eba8"
|
||||
updateJSON["txt"] = validTxtData
|
||||
|
||||
e := SetupIris(t, false, true)
|
||||
DB.DB.Close()
|
||||
e := setupIris(t, false, true)
|
||||
old_db := DB.DB
|
||||
db, mock, _ := sqlmock.New()
|
||||
DB.DB = db
|
||||
defer DB.DB.Close()
|
||||
defer db.Close()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectPrepare("UPDATE records").WillReturnError(errors.New("error"))
|
||||
e.POST("/update").
|
||||
@ -141,19 +119,19 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
|
||||
Status(iris.StatusInternalServerError).
|
||||
JSON().Object().
|
||||
ContainsKey("error")
|
||||
DB.DB = old_db
|
||||
}
|
||||
|
||||
func TestApiManyUpdateWithCredentials(t *testing.T) {
|
||||
// TODO: transfer to using httpexpect builder
|
||||
// If test fails and more debug info is needed, use SetupIris(t, true, false)
|
||||
// If test fails and more debug info is needed, use setupIris(t, true, false)
|
||||
validTxtData := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
|
||||
updateJSON := map[string]interface{}{
|
||||
"subdomain": "",
|
||||
"txt": ""}
|
||||
|
||||
e := SetupIris(t, false, false)
|
||||
defer DB.DB.Close()
|
||||
e := setupIris(t, false, false)
|
||||
newUser, err := DB.Register()
|
||||
if err != nil {
|
||||
t.Errorf("Could not create new user, got error [%v]", err)
|
||||
|
||||
14
db.go
14
db.go
@ -9,10 +9,12 @@ import (
|
||||
"github.com/satori/go.uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"regexp"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type database struct {
|
||||
sync.Mutex
|
||||
DB *sql.DB
|
||||
}
|
||||
|
||||
@ -36,12 +38,14 @@ func getSQLiteStmt(s string) string {
|
||||
}
|
||||
|
||||
func (d *database) Init(engine string, connection string) error {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
db, err := sql.Open(engine, connection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.DB = db
|
||||
d.DB.SetMaxOpenConns(1)
|
||||
//d.DB.SetMaxOpenConns(1)
|
||||
_, err = d.DB.Exec(recordsTable)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -50,6 +54,8 @@ func (d *database) Init(engine string, connection string) error {
|
||||
}
|
||||
|
||||
func (d *database) Register() (ACMETxt, error) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
a, err := newACMETxt()
|
||||
if err != nil {
|
||||
return ACMETxt{}, err
|
||||
@ -80,6 +86,8 @@ func (d *database) Register() (ACMETxt, error) {
|
||||
}
|
||||
|
||||
func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
var results []ACMETxt
|
||||
getSQL := `
|
||||
SELECT Username, Password, Subdomain, Value, LastActive
|
||||
@ -122,6 +130,8 @@ func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
||||
}
|
||||
|
||||
func (d *database) GetByDomain(domain string) ([]ACMETxt, error) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
domain = sanitizeString(domain)
|
||||
var a []ACMETxt
|
||||
getSQL := `
|
||||
@ -156,6 +166,8 @@ func (d *database) GetByDomain(domain string) ([]ACMETxt, error) {
|
||||
}
|
||||
|
||||
func (d *database) Update(a ACMETxt) error {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
// Data in a is already sanitized
|
||||
timenow := time.Now().Unix()
|
||||
updSQL := `
|
||||
|
||||
61
db_test.go
61
db_test.go
@ -1,29 +1,10 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
postgres = flag.Bool("postgres", false, "run integration tests against PostgreSQL")
|
||||
)
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
flag.Parse()
|
||||
if *postgres {
|
||||
DNSConf.Database.Engine = "postgres"
|
||||
err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
|
||||
if err != nil {
|
||||
t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
DNSConf.Database.Engine = "sqlite3"
|
||||
_ = DB.Init("sqlite3", ":memory:")
|
||||
}
|
||||
defer DB.DB.Close()
|
||||
|
||||
// Register tests
|
||||
_, err := DB.Register()
|
||||
if err != nil {
|
||||
@ -32,20 +13,6 @@ func TestRegister(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetByUsername(t *testing.T) {
|
||||
flag.Parse()
|
||||
if *postgres {
|
||||
DNSConf.Database.Engine = "postgres"
|
||||
err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
|
||||
if err != nil {
|
||||
t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
DNSConf.Database.Engine = "sqlite3"
|
||||
_ = DB.Init("sqlite3", ":memory:")
|
||||
}
|
||||
defer DB.DB.Close()
|
||||
|
||||
// Create reg to refer to
|
||||
reg, err := DB.Register()
|
||||
if err != nil {
|
||||
@ -72,20 +39,6 @@ func TestGetByUsername(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetByDomain(t *testing.T) {
|
||||
flag.Parse()
|
||||
if *postgres {
|
||||
DNSConf.Database.Engine = "postgres"
|
||||
err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
|
||||
if err != nil {
|
||||
t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
DNSConf.Database.Engine = "sqlite3"
|
||||
_ = DB.Init("sqlite3", ":memory:")
|
||||
}
|
||||
defer DB.DB.Close()
|
||||
|
||||
var regDomain = ACMETxt{}
|
||||
|
||||
// Create reg to refer to
|
||||
@ -125,20 +78,6 @@ func TestGetByDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
flag.Parse()
|
||||
if *postgres {
|
||||
DNSConf.Database.Engine = "postgres"
|
||||
err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
|
||||
if err != nil {
|
||||
t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
DNSConf.Database.Engine = "sqlite3"
|
||||
_ = DB.Init("sqlite3", ":memory:")
|
||||
}
|
||||
defer DB.DB.Close()
|
||||
|
||||
// Create reg to refer to
|
||||
reg, err := DB.Register()
|
||||
if err != nil {
|
||||
|
||||
11
dns.go
11
dns.go
@ -21,10 +21,9 @@ func readQuery(m *dns.Msg) {
|
||||
|
||||
func answerTXT(q dns.Question) ([]dns.RR, int, error) {
|
||||
var ra []dns.RR
|
||||
var rcode = dns.RcodeNameError
|
||||
var domain = strings.ToLower(q.Name)
|
||||
|
||||
atxt, err := DB.GetByDomain(sanitizeDomainQuestion(domain))
|
||||
rcode := dns.RcodeNameError
|
||||
subdomain := sanitizeDomainQuestion(q.Name)
|
||||
atxt, err := DB.GetByDomain(subdomain)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record")
|
||||
return ra, dns.RcodeNameError, err
|
||||
@ -32,14 +31,14 @@ func answerTXT(q dns.Question) ([]dns.RR, int, error) {
|
||||
for _, v := range atxt {
|
||||
if len(v.Value) > 0 {
|
||||
r := new(dns.TXT)
|
||||
r.Hdr = dns.RR_Header{Name: domain, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1}
|
||||
r.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1}
|
||||
r.Txt = append(r.Txt, v.Value)
|
||||
ra = append(ra, r)
|
||||
rcode = dns.RcodeSuccess
|
||||
}
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{"domain": domain}).Info("Answering TXT question for domain")
|
||||
log.WithFields(log.Fields{"domain": q.Name}).Info("Answering TXT question for domain")
|
||||
return ra, rcode, nil
|
||||
}
|
||||
|
||||
|
||||
36
dns_test.go
36
dns_test.go
@ -2,7 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"github.com/miekg/dns"
|
||||
@ -75,27 +74,6 @@ func findRecordFromMemory(rrstr string, host string, qtype uint16) error {
|
||||
return errors.New(errmsg)
|
||||
}
|
||||
|
||||
func setupConfig() {
|
||||
var dbcfg = dbsettings{
|
||||
Engine: "sqlite3",
|
||||
Connection: ":memory:",
|
||||
}
|
||||
|
||||
var generalcfg = general{
|
||||
Domain: "auth.example.org",
|
||||
Nsname: "ns1.auth.example.org",
|
||||
Nsadmin: "admin.example.org",
|
||||
Debug: false,
|
||||
}
|
||||
|
||||
var dnscfg = DNSConfig{
|
||||
Database: dbcfg,
|
||||
General: generalcfg,
|
||||
}
|
||||
|
||||
DNSConf = dnscfg
|
||||
}
|
||||
|
||||
func startDNSServer(addr string) (*dns.Server, resolver) {
|
||||
|
||||
// DNS server part
|
||||
@ -111,7 +89,6 @@ func startDNSServer(addr string) (*dns.Server, resolver) {
|
||||
}
|
||||
|
||||
func TestResolveA(t *testing.T) {
|
||||
RR.Parse(records)
|
||||
setupConfig()
|
||||
answer, err := resolv.lookup("auth.example.org", dns.TypeA)
|
||||
if err != nil {
|
||||
@ -130,20 +107,7 @@ func TestResolveA(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolveTXT(t *testing.T) {
|
||||
flag.Parse()
|
||||
setupConfig()
|
||||
if *postgres {
|
||||
DNSConf.Database.Engine = "postgres"
|
||||
err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
|
||||
if err != nil {
|
||||
t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
DNSConf.Database.Engine = "sqlite3"
|
||||
_ = DB.Init("sqlite3", ":memory:")
|
||||
}
|
||||
defer DB.DB.Close()
|
||||
|
||||
validTXT := "______________valid_response_______________"
|
||||
|
||||
|
||||
51
main_test.go
51
main_test.go
@ -2,15 +2,64 @@ package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
postgres = flag.Bool("postgres", false, "run integration tests against PostgreSQL")
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
server, resolv = startDNSServer("0.0.0.0:15353")
|
||||
setupConfig()
|
||||
RR.Parse(records)
|
||||
flag.Parse()
|
||||
|
||||
if *postgres {
|
||||
DNSConf.Database.Engine = "postgres"
|
||||
err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
|
||||
if err != nil {
|
||||
fmt.Println("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"")
|
||||
os.Exit(1)
|
||||
}
|
||||
} else {
|
||||
DNSConf.Database.Engine = "sqlite3"
|
||||
_ = DB.Init("sqlite3", ":memory:")
|
||||
}
|
||||
|
||||
server, resolv = startDNSServer("0.0.0.0:15353")
|
||||
exitval := m.Run()
|
||||
server.Shutdown()
|
||||
DB.DB.Close()
|
||||
os.Exit(exitval)
|
||||
}
|
||||
|
||||
func setupConfig() {
|
||||
var dbcfg = dbsettings{
|
||||
Engine: "sqlite3",
|
||||
Connection: ":memory:",
|
||||
}
|
||||
|
||||
var generalcfg = general{
|
||||
Domain: "auth.example.org",
|
||||
Nsname: "ns1.auth.example.org",
|
||||
Nsadmin: "admin.example.org",
|
||||
Debug: false,
|
||||
}
|
||||
|
||||
var httpapicfg = httpapi{
|
||||
Domain: "",
|
||||
Port: "8080",
|
||||
TLS: "none",
|
||||
CorsOrigins: []string{"*"},
|
||||
}
|
||||
|
||||
var dnscfg = DNSConfig{
|
||||
Database: dbcfg,
|
||||
General: generalcfg,
|
||||
API: httpapicfg,
|
||||
}
|
||||
|
||||
DNSConf = dnscfg
|
||||
}
|
||||
|
||||
12
util.go
12
util.go
@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/BurntSushi/toml"
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"github.com/satori/go.uuid"
|
||||
@ -45,13 +46,12 @@ func generatePassword(length int) (string, error) {
|
||||
}
|
||||
|
||||
func sanitizeDomainQuestion(d string) string {
|
||||
var dom string
|
||||
suffix := DNSConf.General.Domain + "."
|
||||
if strings.HasSuffix(d, suffix) {
|
||||
dom = d[0 : len(d)-len(suffix)]
|
||||
} else {
|
||||
dom = d
|
||||
dom := strings.ToLower(d)
|
||||
firstDot := strings.Index(d, ".")
|
||||
if firstDot > 0 {
|
||||
dom = dom[0:firstDot]
|
||||
}
|
||||
fmt.Printf("%s\n", dom)
|
||||
return dom
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user