diff --git a/acmetxt.go b/acmetxt.go new file mode 100644 index 0000000..a4a23ed --- /dev/null +++ b/acmetxt.go @@ -0,0 +1,51 @@ +package main + +import ( + "encoding/json" + "net" + + "github.com/satori/go.uuid" +) + +// ACMETxt is the default structure for the user controlled record +type ACMETxt struct { + Username uuid.UUID + Password string + ACMETxtPost + LastActive int64 + AllowFrom cidrslice +} + +// ACMETxtPost holds the DNS part of the ACMETxt struct +type ACMETxtPost struct { + Subdomain string `json:"subdomain"` + Value string `json:"txt"` +} + +// cidrslice is a list of allowed cidr ranges +type cidrslice []string + +func (c *cidrslice) JSON() string { + ret, _ := json.Marshal(c.ValidEntries()) + return string(ret) +} + +func (c *cidrslice) ValidEntries() []string { + valid := []string{} + for _, v := range *c { + _, _, err := net.ParseCIDR(v) + if err == nil { + valid = append(valid, v) + } + } + return valid +} + +func newACMETxt() ACMETxt { + var a = ACMETxt{} + password := generatePassword(40) + a.Username = uuid.NewV4() + a.Password = password + a.Subdomain = uuid.NewV4().String() + return a +} diff --git a/api.go b/api.go index 7226ef9..5e78321 100644 --- a/api.go +++ b/api.go @@ -16,28 +16,36 @@ func (a authMiddleware) Serve(ctx *iris.Context) { username, err := getValidUsername(usernameStr) if err == nil && validKey(password) { au, err := DB.GetByUsername(username) - if err == nil && correctPassword(password, au.Password) { - // Password ok - if err := ctx.ReadJSON(&postData); err == nil { - // Check that the subdomain belongs to the user - if au.Subdomain == postData.Subdomain { - ctx.Next() + if err != nil { + log.WithFields(log.Fields{"error": err.Error()}).Error("Error while trying to get user") + // To protect against timed side channel (never gonna give you up) + correctPassword(password, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36") + } else { + if correctPassword(password, au.Password) { + // Password ok + if err := ctx.ReadJSON(&postData); err == nil { + // Check that the subdomain belongs to the user + if au.Subdomain == postData.Subdomain { + ctx.Next() + return + } + } else { + // JSON error + ctx.JSON(iris.StatusBadRequest, iris.Map{"error": "bad data"}) return } } else { - ctx.JSON(iris.StatusBadRequest, iris.Map{"error": "bad data"}) - return + // Wrong password + log.WithFields(log.Fields{"username": username}).Warning("Failed password check") } } - // To protect against timed side channel (never gonna give you up) - correctPassword(password, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36") } ctx.JSON(iris.StatusUnauthorized, iris.Map{"error": "unauthorized"}) } func webRegisterPost(ctx *iris.Context) { // Create new user - nu, err := DB.Register() + nu, err := DB.Register(cidrslice{}) var regJSON iris.Map var regStatus int if err != nil { diff --git a/api_test.go b/api_test.go index b067f14..f69f6bf 100644 --- a/api_test.go +++ b/api_test.go @@ -90,7 +90,7 @@ func TestApiUpdateWithCredentials(t *testing.T) { "txt": ""} e := setupIris(t, false, false) - newUser, err := DB.Register() + newUser, err := DB.Register(cidrslice{}) if err != nil { t.Errorf("Could not create new user, got error [%v]", err) } @@ -146,7 +146,7 @@ func TestApiManyUpdateWithCredentials(t *testing.T) { "txt": ""} e := setupIris(t, false, false) - newUser, err := DB.Register() + newUser, err := DB.Register(cidrslice{}) if err != nil { t.Errorf("Could not create new user, got error [%v]", err) } @@ -164,6 +164,7 @@ func TestApiManyUpdateWithCredentials(t *testing.T) { {newUser.Username.String(), newUser.Password, newUser.Subdomain, "tooshortfortxt", 400}, {newUser.Username.String(), newUser.Password, newUser.Subdomain, 1234567890, 400}, {newUser.Username.String(), newUser.Password, newUser.Subdomain, validTxtData, 200}, + {newUser.Username.String(), "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", newUser.Subdomain, validTxtData, 401}, } { updateJSON = map[string]interface{}{ "subdomain": test.subdomain, diff --git a/db.go b/db.go index 868cf6e..3cac758 100644 --- a/db.go +++ b/db.go @@ -2,13 +2,16 @@ package main import ( "database/sql" + "encoding/json" "errors" + "regexp" + "time" + + log "github.com/Sirupsen/logrus" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" "github.com/satori/go.uuid" "golang.org/x/crypto/bcrypt" - "regexp" - "time" ) var recordsTable = ` @@ -43,10 +46,11 @@ func (d *acmedb) Init(engine string, connection string) error { return nil } -func (d *acmedb) Register() (ACMETxt, error) { +func (d *acmedb) Register(afrom cidrslice) (ACMETxt, error) { d.Lock() defer d.Unlock() a := newACMETxt() + a.AllowFrom = cidrslice(afrom.ValidEntries()) passwordHash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10) timenow := time.Now().Unix() regSQL := ` @@ -63,10 +67,11 @@ func (d *acmedb) Register() (ACMETxt, error) { } sm, err := d.DB.Prepare(regSQL) if err != nil { + log.WithFields(log.Fields{"error": err.Error()}).Error("Database error in prepare") return a, errors.New("SQL error") } defer sm.Close() - _, err = sm.Exec(a.Username.String(), passwordHash, a.Subdomain, timenow, a.AllowFrom) + _, err = sm.Exec(a.Username.String(), passwordHash, a.Subdomain, timenow, a.AllowFrom.JSON()) if err != nil { return a, err } @@ -173,13 +178,24 @@ func (d *acmedb) Update(a ACMETxt) error { func getModelFromRow(r *sql.Rows) (ACMETxt, error) { txt := ACMETxt{} + afrom := "" err := r.Scan( &txt.Username, &txt.Password, &txt.Subdomain, &txt.Value, &txt.LastActive, - &txt.AllowFrom) + &afrom) + if err != nil { + log.WithFields(log.Fields{"error": err.Error()}).Error("Row scan error") + } + + cslice := cidrslice{} + err = json.Unmarshal([]byte(afrom), &cslice) + if err != nil { + log.WithFields(log.Fields{"error": err.Error()}).Error("JSON unmarshall error") + } + txt.AllowFrom = cslice return txt, err } diff --git a/db_test.go b/db_test.go index 665bed3..4580a34 100644 --- a/db_test.go +++ b/db_test.go @@ -41,17 +41,44 @@ func TestDBInit(t *testing.T) { errorDB.Close() } -func TestRegister(t *testing.T) { +func TestRegisterNoCIDR(t *testing.T) { // Register tests - _, err := DB.Register() + _, err := DB.Register(cidrslice{}) if err != nil { t.Errorf("Registration failed, got error [%v]", err) } } +func TestRegisterMany(t *testing.T) { + for i, test := range []struct { + input cidrslice + output cidrslice + }{ + {cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}, cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}}, + {cidrslice{"1.1.1./32", "1922.168.42.42/8", "1.1.1.1/33", "1.2.3.4/"}, cidrslice{}}, + {cidrslice{"7.6.5.4/32", "invalid", "1.0.0.1/2"}, cidrslice{"7.6.5.4/32", "1.0.0.1/2"}}, + } { + user, err := DB.Register(test.input) + if err != nil { + t.Errorf("Test %d: Got error from register method: [%v]", i, err) + } + res, err := DB.GetByUsername(user.Username) + if err != nil { + t.Errorf("Test %d: Got error when fetching username: [%v]", i, err) + } + if len(user.AllowFrom) != len(test.output) { + t.Errorf("Test %d: Expected to recieve struct with [%d] entries in AllowFrom, but got [%d] records", i, len(test.output), len(user.AllowFrom)) + } + if len(res.AllowFrom) != len(test.output) { + t.Errorf("Test %d: Expected to recieve struct with [%d] entries in AllowFrom, but got [%d] records", i, len(test.output), len(res.AllowFrom)) + } + + } +} + func TestGetByUsername(t *testing.T) { // Create reg to refer to - reg, err := DB.Register() + reg, err := DB.Register(cidrslice{}) if err != nil { t.Errorf("Registration failed, got error [%v]", err) } @@ -76,7 +103,7 @@ func TestGetByUsername(t *testing.T) { } func TestPrepareErrors(t *testing.T) { - reg, _ := DB.Register() + reg, _ := DB.Register(cidrslice{}) tdb, err := sql.Open("testdb", "") if err != nil { t.Errorf("Got error: %v", err) @@ -98,7 +125,7 @@ func TestPrepareErrors(t *testing.T) { } func TestQueryExecErrors(t *testing.T) { - reg, _ := DB.Register() + reg, _ := DB.Register(cidrslice{}) testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) { return testResult{1, 0}, errors.New("Prepared query error") }) @@ -129,7 +156,7 @@ func TestQueryExecErrors(t *testing.T) { t.Errorf("Expected error from exec in GetByDomain, but got none") } - _, err = DB.Register() + _, err = DB.Register(cidrslice{}) if err == nil { t.Errorf("Expected error from exec in Register, but got none") } @@ -142,7 +169,7 @@ func TestQueryExecErrors(t *testing.T) { } func TestQueryScanErrors(t *testing.T) { - reg, _ := DB.Register() + reg, _ := DB.Register(cidrslice{}) testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) { return testResult{1, 0}, errors.New("Prepared query error") @@ -176,7 +203,7 @@ func TestQueryScanErrors(t *testing.T) { } func TestBadDBValues(t *testing.T) { - reg, _ := DB.Register() + reg, _ := DB.Register(cidrslice{}) testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) { columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"} @@ -209,7 +236,7 @@ func TestGetByDomain(t *testing.T) { var regDomain = ACMETxt{} // Create reg to refer to - reg, err := DB.Register() + reg, err := DB.Register(cidrslice{}) if err != nil { t.Errorf("Registration failed, got error [%v]", err) } @@ -246,7 +273,7 @@ func TestGetByDomain(t *testing.T) { func TestUpdate(t *testing.T) { // Create reg to refer to - reg, err := DB.Register() + reg, err := DB.Register(cidrslice{}) if err != nil { t.Errorf("Registration failed, got error [%v]", err) } diff --git a/dns_test.go b/dns_test.go index e550ec8..63c2701 100644 --- a/dns_test.go +++ b/dns_test.go @@ -139,7 +139,7 @@ func TestResolveTXT(t *testing.T) { resolv := resolver{server: "0.0.0.0:15353"} validTXT := "______________valid_response_______________" - atxt, err := DB.Register() + atxt, err := DB.Register(cidrslice{}) if err != nil { t.Errorf("Could not initiate db record: [%v]", err) return diff --git a/types.go b/types.go index c6bdf46..0dc3c79 100644 --- a/types.go +++ b/types.go @@ -66,21 +66,6 @@ type logconfig struct { Format string `toml:"logformat"` } -// ACMETxt is the default structure for the user controlled record -type ACMETxt struct { - Username uuid.UUID - Password string - ACMETxtPost - LastActive int64 - AllowFrom string -} - -// ACMETxtPost holds the DNS part of the ACMETxt struct -type ACMETxtPost struct { - Subdomain string `json:"subdomain"` - Value string `json:"txt"` -} - type acmedb struct { sync.Mutex DB *sql.DB @@ -88,7 +73,7 @@ type acmedb struct { type database interface { Init(string, string) error - Register() (ACMETxt, error) + Register(cidrslice) (ACMETxt, error) GetByUsername(uuid.UUID) (ACMETxt, error) GetByDomain(string) ([]ACMETxt, error) Update(ACMETxt) error diff --git a/util.go b/util.go index 7f51d06..e3d76c0 100644 --- a/util.go +++ b/util.go @@ -5,7 +5,6 @@ import ( "github.com/BurntSushi/toml" log "github.com/Sirupsen/logrus" "github.com/miekg/dns" - "github.com/satori/go.uuid" "math/big" "regexp" "strings" @@ -45,15 +44,6 @@ func sanitizeDomainQuestion(d string) string { return dom } -func newACMETxt() ACMETxt { - var a = ACMETxt{} - password := generatePassword(40) - a.Username = uuid.NewV4() - a.Password = password - a.Subdomain = uuid.NewV4().String() - return a -} - func setupLogging(format string, level string) { if format == "json" { log.SetFormatter(&log.JSONFormatter{}) diff --git a/validation.go b/validation.go index 036b9bc..66e1626 100644 --- a/validation.go +++ b/validation.go @@ -1,9 +1,10 @@ package main import ( + "unicode/utf8" + "github.com/satori/go.uuid" "golang.org/x/crypto/bcrypt" - "unicode/utf8" ) func getValidUsername(u string) (uuid.UUID, error) { diff --git a/validation_test.go b/validation_test.go index 58801ca..e6e3589 100644 --- a/validation_test.go +++ b/validation_test.go @@ -106,3 +106,24 @@ func TestCorrectPassword(t *testing.T) { } } } + +func TestGetValidCIDRMasks(t *testing.T) { + for i, test := range []struct { + input cidrslice + output cidrslice + }{ + {cidrslice{"10.0.0.1/24"}, cidrslice{"10.0.0.1/24"}}, + {cidrslice{"invalid", "127.0.0.1/32"}, cidrslice{"127.0.0.1/32"}}, + } { + ret := test.input.ValidEntries() + if len(ret) == len(test.output) { + for i, v := range ret { + if v != test.output[i] { + t.Errorf("Test %d: Expected %q but got %q", i, test.output, ret) + } + } + } else { + t.Errorf("Test %d: Expected %q but got %q", i, test.output, ret) + } + } +}