Refactoring

This commit is contained in:
Joona Hoikkala 2016-11-23 17:11:31 +02:00
parent f32c4940e1
commit ba63bad793
8 changed files with 66 additions and 58 deletions

12
api.go
View File

@ -24,10 +24,10 @@ func (a AuthMiddleware) Serve(ctx *iris.Context) {
password := ctx.RequestHeader("X-Api-Key") password := ctx.RequestHeader("X-Api-Key")
postData := ACMETxt{} postData := ACMETxt{}
username, err := GetValidUsername(usernameStr) username, err := getValidUsername(usernameStr)
if err == nil && ValidKey(password) { if err == nil && validKey(password) {
au, err := DB.GetByUsername(username) au, err := DB.GetByUsername(username)
if err == nil && CorrectPassword(password, au.Password) { if err == nil && correctPassword(password, au.Password) {
// Password ok // Password ok
if err := ctx.ReadJSON(&postData); err == nil { if err := ctx.ReadJSON(&postData); err == nil {
// Check that the subdomain belongs to the user // Check that the subdomain belongs to the user
@ -39,7 +39,7 @@ func (a AuthMiddleware) Serve(ctx *iris.Context) {
} }
} }
// To protect against timed side channel (never gonna give you up) // To protect against timed side channel (never gonna give you up)
CorrectPassword(password, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36") correctPassword(password, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36")
} }
ctx.JSON(iris.StatusUnauthorized, iris.Map{"error": "unauthorized"}) ctx.JSON(iris.StatusUnauthorized, iris.Map{"error": "unauthorized"})
} }
@ -72,7 +72,7 @@ func WebUpdatePost(ctx *iris.Context) {
// User auth done in middleware // User auth done in middleware
a := ACMETxt{} a := ACMETxt{}
userStr := ctx.RequestHeader("X-API-User") userStr := ctx.RequestHeader("X-API-User")
username, err := GetValidUsername(userStr) username, err := getValidUsername(userStr)
if err != nil { if err != nil {
log.Warningf("Error while getting username [%s]. This should never happen because of auth middlware.", userStr) log.Warningf("Error while getting username [%s]. This should never happen because of auth middlware.", userStr)
WebUpdatePostError(ctx, err, iris.StatusUnauthorized) WebUpdatePostError(ctx, err, iris.StatusUnauthorized)
@ -86,7 +86,7 @@ func WebUpdatePost(ctx *iris.Context) {
} }
a.Username = username a.Username = username
// Do update // Do update
if ValidSubdomain(a.Subdomain) && ValidTXT(a.Value) { if validSubdomain(a.Subdomain) && validTXT(a.Value) {
err := DB.Update(a) err := DB.Update(a)
if err != nil { if err != nil {
log.Warningf("Error trying to update [%v]", err) log.Warningf("Error trying to update [%v]", err)

4
db.go
View File

@ -49,7 +49,7 @@ func (d *Database) Init(engine string, connection string) error {
} }
func (d *Database) Register() (ACMETxt, error) { func (d *Database) Register() (ACMETxt, error) {
a, err := NewACMETxt() a, err := newACMETxt()
if err != nil { if err != nil {
return ACMETxt{}, err return ACMETxt{}, err
} }
@ -121,7 +121,7 @@ func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
} }
func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) { func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) {
domain = SanitizeString(domain) domain = sanitizeString(domain)
log.Debugf("Trying to select domain [%s] from table", domain) log.Debugf("Trying to select domain [%s] from table", domain)
var a []ACMETxt var a []ACMETxt
getSQL := ` getSQL := `

View File

@ -66,7 +66,7 @@ func TestGetByUsername(t *testing.T) {
} }
// regUser password already is a bcrypt hash // regUser password already is a bcrypt hash
if !CorrectPassword(reg.Password, regUser.Password) { if !correctPassword(reg.Password, regUser.Password) {
t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regUser.Password) t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regUser.Password)
} }
} }
@ -113,7 +113,7 @@ func TestGetByDomain(t *testing.T) {
} }
// regDomain password already is a bcrypt hash // regDomain password already is a bcrypt hash
if !CorrectPassword(reg.Password, regDomain.Password) { if !correctPassword(reg.Password, regDomain.Password) {
t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regDomain.Password) t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regDomain.Password)
} }

2
dns.go
View File

@ -23,7 +23,7 @@ func answerTXT(q dns.Question) ([]dns.RR, int, error) {
var rcode int = dns.RcodeNameError var rcode int = dns.RcodeNameError
var domain = strings.ToLower(q.Name) var domain = strings.ToLower(q.Name)
atxt, err := DB.GetByDomain(SanitizeDomainQuestion(domain)) atxt, err := DB.GetByDomain(sanitizeDomainQuestion(domain))
if err != nil { if err != nil {
log.Errorf("Error while trying to get record [%v]", err) log.Errorf("Error while trying to get record [%v]", err)
return ra, dns.RcodeNameError, err return ra, dns.RcodeNameError, err

32
main.go
View File

@ -28,36 +28,8 @@ func main() {
os.Exit(1) os.Exit(1)
} }
DNSConf = configTmp DNSConf = configTmp
// Setup logging
var logformat = logging.MustStringFormatter(DNSConf.Logconfig.Format) setupLogging()
var logBackend *logging.LogBackend
switch DNSConf.Logconfig.Logtype {
default:
// Setup logging - stdout
logBackend = logging.NewLogBackend(os.Stdout, "", 0)
case "file":
// Logging to file
logfh, err := os.OpenFile(DNSConf.Logconfig.File, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
fmt.Printf("Could not open log file %s\n", DNSConf.Logconfig.File)
os.Exit(1)
}
defer logfh.Close()
logBackend = logging.NewLogBackend(logfh, "", 0)
}
logFormatter := logging.NewBackendFormatter(logBackend, logformat)
logLevel := logging.AddModuleLevel(logFormatter)
switch DNSConf.Logconfig.Level {
default:
logLevel.SetLevel(logging.DEBUG, "")
case "warning":
logLevel.SetLevel(logging.WARNING, "")
case "error":
logLevel.SetLevel(logging.ERROR, "")
case "info":
logLevel.SetLevel(logging.INFO, "")
}
logging.SetBackend(logFormatter)
// Read the default records in // Read the default records in
RR.Parse(DNSConf.General.StaticRecords) RR.Parse(DNSConf.General.StaticRecords)

46
util.go
View File

@ -3,9 +3,12 @@ package main
import ( import (
"crypto/rand" "crypto/rand"
"errors" "errors"
"fmt"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
"github.com/op/go-logging"
"github.com/satori/go.uuid" "github.com/satori/go.uuid"
"math/big" "math/big"
"os"
"regexp" "regexp"
"strings" "strings"
) )
@ -18,7 +21,7 @@ func readConfig(fname string) (DNSConfig, error) {
return conf, nil return conf, nil
} }
func SanitizeString(s string) string { func sanitizeString(s string) string {
// URL safe base64 alphabet without padding as defined in ACME // URL safe base64 alphabet without padding as defined in ACME
re, err := regexp.Compile("[^A-Za-z\\-\\_0-9]+") re, err := regexp.Compile("[^A-Za-z\\-\\_0-9]+")
if err != nil { if err != nil {
@ -28,7 +31,7 @@ func SanitizeString(s string) string {
return re.ReplaceAllString(s, "") return re.ReplaceAllString(s, "")
} }
func GeneratePassword(length int) (string, error) { func generatePassword(length int) (string, error) {
ret := make([]byte, length) ret := make([]byte, length)
const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890-_" const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890-_"
alphalen := big.NewInt(int64(len(alphabet))) alphalen := big.NewInt(int64(len(alphabet)))
@ -43,7 +46,7 @@ func GeneratePassword(length int) (string, error) {
return string(ret), nil return string(ret), nil
} }
func SanitizeDomainQuestion(d string) string { func sanitizeDomainQuestion(d string) string {
var dom string var dom string
suffix := DNSConf.General.Domain + "." suffix := DNSConf.General.Domain + "."
if strings.HasSuffix(d, suffix) { if strings.HasSuffix(d, suffix) {
@ -54,9 +57,9 @@ func SanitizeDomainQuestion(d string) string {
return dom return dom
} }
func NewACMETxt() (ACMETxt, error) { func newACMETxt() (ACMETxt, error) {
var a = ACMETxt{} var a = ACMETxt{}
password, err := GeneratePassword(40) password, err := generatePassword(40)
if err != nil { if err != nil {
return a, err return a, err
} }
@ -65,3 +68,36 @@ func NewACMETxt() (ACMETxt, error) {
a.Subdomain = uuid.NewV4().String() a.Subdomain = uuid.NewV4().String()
return a, nil return a, nil
} }
func setupLogging() {
var logformat = logging.MustStringFormatter(DNSConf.Logconfig.Format)
var logBackend *logging.LogBackend
switch DNSConf.Logconfig.Logtype {
default:
// Setup logging - stdout
logBackend = logging.NewLogBackend(os.Stdout, "", 0)
case "file":
// Logging to file
logfh, err := os.OpenFile(DNSConf.Logconfig.File, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
fmt.Printf("Could not open log file %s\n", DNSConf.Logconfig.File)
os.Exit(1)
}
defer logfh.Close()
logBackend = logging.NewLogBackend(logfh, "", 0)
}
logFormatter := logging.NewBackendFormatter(logBackend, logformat)
logLevel := logging.AddModuleLevel(logFormatter)
switch DNSConf.Logconfig.Level {
default:
logLevel.SetLevel(logging.DEBUG, "")
case "warning":
logLevel.SetLevel(logging.WARNING, "")
case "error":
logLevel.SetLevel(logging.ERROR, "")
case "info":
logLevel.SetLevel(logging.INFO, "")
}
logging.SetBackend(logFormatter)
}

View File

@ -6,7 +6,7 @@ import (
"unicode/utf8" "unicode/utf8"
) )
func GetValidUsername(u string) (uuid.UUID, error) { func getValidUsername(u string) (uuid.UUID, error) {
uname, err := uuid.FromString(u) uname, err := uuid.FromString(u)
if err != nil { if err != nil {
return uuid.UUID{}, err return uuid.UUID{}, err
@ -14,8 +14,8 @@ func GetValidUsername(u string) (uuid.UUID, error) {
return uname, nil return uname, nil
} }
func ValidKey(k string) bool { func validKey(k string) bool {
kn := SanitizeString(k) kn := sanitizeString(k)
if utf8.RuneCountInString(k) == 40 && utf8.RuneCountInString(kn) == 40 { if utf8.RuneCountInString(k) == 40 && utf8.RuneCountInString(kn) == 40 {
// Correct length and all chars valid // Correct length and all chars valid
return true return true
@ -23,7 +23,7 @@ func ValidKey(k string) bool {
return false return false
} }
func ValidSubdomain(s string) bool { func validSubdomain(s string) bool {
_, err := uuid.FromString(s) _, err := uuid.FromString(s)
if err == nil { if err == nil {
return true return true
@ -31,8 +31,8 @@ func ValidSubdomain(s string) bool {
return false return false
} }
func ValidTXT(s string) bool { func validTXT(s string) bool {
sn := SanitizeString(s) sn := sanitizeString(s)
if utf8.RuneCountInString(s) == 43 && utf8.RuneCountInString(sn) == 43 { if utf8.RuneCountInString(s) == 43 && utf8.RuneCountInString(sn) == 43 {
// 43 chars is the current LE auth key size, but not limited / defined by ACME // 43 chars is the current LE auth key size, but not limited / defined by ACME
return true return true
@ -40,7 +40,7 @@ func ValidTXT(s string) bool {
return false return false
} }
func CorrectPassword(pw string, hash string) bool { func correctPassword(pw string, hash string) bool {
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pw)); err == nil { if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pw)); err == nil {
return true return true
} }

View File

@ -17,7 +17,7 @@ func TestGetValidUsername(t *testing.T) {
{"", uuid.UUID{}, true}, {"", uuid.UUID{}, true},
{"&!#!25123!%!'%", uuid.UUID{}, true}, {"&!#!25123!%!'%", uuid.UUID{}, true},
} { } {
ret, err := GetValidUsername(test.uname) ret, err := getValidUsername(test.uname)
if test.shouldErr && err == nil { if test.shouldErr && err == nil {
t.Errorf("Test %d: Expected error, but there was none", i) t.Errorf("Test %d: Expected error, but there was none", i)
} }
@ -41,7 +41,7 @@ func TestValidKey(t *testing.T) {
{"aaaaaaaa-aaa-aaaaaa#aaaaaaaa-aaa_aacaaaa", false}, {"aaaaaaaa-aaa-aaaaaa#aaaaaaaa-aaa_aacaaaa", false},
{"aaaaaaaa-aaa-aaaaaa-aaaaaaaa-aaa_aacaaaaa", false}, {"aaaaaaaa-aaa-aaaaaa-aaaaaaaa-aaa_aacaaaaa", false},
} { } {
ret := ValidKey(test.key) ret := validKey(test.key)
if ret != test.output { if ret != test.output {
t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret) t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret)
} }
@ -58,7 +58,7 @@ func TestGetValidSubdomain(t *testing.T) {
{"", false}, {"", false},
{"&!#!25123!%!'%", false}, {"&!#!25123!%!'%", false},
} { } {
ret := ValidSubdomain(test.subdomain) ret := validSubdomain(test.subdomain)
if ret != test.output { if ret != test.output {
t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret) t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret)
} }
@ -76,7 +76,7 @@ func TestValidTXT(t *testing.T) {
{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", false}, {"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", false},
{"", false}, {"", false},
} { } {
ret := ValidTXT(test.txt) ret := validTXT(test.txt)
if ret != test.output { if ret != test.output {
t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret) t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret)
} }
@ -100,7 +100,7 @@ func TestCorrectPassword(t *testing.T) {
false}, false},
{"", "", false}, {"", "", false},
} { } {
ret := CorrectPassword(test.pw, test.hash) ret := correctPassword(test.pw, test.hash)
if ret != test.output { if ret != test.output {
t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret) t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret)
} }