diff --git a/api.go b/api.go index 01ffa92..1c7dec8 100644 --- a/api.go +++ b/api.go @@ -24,10 +24,10 @@ func (a AuthMiddleware) Serve(ctx *iris.Context) { password := ctx.RequestHeader("X-Api-Key") postData := ACMETxt{} - username, err := GetValidUsername(usernameStr) - if err == nil && ValidKey(password) { + username, err := getValidUsername(usernameStr) + if err == nil && validKey(password) { au, err := DB.GetByUsername(username) - if err == nil && CorrectPassword(password, au.Password) { + if err == nil && correctPassword(password, au.Password) { // Password ok if err := ctx.ReadJSON(&postData); err == nil { // Check that the subdomain belongs to the user @@ -39,7 +39,7 @@ func (a AuthMiddleware) Serve(ctx *iris.Context) { } } // To protect against timed side channel (never gonna give you up) - CorrectPassword(password, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36") + correctPassword(password, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36") } ctx.JSON(iris.StatusUnauthorized, iris.Map{"error": "unauthorized"}) } @@ -72,7 +72,7 @@ func WebUpdatePost(ctx *iris.Context) { // User auth done in middleware a := ACMETxt{} userStr := ctx.RequestHeader("X-API-User") - username, err := GetValidUsername(userStr) + username, err := getValidUsername(userStr) if err != nil { log.Warningf("Error while getting username [%s]. This should never happen because of auth middlware.", userStr) WebUpdatePostError(ctx, err, iris.StatusUnauthorized) @@ -86,7 +86,7 @@ func WebUpdatePost(ctx *iris.Context) { } a.Username = username // Do update - if ValidSubdomain(a.Subdomain) && ValidTXT(a.Value) { + if validSubdomain(a.Subdomain) && validTXT(a.Value) { err := DB.Update(a) if err != nil { log.Warningf("Error trying to update [%v]", err) diff --git a/db.go b/db.go index acf1db8..e41a232 100644 --- a/db.go +++ b/db.go @@ -49,7 +49,7 @@ func (d *Database) Init(engine string, connection string) error { } func (d *Database) Register() (ACMETxt, error) { - a, err := NewACMETxt() + a, err := newACMETxt() if err != nil { return ACMETxt{}, err } @@ -121,7 +121,7 @@ func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) { } func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) { - domain = SanitizeString(domain) + domain = sanitizeString(domain) log.Debugf("Trying to select domain [%s] from table", domain) var a []ACMETxt getSQL := ` diff --git a/db_test.go b/db_test.go index a4da146..7e245ca 100644 --- a/db_test.go +++ b/db_test.go @@ -66,7 +66,7 @@ func TestGetByUsername(t *testing.T) { } // regUser password already is a bcrypt hash - if !CorrectPassword(reg.Password, regUser.Password) { + if !correctPassword(reg.Password, regUser.Password) { t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regUser.Password) } } @@ -113,7 +113,7 @@ func TestGetByDomain(t *testing.T) { } // regDomain password already is a bcrypt hash - if !CorrectPassword(reg.Password, regDomain.Password) { + if !correctPassword(reg.Password, regDomain.Password) { t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regDomain.Password) } diff --git a/dns.go b/dns.go index 57f5a6b..9a665dd 100644 --- a/dns.go +++ b/dns.go @@ -23,7 +23,7 @@ func answerTXT(q dns.Question) ([]dns.RR, int, error) { var rcode int = dns.RcodeNameError var domain = strings.ToLower(q.Name) - atxt, err := DB.GetByDomain(SanitizeDomainQuestion(domain)) + atxt, err := DB.GetByDomain(sanitizeDomainQuestion(domain)) if err != nil { log.Errorf("Error while trying to get record [%v]", err) return ra, dns.RcodeNameError, err diff --git a/main.go b/main.go index b648ef9..6fac7e8 100644 --- a/main.go +++ b/main.go @@ -28,36 +28,8 @@ func main() { os.Exit(1) } DNSConf = configTmp - // Setup logging - var logformat = logging.MustStringFormatter(DNSConf.Logconfig.Format) - var logBackend *logging.LogBackend - switch DNSConf.Logconfig.Logtype { - default: - // Setup logging - stdout - logBackend = logging.NewLogBackend(os.Stdout, "", 0) - case "file": - // Logging to file - logfh, err := os.OpenFile(DNSConf.Logconfig.File, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) - if err != nil { - fmt.Printf("Could not open log file %s\n", DNSConf.Logconfig.File) - os.Exit(1) - } - defer logfh.Close() - logBackend = logging.NewLogBackend(logfh, "", 0) - } - logFormatter := logging.NewBackendFormatter(logBackend, logformat) - logLevel := logging.AddModuleLevel(logFormatter) - switch DNSConf.Logconfig.Level { - default: - logLevel.SetLevel(logging.DEBUG, "") - case "warning": - logLevel.SetLevel(logging.WARNING, "") - case "error": - logLevel.SetLevel(logging.ERROR, "") - case "info": - logLevel.SetLevel(logging.INFO, "") - } - logging.SetBackend(logFormatter) + + setupLogging() // Read the default records in RR.Parse(DNSConf.General.StaticRecords) diff --git a/util.go b/util.go index 938c179..9e14f90 100644 --- a/util.go +++ b/util.go @@ -3,9 +3,12 @@ package main import ( "crypto/rand" "errors" + "fmt" "github.com/BurntSushi/toml" + "github.com/op/go-logging" "github.com/satori/go.uuid" "math/big" + "os" "regexp" "strings" ) @@ -18,7 +21,7 @@ func readConfig(fname string) (DNSConfig, error) { return conf, nil } -func SanitizeString(s string) string { +func sanitizeString(s string) string { // URL safe base64 alphabet without padding as defined in ACME re, err := regexp.Compile("[^A-Za-z\\-\\_0-9]+") if err != nil { @@ -28,7 +31,7 @@ func SanitizeString(s string) string { return re.ReplaceAllString(s, "") } -func GeneratePassword(length int) (string, error) { +func generatePassword(length int) (string, error) { ret := make([]byte, length) const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890-_" alphalen := big.NewInt(int64(len(alphabet))) @@ -43,7 +46,7 @@ func GeneratePassword(length int) (string, error) { return string(ret), nil } -func SanitizeDomainQuestion(d string) string { +func sanitizeDomainQuestion(d string) string { var dom string suffix := DNSConf.General.Domain + "." if strings.HasSuffix(d, suffix) { @@ -54,9 +57,9 @@ func SanitizeDomainQuestion(d string) string { return dom } -func NewACMETxt() (ACMETxt, error) { +func newACMETxt() (ACMETxt, error) { var a = ACMETxt{} - password, err := GeneratePassword(40) + password, err := generatePassword(40) if err != nil { return a, err } @@ -65,3 +68,36 @@ func NewACMETxt() (ACMETxt, error) { a.Subdomain = uuid.NewV4().String() return a, nil } + +func setupLogging() { + var logformat = logging.MustStringFormatter(DNSConf.Logconfig.Format) + var logBackend *logging.LogBackend + switch DNSConf.Logconfig.Logtype { + default: + // Setup logging - stdout + logBackend = logging.NewLogBackend(os.Stdout, "", 0) + case "file": + // Logging to file + logfh, err := os.OpenFile(DNSConf.Logconfig.File, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) + if err != nil { + fmt.Printf("Could not open log file %s\n", DNSConf.Logconfig.File) + os.Exit(1) + } + defer logfh.Close() + logBackend = logging.NewLogBackend(logfh, "", 0) + } + logFormatter := logging.NewBackendFormatter(logBackend, logformat) + logLevel := logging.AddModuleLevel(logFormatter) + switch DNSConf.Logconfig.Level { + default: + logLevel.SetLevel(logging.DEBUG, "") + case "warning": + logLevel.SetLevel(logging.WARNING, "") + case "error": + logLevel.SetLevel(logging.ERROR, "") + case "info": + logLevel.SetLevel(logging.INFO, "") + } + logging.SetBackend(logFormatter) + +} diff --git a/validation.go b/validation.go index 589012c..036b9bc 100644 --- a/validation.go +++ b/validation.go @@ -6,7 +6,7 @@ import ( "unicode/utf8" ) -func GetValidUsername(u string) (uuid.UUID, error) { +func getValidUsername(u string) (uuid.UUID, error) { uname, err := uuid.FromString(u) if err != nil { return uuid.UUID{}, err @@ -14,8 +14,8 @@ func GetValidUsername(u string) (uuid.UUID, error) { return uname, nil } -func ValidKey(k string) bool { - kn := SanitizeString(k) +func validKey(k string) bool { + kn := sanitizeString(k) if utf8.RuneCountInString(k) == 40 && utf8.RuneCountInString(kn) == 40 { // Correct length and all chars valid return true @@ -23,7 +23,7 @@ func ValidKey(k string) bool { return false } -func ValidSubdomain(s string) bool { +func validSubdomain(s string) bool { _, err := uuid.FromString(s) if err == nil { return true @@ -31,8 +31,8 @@ func ValidSubdomain(s string) bool { return false } -func ValidTXT(s string) bool { - sn := SanitizeString(s) +func validTXT(s string) bool { + sn := sanitizeString(s) if utf8.RuneCountInString(s) == 43 && utf8.RuneCountInString(sn) == 43 { // 43 chars is the current LE auth key size, but not limited / defined by ACME return true @@ -40,7 +40,7 @@ func ValidTXT(s string) bool { return false } -func CorrectPassword(pw string, hash string) bool { +func correctPassword(pw string, hash string) bool { if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pw)); err == nil { return true } diff --git a/validation_test.go b/validation_test.go index 1869e6e..58801ca 100644 --- a/validation_test.go +++ b/validation_test.go @@ -17,7 +17,7 @@ func TestGetValidUsername(t *testing.T) { {"", uuid.UUID{}, true}, {"&!#!25123!%!'%", uuid.UUID{}, true}, } { - ret, err := GetValidUsername(test.uname) + ret, err := getValidUsername(test.uname) if test.shouldErr && err == nil { t.Errorf("Test %d: Expected error, but there was none", i) } @@ -41,7 +41,7 @@ func TestValidKey(t *testing.T) { {"aaaaaaaa-aaa-aaaaaa#aaaaaaaa-aaa_aacaaaa", false}, {"aaaaaaaa-aaa-aaaaaa-aaaaaaaa-aaa_aacaaaaa", false}, } { - ret := ValidKey(test.key) + ret := validKey(test.key) if ret != test.output { t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret) } @@ -58,7 +58,7 @@ func TestGetValidSubdomain(t *testing.T) { {"", false}, {"&!#!25123!%!'%", false}, } { - ret := ValidSubdomain(test.subdomain) + ret := validSubdomain(test.subdomain) if ret != test.output { t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret) } @@ -76,7 +76,7 @@ func TestValidTXT(t *testing.T) { {"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", false}, {"", false}, } { - ret := ValidTXT(test.txt) + ret := validTXT(test.txt) if ret != test.output { t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret) } @@ -100,7 +100,7 @@ func TestCorrectPassword(t *testing.T) { false}, {"", "", false}, } { - ret := CorrectPassword(test.pw, test.hash) + ret := correctPassword(test.pw, test.hash) if ret != test.output { t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret) }