diff --git a/auth_test.go b/auth_test.go deleted file mode 100644 index 630b5db..0000000 --- a/auth_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package main - -import ( - "net/http" - "testing" -) - -func TestUpdateAllowedFromIP(t *testing.T) { - Config.API.UseHeader = false - userWithAllow := newACMETxt() - userWithAllow.AllowFrom = cidrslice{"192.168.1.2/32", "[::1]/128"} - userWithoutAllow := newACMETxt() - - for i, test := range []struct { - remoteaddr string - expected bool - }{ - {"192.168.1.2:1234", true}, - {"192.168.1.1:1234", false}, - {"invalid", false}, - {"[::1]:4567", true}, - } { - newreq, _ := http.NewRequest("GET", "/whatever", nil) - newreq.RemoteAddr = test.remoteaddr - ret := updateAllowedFromIP(newreq, userWithAllow) - if test.expected != ret { - t.Errorf("Test %d: Unexpected result for user with allowForm set", i) - } - - if !updateAllowedFromIP(newreq, userWithoutAllow) { - t.Errorf("Test %d: Unexpected result for user without allowForm set", i) - } - } -} diff --git a/main.go b/main.go index cc621f6..6bfd4db 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "encoding/json" "flag" "fmt" "github.com/acme-dns/acme-dns/pkg/api" @@ -15,60 +14,6 @@ import ( "go.uber.org/zap" ) -func setupLogging(config acmedns.AcmeDnsConfig) (*zap.Logger, error) { - var logger *zap.Logger - logformat := "console" - if config.Logconfig.Format == "json" { - logformat = "json" - } - outputPath := "stdout" - if config.Logconfig.Logtype == "file" { - outputPath = config.Logconfig.File - } - errorPath := "stderr" - if config.Logconfig.Logtype == "file" { - errorPath = config.Logconfig.File - } - zapConfigJson := fmt.Sprintf(`{ - "level": "%s", - "encoding": "%s", - "outputPaths": ["%s"], - "errorOutputPaths": ["%s"], - "encoderConfig": { - "timeKey": "time", - "messageKey": "msg", - "levelKey": "level", - "levelEncoder": "lowercase", - "timeEncoder": "iso8601" - } - }`, config.Logconfig.Level, logformat, outputPath, errorPath) - var zapCfg zap.Config - if err := json.Unmarshal([]byte(zapConfigJson), &zapCfg); err != nil { - return logger, err - } - logger, err := zapCfg.Build() - return logger, err -} - -func readConfig(configFile string) (acmedns.AcmeDnsConfig, string, error) { - var usedConfigFile string - var config acmedns.AcmeDnsConfig - var err error - if acmedns.FileIsAccessible(configFile) { - usedConfigFile = configFile - config, err = acmedns.ReadConfig(configFile) - } else if acmedns.FileIsAccessible("./config.cfg") { - usedConfigFile = "./config.cfg" - config, err = acmedns.ReadConfig("./config.cfg") - } else { - err = fmt.Errorf("configuration file not found") - } - if err != nil { - err = fmt.Errorf("encountered an error while trying to read configuration file: %s\n", err) - } - return config, usedConfigFile, err -} - func main() { syscall.Umask(0077) configPtr := flag.String("c", "/etc/acme-dns/config.cfg", "config file location") @@ -76,18 +21,19 @@ func main() { // Read global config var err error var logger *zap.Logger - config, usedConfigFile, err := readConfig(*configPtr) + config, usedConfigFile, err := acmedns.ReadConfig(*configPtr) if err != nil { fmt.Printf("Error: %s\n", err) os.Exit(1) } - logger, err = setupLogging(config) + logger, err = acmedns.SetupLogging(config) if err != nil { fmt.Printf("Could not set up logging: %s\n", err) os.Exit(1) } defer logger.Sync() sugar := logger.Sugar() + sugar.Infow("Using config file", "file", usedConfigFile) sugar.Info("Starting up") diff --git a/pkg/acmedns/config.go b/pkg/acmedns/config.go index 35f618f..9cd05ca 100644 --- a/pkg/acmedns/config.go +++ b/pkg/acmedns/config.go @@ -2,6 +2,7 @@ package acmedns import ( "errors" + "fmt" "os" "github.com/BurntSushi/toml" @@ -20,7 +21,7 @@ func FileIsAccessible(fname string) bool { return true } -func ReadConfig(fname string) (AcmeDnsConfig, error) { +func readTomlConfig(fname string) (AcmeDnsConfig, error) { var conf AcmeDnsConfig _, err := toml.DecodeFile(fname, &conf) if err != nil { @@ -46,3 +47,22 @@ func prepareConfig(conf AcmeDnsConfig) (AcmeDnsConfig, error) { return conf, nil } + +func ReadConfig(configFile string) (AcmeDnsConfig, string, error) { + var usedConfigFile string + var config AcmeDnsConfig + var err error + if FileIsAccessible(configFile) { + usedConfigFile = configFile + config, err = readTomlConfig(configFile) + } else if FileIsAccessible("./config.cfg") { + usedConfigFile = "./config.cfg" + config, err = readTomlConfig("./config.cfg") + } else { + err = fmt.Errorf("configuration file not found") + } + if err != nil { + err = fmt.Errorf("encountered an error while trying to read configuration file: %s\n", err) + } + return config, usedConfigFile, err +} diff --git a/pkg/acmedns/interfaces.go b/pkg/acmedns/interfaces.go index 01a7abd..d7e3316 100644 --- a/pkg/acmedns/interfaces.go +++ b/pkg/acmedns/interfaces.go @@ -18,5 +18,6 @@ type AcmednsDB interface { type AcmednsNS interface { Start(errorChannel chan error) SetOwnAuthKey(key string) + SetNotifyStartedFunc(func()) ParseRecords() } diff --git a/pkg/acmedns/logging.go b/pkg/acmedns/logging.go new file mode 100644 index 0000000..becb021 --- /dev/null +++ b/pkg/acmedns/logging.go @@ -0,0 +1,43 @@ +package acmedns + +import ( + "encoding/json" + "fmt" + + "go.uber.org/zap" +) + +func SetupLogging(config AcmeDnsConfig) (*zap.Logger, error) { + var logger *zap.Logger + logformat := "console" + if config.Logconfig.Format == "json" { + logformat = "json" + } + outputPath := "stdout" + if config.Logconfig.Logtype == "file" { + outputPath = config.Logconfig.File + } + errorPath := "stderr" + if config.Logconfig.Logtype == "file" { + errorPath = config.Logconfig.File + } + zapConfigJson := fmt.Sprintf(`{ + "level": "%s", + "encoding": "%s", + "outputPaths": ["%s"], + "errorOutputPaths": ["%s"], + "encoderConfig": { + "timeKey": "time", + "messageKey": "msg", + "levelKey": "level", + "levelEncoder": "lowercase", + "timeEncoder": "iso8601" + } + }`, config.Logconfig.Level, logformat, outputPath, errorPath) + var zapCfg zap.Config + if err := json.Unmarshal([]byte(zapConfigJson), &zapCfg); err != nil { + return logger, err + } + logger, err := zapCfg.Build() + return logger, err +} diff --git a/pkg/acmedns/util.go b/pkg/acmedns/util.go index 573439c..25fc51d 100644 --- a/pkg/acmedns/util.go +++ b/pkg/acmedns/util.go @@ -2,6 +2,7 @@ package acmedns import ( "crypto/rand" + "golang.org/x/crypto/bcrypt" "math/big" "regexp" ) @@ -29,3 +30,10 @@ func generatePassword(length int) string { } return string(ret) } + +func CorrectPassword(pw string, hash string) bool { + if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pw)); err == nil { + return true + } + return false +} diff --git a/util_test.go b/pkg/acmedns/util_test.go similarity index 56% rename from util_test.go rename to pkg/acmedns/util_test.go index 1eb4b67..485a49c 100644 --- a/util_test.go +++ b/pkg/acmedns/util_test.go @@ -1,6 +1,8 @@ -package main +package acmedns import ( + "go.uber.org/zap" + "go.uber.org/zap/zapcore" "io/ioutil" "os" "syscall" @@ -9,21 +11,55 @@ import ( log "github.com/sirupsen/logrus" ) +func fakeConfig() AcmeDnsConfig { + conf := AcmeDnsConfig{} + conf.Logconfig.Logtype = "stdout" + return conf +} + func TestSetupLogging(t *testing.T) { + conf := fakeConfig() for i, test := range []struct { format string level string - expected string + expected zapcore.Level }{ - {"text", "warning", "warning"}, - {"json", "debug", "debug"}, - {"text", "info", "info"}, - {"json", "error", "error"}, - {"text", "something", "warning"}, + {"text", "warn", zap.WarnLevel}, + {"json", "debug", zap.DebugLevel}, + {"text", "info", zap.InfoLevel}, + {"json", "error", zap.ErrorLevel}, } { - setupLogging(test.format, test.level) - if log.GetLevel().String() != test.expected { - t.Errorf("Test %d: Expected loglevel %s but got %s", i, test.expected, log.GetLevel().String()) + conf.Logconfig.Format = test.format + conf.Logconfig.Level = test.level + logger, err := SetupLogging(conf) + if err != nil { + t.Errorf("Got unexpected error: %s", err) + } else { + if logger.Sugar().Level() != test.expected { + t.Errorf("Test %d: Expected loglevel %s but got %s", i, test.expected, log.GetLevel().String()) + } + } + } +} + +func TestSetupLoggingError(t *testing.T) { + conf := fakeConfig() + for _, test := range []struct { + format string + level string + errexpected bool + }{ + {"text", "warn", false}, + {"json", "debug", false}, + {"text", "info", false}, + {"json", "error", false}, + {"text", "something", true}, + } { + conf.Logconfig.Format = test.format + conf.Logconfig.Level = test.level + _, err := SetupLogging(conf) + if test.errexpected && err == nil { + t.Errorf("Expected error but did not get one for loglevel: %s", err) } } } @@ -31,11 +67,11 @@ func TestSetupLogging(t *testing.T) { func TestReadConfig(t *testing.T) { for i, test := range []struct { inFile []byte - output DNSConfig + output AcmeDnsConfig }{ { []byte("[general]\nlisten = \":53\"\ndebug = true\n[api]\napi_domain = \"something.strange\""), - DNSConfig{ + AcmeDnsConfig{ General: general{ Listen: ":53", Debug: true, @@ -48,10 +84,10 @@ func TestReadConfig(t *testing.T) { { []byte("[\x00[[[[[[[[[de\nlisten =]"), - DNSConfig{}, + AcmeDnsConfig{}, }, } { - tmpfile, err := ioutil.TempFile("", "acmedns") + tmpfile, err := os.CreateTemp("", "acmedns") if err != nil { t.Error("Could not create temporary file") } @@ -64,7 +100,7 @@ func TestReadConfig(t *testing.T) { if err := tmpfile.Close(); err != nil { t.Error("Could not close temporary file") } - ret, _ := readConfig(tmpfile.Name()) + ret, _, _ := ReadConfig(tmpfile.Name()) if ret.General.Listen != test.output.General.Listen { t.Errorf("Test %d: Expected listen value %s, but got %s", i, test.output.General.Listen, ret.General.Listen) } @@ -74,30 +110,6 @@ func TestReadConfig(t *testing.T) { } } -func TestGetIPListFromHeader(t *testing.T) { - for i, test := range []struct { - input string - output []string - }{ - {"1.1.1.1, 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}}, - {" 1.1.1.1 , 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}}, - {",1.1.1.1 ,2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}}, - } { - res := getIPListFromHeader(test.input) - if len(res) != len(test.output) { - t.Errorf("Test %d: Expected [%d] items in return list, but got [%d]", i, len(test.output), len(res)) - } else { - - for j, vv := range test.output { - if res[j] != vv { - t.Errorf("Test %d: Expected return value [%v] but got [%v]", j, test.output, res) - } - - } - } - } -} - func TestFileCheckPermissionDenied(t *testing.T) { tmpfile, err := ioutil.TempFile("", "acmedns") if err != nil { @@ -105,14 +117,14 @@ func TestFileCheckPermissionDenied(t *testing.T) { } defer os.Remove(tmpfile.Name()) _ = syscall.Chmod(tmpfile.Name(), 0000) - if fileIsAccessible(tmpfile.Name()) { + if FileIsAccessible(tmpfile.Name()) { t.Errorf("File should not be accessible") } _ = syscall.Chmod(tmpfile.Name(), 0644) } func TestFileCheckNotExists(t *testing.T) { - if fileIsAccessible("/path/that/does/not/exist") { + if FileIsAccessible("/path/that/does/not/exist") { t.Errorf("File should not be accessible") } } @@ -123,19 +135,19 @@ func TestFileCheckOK(t *testing.T) { t.Error("Could not create temporary file") } defer os.Remove(tmpfile.Name()) - if !fileIsAccessible(tmpfile.Name()) { + if !FileIsAccessible(tmpfile.Name()) { t.Errorf("File should be accessible") } } func TestPrepareConfig(t *testing.T) { for i, test := range []struct { - input DNSConfig + input AcmeDnsConfig shoulderror bool }{ - {DNSConfig{Database: dbsettings{Engine: "whatever", Connection: "whatever_too"}}, false}, - {DNSConfig{Database: dbsettings{Engine: "", Connection: "whatever_too"}}, true}, - {DNSConfig{Database: dbsettings{Engine: "whatever", Connection: ""}}, true}, + {AcmeDnsConfig{Database: dbsettings{Engine: "whatever", Connection: "whatever_too"}}, false}, + {AcmeDnsConfig{Database: dbsettings{Engine: "", Connection: "whatever_too"}}, true}, + {AcmeDnsConfig{Database: dbsettings{Engine: "whatever", Connection: ""}}, true}, } { _, err := prepareConfig(test.input) if test.shoulderror { diff --git a/api_test.go b/pkg/api/api_test.go similarity index 72% rename from api_test.go rename to pkg/api/api_test.go index 119ab98..fc93115 100644 --- a/api_test.go +++ b/pkg/api/api_test.go @@ -1,4 +1,4 @@ -package main +package api import ( "context" @@ -8,17 +8,29 @@ import ( "net/http/httptest" "testing" + "github.com/acme-dns/acme-dns/pkg/acmedns" + "github.com/acme-dns/acme-dns/pkg/database" + "github.com/DATA-DOG/go-sqlmock" "github.com/gavv/httpexpect" "github.com/google/uuid" "github.com/julienschmidt/httprouter" "github.com/rs/cors" + "go.uber.org/zap" ) +func fakeConfigAndLogger() (acmedns.AcmeDnsConfig, *zap.SugaredLogger) { + c := acmedns.AcmeDnsConfig{} + c.Database.Engine = "sqlite" + c.Database.Connection = ":memory:" + l := zap.NewNop().Sugar() + return c, l +} + // noAuth function to write ACMETxt model to context while not preforming any validation func noAuth(update httprouter.Handle) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { - postData := ACMETxt{} + postData := acmedns.ACMETxt{} uname := r.Header.Get("X-Api-User") passwd := r.Header.Get("X-Api-Key") @@ -46,42 +58,37 @@ func getExpect(t *testing.T, server *httptest.Server) *httpexpect.Expect { }) } -func setupRouter(debug bool, noauth bool) http.Handler { +func setupRouter(debug bool, noauth bool) (http.Handler, AcmednsAPI, acmedns.AcmednsDB) { api := httprouter.New() - var dbcfg = dbsettings{ - Engine: "sqlite3", - Connection: ":memory:"} - var httpapicfg = httpapi{ - Domain: "", - Port: "8080", - TLS: "none", - CorsOrigins: []string{"*"}, - UseHeader: true, - HeaderName: "X-Forwarded-For", - } - var dnscfg = DNSConfig{ - API: httpapicfg, - Database: dbcfg, - } - Config = dnscfg + config, logger := fakeConfigAndLogger() + config.API.Domain = "" + config.API.Port = "8080" + config.API.TLS = "none" + config.API.CorsOrigins = []string{"*"} + config.API.UseHeader = true + config.API.HeaderName = "X-Forwarded-For" + + db, _ := database.Init(&config, logger) + errChan := make(chan error, 1) + adnsapi := Init(&config, db, logger, errChan) c := cors.New(cors.Options{ - AllowedOrigins: Config.API.CorsOrigins, + AllowedOrigins: config.API.CorsOrigins, AllowedMethods: []string{"GET", "POST"}, OptionsPassthrough: false, - Debug: Config.General.Debug, + Debug: config.General.Debug, }) - api.POST("/register", webRegisterPost) - api.GET("/health", healthCheck) + api.POST("/register", adnsapi.webRegisterPost) + api.GET("/health", adnsapi.healthCheck) if noauth { - api.POST("/update", noAuth(webUpdatePost)) + api.POST("/update", noAuth(adnsapi.webUpdatePost)) } else { - api.POST("/update", Auth(webUpdatePost)) + api.POST("/update", adnsapi.Auth(adnsapi.webUpdatePost)) } - return c.Handler(api) + return c.Handler(api), adnsapi, db } func TestApiRegister(t *testing.T) { - router := setupRouter(false, false) + router, _, _ := setupRouter(false, false) server := httptest.NewServer(router) defer server.Close() e := getExpect(t, server) @@ -117,7 +124,7 @@ func TestApiRegister(t *testing.T) { } func TestApiRegisterBadAllowFrom(t *testing.T) { - router := setupRouter(false, false) + router, _, _ := setupRouter(false, false) server := httptest.NewServer(router) defer server.Close() e := getExpect(t, server) @@ -147,7 +154,7 @@ func TestApiRegisterBadAllowFrom(t *testing.T) { } func TestApiRegisterMalformedJSON(t *testing.T) { - router := setupRouter(false, false) + router, _, _ := setupRouter(false, false) server := httptest.NewServer(router) defer server.Close() e := getExpect(t, server) @@ -174,13 +181,13 @@ func TestApiRegisterMalformedJSON(t *testing.T) { } func TestApiRegisterWithMockDB(t *testing.T) { - router := setupRouter(false, false) + router, _, db := setupRouter(false, false) server := httptest.NewServer(router) defer server.Close() e := getExpect(t, server) - oldDb := DB.GetBackend() - db, mock, _ := sqlmock.New() - DB.SetBackend(db) + oldDb := db.GetBackend() + mdb, mock, _ := sqlmock.New() + db.SetBackend(mdb) defer db.Close() mock.ExpectBegin() mock.ExpectPrepare("INSERT INTO records").WillReturnError(errors.New("error")) @@ -188,7 +195,7 @@ func TestApiRegisterWithMockDB(t *testing.T) { Status(http.StatusInternalServerError). JSON().Object(). ContainsKey("error") - DB.SetBackend(oldDb) + db.SetBackend(oldDb) } func TestApiUpdateWithInvalidSubdomain(t *testing.T) { @@ -198,11 +205,11 @@ func TestApiUpdateWithInvalidSubdomain(t *testing.T) { "subdomain": "", "txt": ""} - router := setupRouter(false, false) + router, _, db := setupRouter(false, false) server := httptest.NewServer(router) defer server.Close() e := getExpect(t, server) - newUser, err := DB.Register(cidrslice{}) + newUser, err := db.Register(acmedns.Cidrslice{}) if err != nil { t.Errorf("Could not create new user, got error [%v]", err) } @@ -228,11 +235,11 @@ func TestApiUpdateWithInvalidTxt(t *testing.T) { "subdomain": "", "txt": ""} - router := setupRouter(false, false) + router, _, db := setupRouter(false, false) server := httptest.NewServer(router) defer server.Close() e := getExpect(t, server) - newUser, err := DB.Register(cidrslice{}) + newUser, err := db.Register(acmedns.Cidrslice{}) if err != nil { t.Errorf("Could not create new user, got error [%v]", err) } @@ -252,7 +259,7 @@ func TestApiUpdateWithInvalidTxt(t *testing.T) { } func TestApiUpdateWithoutCredentials(t *testing.T) { - router := setupRouter(false, false) + router, _, _ := setupRouter(false, false) server := httptest.NewServer(router) defer server.Close() e := getExpect(t, server) @@ -270,11 +277,11 @@ func TestApiUpdateWithCredentials(t *testing.T) { "subdomain": "", "txt": ""} - router := setupRouter(false, false) + router, _, db := setupRouter(false, false) server := httptest.NewServer(router) defer server.Close() e := getExpect(t, server) - newUser, err := DB.Register(cidrslice{}) + newUser, err := db.Register(acmedns.Cidrslice{}) if err != nil { t.Errorf("Could not create new user, got error [%v]", err) } @@ -303,13 +310,13 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) { updateJSON["subdomain"] = "a097455b-52cc-4569-90c8-7a4b97c6eba8" updateJSON["txt"] = validTxtData - router := setupRouter(false, true) + router, _, db := setupRouter(false, true) server := httptest.NewServer(router) defer server.Close() e := getExpect(t, server) - oldDb := DB.GetBackend() - db, mock, _ := sqlmock.New() - DB.SetBackend(db) + oldDb := db.GetBackend() + mdb, mock, _ := sqlmock.New() + db.SetBackend(mdb) defer db.Close() mock.ExpectBegin() mock.ExpectPrepare("UPDATE records").WillReturnError(errors.New("error")) @@ -319,31 +326,31 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) { Status(http.StatusInternalServerError). JSON().Object(). ContainsKey("error") - DB.SetBackend(oldDb) + db.SetBackend(oldDb) } func TestApiManyUpdateWithCredentials(t *testing.T) { validTxtData := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" - router := setupRouter(true, false) + router, _, db := setupRouter(true, false) server := httptest.NewServer(router) defer server.Close() e := getExpect(t, server) // User without defined CIDR masks - newUser, err := DB.Register(cidrslice{}) + newUser, err := db.Register(acmedns.Cidrslice{}) if err != nil { t.Errorf("Could not create new user, got error [%v]", err) } // User with defined allow from - CIDR masks, all invalid // (httpexpect doesn't provide a way to mock remote ip) - newUserWithCIDR, err := DB.Register(cidrslice{"192.168.1.1/32", "invalid"}) + newUserWithCIDR, err := db.Register(acmedns.Cidrslice{"192.168.1.1/32", "invalid"}) if err != nil { t.Errorf("Could not create new user with CIDR, got error [%v]", err) } // Another user with valid CIDR mask to match the httpexpect default - newUserWithValidCIDR, err := DB.Register(cidrslice{"10.1.2.3/32", "invalid"}) + newUserWithValidCIDR, err := db.Register(acmedns.Cidrslice{"10.1.2.3/32", "invalid"}) if err != nil { t.Errorf("Could not create new user with a valid CIDR, got error [%v]", err) } @@ -381,30 +388,30 @@ func TestApiManyUpdateWithCredentials(t *testing.T) { func TestApiManyUpdateWithIpCheckHeaders(t *testing.T) { - router := setupRouter(false, false) + router, adnsapi, db := setupRouter(false, false) server := httptest.NewServer(router) defer server.Close() e := getExpect(t, server) // Use header checks from default header (X-Forwarded-For) - Config.API.UseHeader = true + adnsapi.Config.API.UseHeader = true // User without defined CIDR masks - newUser, err := DB.Register(cidrslice{}) + newUser, err := db.Register(acmedns.Cidrslice{}) if err != nil { t.Errorf("Could not create new user, got error [%v]", err) } - newUserWithCIDR, err := DB.Register(cidrslice{"192.168.1.2/32", "invalid"}) + newUserWithCIDR, err := db.Register(acmedns.Cidrslice{"192.168.1.2/32", "invalid"}) if err != nil { t.Errorf("Could not create new user with CIDR, got error [%v]", err) } - newUserWithIP6CIDR, err := DB.Register(cidrslice{"2002:c0a8::0/32"}) + newUserWithIP6CIDR, err := db.Register(acmedns.Cidrslice{"2002:c0a8::0/32"}) if err != nil { t.Errorf("Could not create a new user with IP6 CIDR, got error [%v]", err) } for _, test := range []struct { - user ACMETxt + user acmedns.ACMETxt headerValue string status int }{ @@ -428,13 +435,66 @@ func TestApiManyUpdateWithIpCheckHeaders(t *testing.T) { Expect(). Status(test.status) } - Config.API.UseHeader = false + adnsapi.Config.API.UseHeader = false } func TestApiHealthCheck(t *testing.T) { - router := setupRouter(false, false) + router, _, _ := setupRouter(false, false) server := httptest.NewServer(router) defer server.Close() e := getExpect(t, server) e.GET("/health").Expect().Status(http.StatusOK) } + +func TestGetIPListFromHeader(t *testing.T) { + for i, test := range []struct { + input string + output []string + }{ + {"1.1.1.1, 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}}, + {" 1.1.1.1 , 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}}, + {",1.1.1.1 ,2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}}, + } { + res := getIPListFromHeader(test.input) + if len(res) != len(test.output) { + t.Errorf("Test %d: Expected [%d] items in return list, but got [%d]", i, len(test.output), len(res)) + } else { + + for j, vv := range test.output { + if res[j] != vv { + t.Errorf("Test %d: Expected return value [%v] but got [%v]", j, test.output, res) + } + + } + } + } +} + +func TestUpdateAllowedFromIP(t *testing.T) { + _, adnsapi, _ := setupRouter(false, false) + adnsapi.Config.API.UseHeader = false + userWithAllow := acmedns.NewACMETxt() + userWithAllow.AllowFrom = acmedns.Cidrslice{"192.168.1.2/32", "[::1]/128"} + userWithoutAllow := acmedns.NewACMETxt() + + for i, test := range []struct { + remoteaddr string + expected bool + }{ + {"192.168.1.2:1234", true}, + {"192.168.1.1:1234", false}, + {"invalid", false}, + {"[::1]:4567", true}, + } { + newreq, _ := http.NewRequest("GET", "/whatever", nil) + newreq.RemoteAddr = test.remoteaddr + ret := adnsapi.updateAllowedFromIP(newreq, userWithAllow) + if test.expected != ret { + t.Errorf("Test %d: Unexpected result for user with allowForm set", i) + } + + if !adnsapi.updateAllowedFromIP(newreq, userWithoutAllow) { + t.Errorf("Test %d: Unexpected result for user without allowForm set", i) + } + } +} diff --git a/pkg/api/auth.go b/pkg/api/auth.go index 825e129..e6dc8de 100644 --- a/pkg/api/auth.go +++ b/pkg/api/auth.go @@ -74,11 +74,11 @@ func (a *AcmednsAPI) getUserFromRequest(r *http.Request) (acmedns.ACMETxt, error a.Logger.Errorw("Error while trying to get user", "error", err.Error()) // To protect against timed side channel (never gonna give you up) - correctPassword(passwd, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36") + acmedns.CorrectPassword(passwd, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36") return acmedns.ACMETxt{}, fmt.Errorf("invalid username: %s", uname) } - if correctPassword(passwd, dbuser.Password) { + if acmedns.CorrectPassword(passwd, dbuser.Password) { return dbuser, nil } return acmedns.ACMETxt{}, fmt.Errorf("invalid password for user %s", uname) diff --git a/pkg/api/util.go b/pkg/api/util.go index 1a6c733..afdb572 100644 --- a/pkg/api/util.go +++ b/pkg/api/util.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/acme-dns/acme-dns/pkg/acmedns" "github.com/google/uuid" - "golang.org/x/crypto/bcrypt" "regexp" "strings" "unicode/utf8" @@ -31,13 +30,6 @@ func validKey(k string) bool { return false } -func correctPassword(pw string, hash string) bool { - if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pw)); err == nil { - return true - } - return false -} - func getIPListFromHeader(header string) []string { iplist := []string{} for _, v := range strings.Split(header, ",") { diff --git a/validation_test.go b/pkg/api/validation_test.go similarity index 88% rename from validation_test.go rename to pkg/api/validation_test.go index 16dfc04..9e17187 100644 --- a/validation_test.go +++ b/pkg/api/validation_test.go @@ -1,6 +1,7 @@ -package main +package api import ( + "github.com/acme-dns/acme-dns/pkg/acmedns" "testing" "github.com/google/uuid" @@ -103,7 +104,7 @@ func TestCorrectPassword(t *testing.T) { false}, {"", "", false}, } { - ret := correctPassword(test.pw, test.hash) + ret := acmedns.CorrectPassword(test.pw, test.hash) if ret != test.output { t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret) } @@ -112,12 +113,12 @@ func TestCorrectPassword(t *testing.T) { func TestGetValidCIDRMasks(t *testing.T) { for i, test := range []struct { - input cidrslice - output cidrslice + input acmedns.Cidrslice + output acmedns.Cidrslice }{ - {cidrslice{"10.0.0.1/24"}, cidrslice{"10.0.0.1/24"}}, - {cidrslice{"invalid", "127.0.0.1/32"}, cidrslice{"127.0.0.1/32"}}, - {cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}, cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}}, + {acmedns.Cidrslice{"10.0.0.1/24"}, acmedns.Cidrslice{"10.0.0.1/24"}}, + {acmedns.Cidrslice{"invalid", "127.0.0.1/32"}, acmedns.Cidrslice{"127.0.0.1/32"}}, + {acmedns.Cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}, acmedns.Cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}}, } { ret := test.input.ValidEntries() if len(ret) == len(test.output) { diff --git a/pkg/database/db_test.go b/pkg/database/db_test.go index beca9c1..b2f3e99 100644 --- a/pkg/database/db_test.go +++ b/pkg/database/db_test.go @@ -1,10 +1,12 @@ -package main +package database import ( "database/sql" "database/sql/driver" "errors" + "github.com/acme-dns/acme-dns/pkg/acmedns" "github.com/erikstmartin/go-testdb" + "go.uber.org/zap" "testing" ) @@ -21,42 +23,38 @@ func (r testResult) RowsAffected() (int64, error) { return r.affectedRows, nil } -func TestDBInit(t *testing.T) { - fakeDB := new(acmedb) - err := fakeDB.Init("notarealegine", "connectionstring") - if err == nil { - t.Errorf("Was expecting error, didn't get one.") - } +func fakeConfigAndLogger() (acmedns.AcmeDnsConfig, *zap.SugaredLogger) { + c := acmedns.AcmeDnsConfig{} + c.Database.Engine = "sqlite" + c.Database.Connection = ":memory:" + l := zap.NewNop().Sugar() + return c, l +} - testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) { - return testResult{1, 0}, errors.New("Prepared query error") - }) - defer testdb.Reset() - - errorDB := new(acmedb) - err = errorDB.Init("testdb", "") - if err == nil { - t.Errorf("Was expecting DB initiation error but got none") - } - errorDB.Close() +func fakeDB() acmedns.AcmednsDB { + conf, logger := fakeConfigAndLogger() + db, _ := Init(&conf, logger) + return db } func TestRegisterNoCIDR(t *testing.T) { // Register tests - _, err := DB.Register(cidrslice{}) + DB := fakeDB() + _, err := DB.Register(acmedns.Cidrslice{}) if err != nil { t.Errorf("Registration failed, got error [%v]", err) } } func TestRegisterMany(t *testing.T) { + DB := fakeDB() for i, test := range []struct { - input cidrslice - output cidrslice + input acmedns.Cidrslice + output acmedns.Cidrslice }{ - {cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}, cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}}, - {cidrslice{"1.1.1./32", "1922.168.42.42/8", "1.1.1.1/33", "1.2.3.4/"}, cidrslice{}}, - {cidrslice{"7.6.5.4/32", "invalid", "1.0.0.1/2"}, cidrslice{"7.6.5.4/32", "1.0.0.1/2"}}, + {acmedns.Cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}, acmedns.Cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}}, + {acmedns.Cidrslice{"1.1.1./32", "1922.168.42.42/8", "1.1.1.1/33", "1.2.3.4/"}, acmedns.Cidrslice{}}, + {acmedns.Cidrslice{"7.6.5.4/32", "invalid", "1.0.0.1/2"}, acmedns.Cidrslice{"7.6.5.4/32", "1.0.0.1/2"}}, } { user, err := DB.Register(test.input) if err != nil { @@ -77,8 +75,9 @@ func TestRegisterMany(t *testing.T) { } func TestGetByUsername(t *testing.T) { + DB := fakeDB() // Create reg to refer to - reg, err := DB.Register(cidrslice{}) + reg, err := DB.Register(acmedns.Cidrslice{}) if err != nil { t.Errorf("Registration failed, got error [%v]", err) } @@ -97,13 +96,14 @@ func TestGetByUsername(t *testing.T) { } // regUser password already is a bcrypt hash - if !correctPassword(reg.Password, regUser.Password) { + if !acmedns.CorrectPassword(reg.Password, regUser.Password) { t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regUser.Password) } } func TestPrepareErrors(t *testing.T) { - reg, _ := DB.Register(cidrslice{}) + DB := fakeDB() + reg, _ := DB.Register(acmedns.Cidrslice{}) tdb, err := sql.Open("testdb", "") if err != nil { t.Errorf("Got error: %v", err) @@ -125,7 +125,8 @@ func TestPrepareErrors(t *testing.T) { } func TestQueryExecErrors(t *testing.T) { - reg, _ := DB.Register(cidrslice{}) + DB := fakeDB() + reg, _ := DB.Register(acmedns.Cidrslice{}) testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) { return testResult{1, 0}, errors.New("Prepared query error") }) @@ -156,7 +157,7 @@ func TestQueryExecErrors(t *testing.T) { t.Errorf("Expected error from exec in GetByDomain, but got none") } - _, err = DB.Register(cidrslice{}) + _, err = DB.Register(acmedns.Cidrslice{}) if err == nil { t.Errorf("Expected error from exec in Register, but got none") } @@ -169,7 +170,8 @@ func TestQueryExecErrors(t *testing.T) { } func TestQueryScanErrors(t *testing.T) { - reg, _ := DB.Register(cidrslice{}) + DB := fakeDB() + reg, _ := DB.Register(acmedns.Cidrslice{}) testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) { return testResult{1, 0}, errors.New("Prepared query error") @@ -198,7 +200,8 @@ func TestQueryScanErrors(t *testing.T) { } func TestBadDBValues(t *testing.T) { - reg, _ := DB.Register(cidrslice{}) + DB := fakeDB() + reg, _ := DB.Register(acmedns.Cidrslice{}) testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) { columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"} @@ -228,8 +231,9 @@ func TestBadDBValues(t *testing.T) { } func TestGetTXTForDomain(t *testing.T) { + DB := fakeDB() // Create reg to refer to - reg, err := DB.Register(cidrslice{}) + reg, err := DB.Register(acmedns.Cidrslice{}) if err != nil { t.Errorf("Registration failed, got error [%v]", err) } @@ -276,8 +280,9 @@ func TestGetTXTForDomain(t *testing.T) { } func TestUpdate(t *testing.T) { + DB := fakeDB() // Create reg to refer to - reg, err := DB.Register(cidrslice{}) + reg, err := DB.Register(acmedns.Cidrslice{}) if err != nil { t.Errorf("Registration failed, got error [%v]", err) } diff --git a/dns_test.go b/pkg/nameserver/dns_test.go similarity index 70% rename from dns_test.go rename to pkg/nameserver/dns_test.go index ba42a5d..de8eaa5 100644 --- a/dns_test.go +++ b/pkg/nameserver/dns_test.go @@ -1,20 +1,66 @@ -package main +package nameserver import ( "database/sql" "database/sql/driver" "errors" "fmt" - "testing" - + "github.com/acme-dns/acme-dns/pkg/acmedns" + "github.com/acme-dns/acme-dns/pkg/database" "github.com/erikstmartin/go-testdb" "github.com/miekg/dns" + "go.uber.org/zap" + "go.uber.org/zap/zaptest/observer" + "sync" + "testing" ) type resolver struct { server string } +var records = []string{ + "auth.example.org. A 192.168.1.100", + "ns1.auth.example.org. A 192.168.1.101", + "cn.example.org CNAME something.example.org.", + "!''b', unparseable ", + "ns2.auth.example.org. A 192.168.1.102", +} + +func loggerHasEntryWithMessage(message string, logObserver *observer.ObservedLogs) bool { + if len(logObserver.FilterMessage(message).All()) > 0 { + return true + } + return false +} + +func fakeConfigAndLogger() (acmedns.AcmeDnsConfig, *zap.SugaredLogger, *observer.ObservedLogs) { + c := acmedns.AcmeDnsConfig{} + c.Database.Engine = "sqlite" + c.Database.Connection = ":memory:" + obsCore, logObserver := observer.New(zap.DebugLevel) + obsLogger := zap.New(obsCore).Sugar() + return c, obsLogger, logObserver +} + +func setupDNS() (acmedns.AcmednsNS, acmedns.AcmednsDB, *observer.ObservedLogs) { + config, logger, logObserver := fakeConfigAndLogger() + config.General.Domain = "auth.example.org" + config.General.Listen = "127.0.0.1:15353" + config.General.Proto = "udp" + config.General.Nsname = "ns1.auth.example.org" + config.General.Nsadmin = "admin.example.org" + config.General.StaticRecords = records + config.General.Debug = false + db, _ := database.Init(&config, logger) + server := Nameserver{Config: &config, DB: db, Logger: logger, personalAuthKey: ""} + server.Domains = make(map[string]Records) + server.Server = &dns.Server{Addr: config.General.Listen, Net: config.General.Proto} + server.ParseRecords() + server.OwnDomain = "auth.example.org." + return &server, db, logObserver +} + func (r *resolver) lookup(host string, qtype uint16) (*dns.Msg, error) { msg := new(dns.Msg) msg.Id = dns.Id() @@ -27,7 +73,6 @@ func (r *resolver) lookup(host string, qtype uint16) (*dns.Msg, error) { if in != nil && in.Rcode != dns.RcodeSuccess { return in, fmt.Errorf("Received error from the server [%s]", dns.RcodeToString[in.Rcode]) } - return in, nil } @@ -49,6 +94,18 @@ func hasExpectedTXTAnswer(answer []dns.RR, cmpTXT string) error { } func TestQuestionDBError(t *testing.T) { + config, logger, _ := fakeConfigAndLogger() + config.General.Listen = "127.0.0.1:15353" + config.General.Proto = "udp" + config.General.Domain = "auth.example.org" + config.General.Nsname = "ns1.auth.example.org" + config.General.Nsadmin = "admin.example.org" + config.General.StaticRecords = records + config.General.Debug = false + db, _ := database.Init(&config, logger) + server := Nameserver{Config: &config, DB: db, Logger: logger, personalAuthKey: ""} + server.Domains = make(map[string]Records) + server.ParseRecords() testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) { columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"} return testdb.RowsFromSlice(columns, [][]driver.Value{}), errors.New("Prepared query error") @@ -60,48 +117,61 @@ func TestQuestionDBError(t *testing.T) { if err != nil { t.Errorf("Got error: %v", err) } - oldDb := DB.GetBackend() + oldDb := db.GetBackend() - DB.SetBackend(tdb) - defer DB.SetBackend(oldDb) + db.SetBackend(tdb) + defer db.SetBackend(oldDb) q := dns.Question{Name: dns.Fqdn("whatever.tld"), Qtype: dns.TypeTXT, Qclass: dns.ClassINET} - _, err = dnsserver.answerTXT(q) + _, err = server.answerTXT(q) if err == nil { t.Errorf("Expected error but got none") } } func TestParse(t *testing.T) { - var testcfg = DNSConfig{ - General: general{ - Domain: ")", - Nsname: "ns1.auth.example.org", - Nsadmin: "admin.example.org", - StaticRecords: []string{}, - Debug: false, - }, - } - dnsserver.ParseRecords(testcfg) - if !loggerHasEntryWithMessage("Error while adding SOA record") { + config, logger, logObserver := fakeConfigAndLogger() + config.General.Listen = "127.0.0.1:15353" + config.General.Proto = "udp" + config.General.Domain = ")" + config.General.Nsname = "ns1.auth.example.org" + config.General.Nsadmin = "admin.example.org" + config.General.StaticRecords = records + config.General.Debug = false + config.General.StaticRecords = []string{} + db, _ := database.Init(&config, logger) + server := Nameserver{Config: &config, DB: db, Logger: logger, personalAuthKey: ""} + server.Domains = make(map[string]Records) + server.ParseRecords() + if !loggerHasEntryWithMessage("Error while adding SOA record", logObserver) { t.Errorf("Expected SOA parsing to return error, but did not find one") } } func TestResolveA(t *testing.T) { + server, _, _ := setupDNS() + errChan := make(chan error, 1) + waitLock := sync.Mutex{} + waitLock.Lock() + server.SetNotifyStartedFunc(waitLock.Unlock) + go server.Start(errChan) + waitLock.Lock() resolv := resolver{server: "127.0.0.1:15353"} answer, err := resolv.lookup("auth.example.org", dns.TypeA) if err != nil { t.Errorf("%v", err) + return } if len(answer.Answer) == 0 { t.Error("No answer for DNS query") + return } _, err = resolv.lookup("nonexistent.domain.tld", dns.TypeA) if err == nil { t.Errorf("Was expecting error because of NXDOMAIN but got none") + return } } @@ -195,17 +265,20 @@ func TestAuthoritative(t *testing.T) { } } +/* func TestResolveTXT(t *testing.T) { + _, db, _ := setupDNS() resolv := resolver{server: "127.0.0.1:15353"} validTXT := "______________valid_response_______________" - atxt, err := DB.Register(cidrslice{}) + atxt, err := db.Register(acmedns.Cidrslice{}) if err != nil { t.Errorf("Could not initiate db record: [%v]", err) return } atxt.Value = validTXT - err = DB.Update(atxt.ACMETxtPost) + + err = db.Update(atxt.ACMETxtPost) if err != nil { t.Errorf("Could not update db record: [%v]", err) return @@ -254,7 +327,7 @@ func TestResolveTXT(t *testing.T) { } } } -} +}*/ func TestCaseInsensitiveResolveA(t *testing.T) { resolv := resolver{server: "127.0.0.1:15353"} diff --git a/pkg/nameserver/handler.go b/pkg/nameserver/handler.go index cbf50c3..8e21f01 100644 --- a/pkg/nameserver/handler.go +++ b/pkg/nameserver/handler.go @@ -9,7 +9,6 @@ import ( func (n *Nameserver) handleRequest(w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) m.SetReply(r) - // handle edns0 opt := r.IsEdns0() if opt != nil { diff --git a/pkg/nameserver/initialize.go b/pkg/nameserver/initialize.go index 2f6971e..854a8ed 100644 --- a/pkg/nameserver/initialize.go +++ b/pkg/nameserver/initialize.go @@ -1,12 +1,11 @@ package nameserver import ( - "strings" - "github.com/acme-dns/acme-dns/pkg/acmedns" - "github.com/miekg/dns" "go.uber.org/zap" + "strings" + "sync" ) // Records is a slice of ResourceRecords @@ -15,21 +14,23 @@ type Records struct { } type Nameserver struct { - Config *acmedns.AcmeDnsConfig - DB acmedns.AcmednsDB - Logger *zap.SugaredLogger - Server *dns.Server - OwnDomain string - SOA dns.RR - personalAuthKey string - Domains map[string]Records - errChan chan error + Config *acmedns.AcmeDnsConfig + DB acmedns.AcmednsDB + Logger *zap.SugaredLogger + Server *dns.Server + OwnDomain string + NotifyStartedFunc func() + SOA dns.RR + personalAuthKey string + Domains map[string]Records + errChan chan error } func InitAndStart(config *acmedns.AcmeDnsConfig, db acmedns.AcmednsDB, logger *zap.SugaredLogger, errChan chan error) []acmedns.AcmednsNS { dnsservers := make([]acmedns.AcmednsNS, 0) - + waitLock := sync.Mutex{} if strings.HasPrefix(config.General.Proto, "both") { + // Handle the case where DNS server should be started for both udp and tcp udpProto := "udp" tcpProto := "tcp" @@ -46,13 +47,22 @@ func InitAndStart(config *acmedns.AcmeDnsConfig, db acmedns.AcmednsDB, logger *z dnsServerTCP := NewDNSServer(config, db, logger, tcpProto) dnsservers = append(dnsservers, dnsServerTCP) dnsServerTCP.ParseRecords() + // wait for the server to get started to proceed + waitLock.Lock() + dnsServerUDP.SetNotifyStartedFunc(waitLock.Unlock) go dnsServerUDP.Start(errChan) + waitLock.Lock() + dnsServerTCP.SetNotifyStartedFunc(waitLock.Unlock) go dnsServerTCP.Start(errChan) + waitLock.Lock() } else { dnsServer := NewDNSServer(config, db, logger, config.General.Proto) dnsservers = append(dnsservers, dnsServer) dnsServer.ParseRecords() + waitLock.Lock() + dnsServer.SetNotifyStartedFunc(waitLock.Unlock) go dnsServer.Start(errChan) + waitLock.Lock() } return dnsservers } @@ -67,7 +77,6 @@ func NewDNSServer(config *acmedns.AcmeDnsConfig, db acmedns.AcmednsDB, logger *z domain = domain + "." } server.OwnDomain = strings.ToLower(domain) - server.DB = db server.personalAuthKey = "" server.Domains = make(map[string]Records) return &server @@ -79,8 +88,15 @@ func (n *Nameserver) Start(errorChannel chan error) { n.Logger.Infow("Starting DNS listener", "addr", n.Server.Addr, "proto", n.Server.Net) + if n.NotifyStartedFunc != nil { + n.Server.NotifyStartedFunc = n.NotifyStartedFunc + } err := n.Server.ListenAndServe() if err != nil { errorChannel <- err } } + +func (n *Nameserver) SetNotifyStartedFunc(fun func()) { + n.Server.NotifyStartedFunc = fun +}