Better DB test coverage
This commit is contained in:
parent
4615826267
commit
31a82e680f
19
db.go
19
db.go
@ -3,7 +3,6 @@ package main
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
log "github.com/Sirupsen/logrus"
|
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
"github.com/satori/go.uuid"
|
"github.com/satori/go.uuid"
|
||||||
@ -23,11 +22,7 @@ var recordsTable = `
|
|||||||
|
|
||||||
// getSQLiteStmt replaces all PostgreSQL prepared statement placeholders (eg. $1, $2) with SQLite variant "?"
|
// getSQLiteStmt replaces all PostgreSQL prepared statement placeholders (eg. $1, $2) with SQLite variant "?"
|
||||||
func getSQLiteStmt(s string) string {
|
func getSQLiteStmt(s string) string {
|
||||||
re, err := regexp.Compile("\\$[0-9]")
|
re, _ := regexp.Compile("\\$[0-9]")
|
||||||
if err != nil {
|
|
||||||
log.WithFields(log.Fields{"error": err.Error()}).Debug("Error in regexp")
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
return re.ReplaceAllString(s, "?")
|
return re.ReplaceAllString(s, "?")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,10 +45,7 @@ func (d *acmedb) Init(engine string, connection string) error {
|
|||||||
func (d *acmedb) Register() (ACMETxt, error) {
|
func (d *acmedb) Register() (ACMETxt, error) {
|
||||||
d.Lock()
|
d.Lock()
|
||||||
defer d.Unlock()
|
defer d.Unlock()
|
||||||
a, err := newACMETxt()
|
a := newACMETxt()
|
||||||
if err != nil {
|
|
||||||
return ACMETxt{}, err
|
|
||||||
}
|
|
||||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10)
|
passwordHash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10)
|
||||||
timenow := time.Now().Unix()
|
timenow := time.Now().Unix()
|
||||||
regSQL := `
|
regSQL := `
|
||||||
@ -106,12 +98,7 @@ func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
|||||||
// It will only be one row though
|
// It will only be one row though
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
a := ACMETxt{}
|
a := ACMETxt{}
|
||||||
var uname string
|
err = rows.Scan(&a.Username, &a.Password, &a.Subdomain, &a.Value, &a.LastActive)
|
||||||
err = rows.Scan(&uname, &a.Password, &a.Subdomain, &a.Value, &a.LastActive)
|
|
||||||
if err != nil {
|
|
||||||
return ACMETxt{}, err
|
|
||||||
}
|
|
||||||
a.Username, err = uuid.FromString(uname)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ACMETxt{}, err
|
return ACMETxt{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
167
db_test.go
167
db_test.go
@ -1,9 +1,46 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"errors"
|
||||||
|
"github.com/erikstmartin/go-testdb"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type testResult struct {
|
||||||
|
lastID int64
|
||||||
|
affectedRows int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r testResult) LastInsertId() (int64, error) {
|
||||||
|
return r.lastID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r testResult) RowsAffected() (int64, error) {
|
||||||
|
return r.affectedRows, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDBInit(t *testing.T) {
|
||||||
|
fakeDB := new(acmedb)
|
||||||
|
err := fakeDB.Init("notarealegine", "connectionstring")
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Was expecting error, didn't get one.")
|
||||||
|
}
|
||||||
|
|
||||||
|
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
|
||||||
|
return testResult{1, 0}, errors.New("Prepared query error")
|
||||||
|
})
|
||||||
|
defer testdb.Reset()
|
||||||
|
|
||||||
|
errorDB := new(acmedb)
|
||||||
|
err = errorDB.Init("testdb", "")
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Was expecting DB initiation error but got none")
|
||||||
|
}
|
||||||
|
errorDB.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func TestRegister(t *testing.T) {
|
func TestRegister(t *testing.T) {
|
||||||
// Register tests
|
// Register tests
|
||||||
_, err := DB.Register()
|
_, err := DB.Register()
|
||||||
@ -38,6 +75,136 @@ func TestGetByUsername(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPrepareErrors(t *testing.T) {
|
||||||
|
reg, _ := DB.Register()
|
||||||
|
tdb, err := sql.Open("testdb", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Got error: %v", err)
|
||||||
|
}
|
||||||
|
oldDb := DB.GetBackend()
|
||||||
|
DB.SetBackend(tdb)
|
||||||
|
defer DB.SetBackend(oldDb)
|
||||||
|
defer testdb.Reset()
|
||||||
|
|
||||||
|
_, err = DB.GetByUsername(reg.Username)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error, but didn't get one")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = DB.GetByDomain(reg.Subdomain)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error, but didn't get one")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryExecErrors(t *testing.T) {
|
||||||
|
reg, _ := DB.Register()
|
||||||
|
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
|
||||||
|
return testResult{1, 0}, errors.New("Prepared query error")
|
||||||
|
})
|
||||||
|
|
||||||
|
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
|
||||||
|
columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"}
|
||||||
|
return testdb.RowsFromSlice(columns, [][]driver.Value{}), errors.New("Prepared query error")
|
||||||
|
})
|
||||||
|
|
||||||
|
defer testdb.Reset()
|
||||||
|
|
||||||
|
tdb, err := sql.Open("testdb", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Got error: %v", err)
|
||||||
|
}
|
||||||
|
oldDb := DB.GetBackend()
|
||||||
|
|
||||||
|
DB.SetBackend(tdb)
|
||||||
|
defer DB.SetBackend(oldDb)
|
||||||
|
|
||||||
|
_, err = DB.GetByUsername(reg.Username)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error from exec, but got none")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = DB.GetByDomain(reg.Subdomain)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error from exec in GetByDomain, but got none")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = DB.Register()
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error from exec in Register, but got none")
|
||||||
|
}
|
||||||
|
reg.Value = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||||
|
err = DB.Update(reg)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error from exec in Update, but got none")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryScanErrors(t *testing.T) {
|
||||||
|
reg, _ := DB.Register()
|
||||||
|
|
||||||
|
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
|
||||||
|
return testResult{1, 0}, errors.New("Prepared query error")
|
||||||
|
})
|
||||||
|
|
||||||
|
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
|
||||||
|
columns := []string{"Only one"}
|
||||||
|
resultrows := "this value"
|
||||||
|
return testdb.RowsFromCSVString(columns, resultrows), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
defer testdb.Reset()
|
||||||
|
tdb, err := sql.Open("testdb", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Got error: %v", err)
|
||||||
|
}
|
||||||
|
oldDb := DB.GetBackend()
|
||||||
|
|
||||||
|
DB.SetBackend(tdb)
|
||||||
|
defer DB.SetBackend(oldDb)
|
||||||
|
|
||||||
|
_, err = DB.GetByUsername(reg.Username)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error from scan in, but got none")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = DB.GetByDomain(reg.Subdomain)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error from scan in GetByDomain, but got none")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBadDBValues(t *testing.T) {
|
||||||
|
reg, _ := DB.Register()
|
||||||
|
|
||||||
|
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
|
||||||
|
columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"}
|
||||||
|
resultrows := "invalid,invalid,invalid,invalid,"
|
||||||
|
return testdb.RowsFromCSVString(columns, resultrows), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
defer testdb.Reset()
|
||||||
|
tdb, err := sql.Open("testdb", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Got error: %v", err)
|
||||||
|
}
|
||||||
|
oldDb := DB.GetBackend()
|
||||||
|
|
||||||
|
DB.SetBackend(tdb)
|
||||||
|
defer DB.SetBackend(oldDb)
|
||||||
|
|
||||||
|
_, err = DB.GetByUsername(reg.Username)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error from scan in, but got none")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = DB.GetByDomain(reg.Subdomain)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error from scan in GetByDomain, but got none")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetByDomain(t *testing.T) {
|
func TestGetByDomain(t *testing.T) {
|
||||||
var regDomain = ACMETxt{}
|
var regDomain = ACMETxt{}
|
||||||
|
|
||||||
|
|||||||
4
util.go
4
util.go
@ -50,13 +50,13 @@ func sanitizeDomainQuestion(d string) string {
|
|||||||
return dom
|
return dom
|
||||||
}
|
}
|
||||||
|
|
||||||
func newACMETxt() (ACMETxt, error) {
|
func newACMETxt() ACMETxt {
|
||||||
var a = ACMETxt{}
|
var a = ACMETxt{}
|
||||||
password := generatePassword(40)
|
password := generatePassword(40)
|
||||||
a.Username = uuid.NewV4()
|
a.Username = uuid.NewV4()
|
||||||
a.Password = password
|
a.Password = password
|
||||||
a.Subdomain = uuid.NewV4().String()
|
a.Subdomain = uuid.NewV4().String()
|
||||||
return a, nil
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupLogging(format string, level string) {
|
func setupLogging(format string, level string) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user