From 31a82e680f934d388a0508d715d97be6651ffaae Mon Sep 17 00:00:00 2001 From: Joona Hoikkala Date: Mon, 28 Nov 2016 01:55:57 +0200 Subject: [PATCH] Better DB test coverage --- db.go | 19 +----- db_test.go | 167 +++++++++++++++++++++++++++++++++++++++++++++++++++++ util.go | 4 +- 3 files changed, 172 insertions(+), 18 deletions(-) diff --git a/db.go b/db.go index 85348d8..aaa6cdb 100644 --- a/db.go +++ b/db.go @@ -3,7 +3,6 @@ package main import ( "database/sql" "errors" - log "github.com/Sirupsen/logrus" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" "github.com/satori/go.uuid" @@ -23,11 +22,7 @@ var recordsTable = ` // getSQLiteStmt replaces all PostgreSQL prepared statement placeholders (eg. $1, $2) with SQLite variant "?" func getSQLiteStmt(s string) string { - re, err := regexp.Compile("\\$[0-9]") - if err != nil { - log.WithFields(log.Fields{"error": err.Error()}).Debug("Error in regexp") - return s - } + re, _ := regexp.Compile("\\$[0-9]") return re.ReplaceAllString(s, "?") } @@ -50,10 +45,7 @@ func (d *acmedb) Init(engine string, connection string) error { func (d *acmedb) Register() (ACMETxt, error) { d.Lock() defer d.Unlock() - a, err := newACMETxt() - if err != nil { - return ACMETxt{}, err - } + a := newACMETxt() passwordHash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10) timenow := time.Now().Unix() regSQL := ` @@ -106,12 +98,7 @@ func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) { // It will only be one row though for rows.Next() { a := ACMETxt{} - var uname string - err = rows.Scan(&uname, &a.Password, &a.Subdomain, &a.Value, &a.LastActive) - if err != nil { - return ACMETxt{}, err - } - a.Username, err = uuid.FromString(uname) + err = rows.Scan(&a.Username, &a.Password, &a.Subdomain, &a.Value, &a.LastActive) if err != nil { return ACMETxt{}, err } diff --git a/db_test.go b/db_test.go index 376a339..665bed3 100644 --- a/db_test.go +++ b/db_test.go @@ -1,9 +1,46 @@ package main import ( + "database/sql" + "database/sql/driver" + "errors" + "github.com/erikstmartin/go-testdb" "testing" ) +type testResult struct { + lastID int64 + affectedRows int64 +} + +func (r testResult) LastInsertId() (int64, error) { + return r.lastID, nil +} + +func (r testResult) RowsAffected() (int64, error) { + return r.affectedRows, nil +} + +func TestDBInit(t *testing.T) { + fakeDB := new(acmedb) + err := fakeDB.Init("notarealegine", "connectionstring") + if err == nil { + t.Errorf("Was expecting error, didn't get one.") + } + + testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) { + return testResult{1, 0}, errors.New("Prepared query error") + }) + defer testdb.Reset() + + errorDB := new(acmedb) + err = errorDB.Init("testdb", "") + if err == nil { + t.Errorf("Was expecting DB initiation error but got none") + } + errorDB.Close() +} + func TestRegister(t *testing.T) { // Register tests _, err := DB.Register() @@ -38,6 +75,136 @@ func TestGetByUsername(t *testing.T) { } } +func TestPrepareErrors(t *testing.T) { + reg, _ := DB.Register() + tdb, err := sql.Open("testdb", "") + if err != nil { + t.Errorf("Got error: %v", err) + } + oldDb := DB.GetBackend() + DB.SetBackend(tdb) + defer DB.SetBackend(oldDb) + defer testdb.Reset() + + _, err = DB.GetByUsername(reg.Username) + if err == nil { + t.Errorf("Expected error, but didn't get one") + } + + _, err = DB.GetByDomain(reg.Subdomain) + if err == nil { + t.Errorf("Expected error, but didn't get one") + } +} + +func TestQueryExecErrors(t *testing.T) { + reg, _ := DB.Register() + testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) { + return testResult{1, 0}, errors.New("Prepared query error") + }) + + testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) { + columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"} + return testdb.RowsFromSlice(columns, [][]driver.Value{}), errors.New("Prepared query error") + }) + + defer testdb.Reset() + + tdb, err := sql.Open("testdb", "") + if err != nil { + t.Errorf("Got error: %v", err) + } + oldDb := DB.GetBackend() + + DB.SetBackend(tdb) + defer DB.SetBackend(oldDb) + + _, err = DB.GetByUsername(reg.Username) + if err == nil { + t.Errorf("Expected error from exec, but got none") + } + + _, err = DB.GetByDomain(reg.Subdomain) + if err == nil { + t.Errorf("Expected error from exec in GetByDomain, but got none") + } + + _, err = DB.Register() + if err == nil { + t.Errorf("Expected error from exec in Register, but got none") + } + reg.Value = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + err = DB.Update(reg) + if err == nil { + t.Errorf("Expected error from exec in Update, but got none") + } + +} + +func TestQueryScanErrors(t *testing.T) { + reg, _ := DB.Register() + + testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) { + return testResult{1, 0}, errors.New("Prepared query error") + }) + + testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) { + columns := []string{"Only one"} + resultrows := "this value" + return testdb.RowsFromCSVString(columns, resultrows), nil + }) + + defer testdb.Reset() + tdb, err := sql.Open("testdb", "") + if err != nil { + t.Errorf("Got error: %v", err) + } + oldDb := DB.GetBackend() + + DB.SetBackend(tdb) + defer DB.SetBackend(oldDb) + + _, err = DB.GetByUsername(reg.Username) + if err == nil { + t.Errorf("Expected error from scan in, but got none") + } + + _, err = DB.GetByDomain(reg.Subdomain) + if err == nil { + t.Errorf("Expected error from scan in GetByDomain, but got none") + } +} + +func TestBadDBValues(t *testing.T) { + reg, _ := DB.Register() + + testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) { + columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"} + resultrows := "invalid,invalid,invalid,invalid," + return testdb.RowsFromCSVString(columns, resultrows), nil + }) + + defer testdb.Reset() + tdb, err := sql.Open("testdb", "") + if err != nil { + t.Errorf("Got error: %v", err) + } + oldDb := DB.GetBackend() + + DB.SetBackend(tdb) + defer DB.SetBackend(oldDb) + + _, err = DB.GetByUsername(reg.Username) + if err == nil { + t.Errorf("Expected error from scan in, but got none") + } + + _, err = DB.GetByDomain(reg.Subdomain) + if err == nil { + t.Errorf("Expected error from scan in GetByDomain, but got none") + } +} + func TestGetByDomain(t *testing.T) { var regDomain = ACMETxt{} diff --git a/util.go b/util.go index 200719d..4898d7d 100644 --- a/util.go +++ b/util.go @@ -50,13 +50,13 @@ func sanitizeDomainQuestion(d string) string { return dom } -func newACMETxt() (ACMETxt, error) { +func newACMETxt() ACMETxt { var a = ACMETxt{} password := generatePassword(40) a.Username = uuid.NewV4() a.Password = password a.Subdomain = uuid.NewV4().String() - return a, nil + return a } func setupLogging(format string, level string) {