Re-added tests

This commit is contained in:
Joona Hoikkala 2022-12-25 12:52:49 +02:00
parent 1405e6ab47
commit 157241994f
No known key found for this signature in database
GPG Key ID: 1708DAE66E87A524
15 changed files with 425 additions and 283 deletions

View File

@ -1,34 +0,0 @@
package main
import (
"net/http"
"testing"
)
func TestUpdateAllowedFromIP(t *testing.T) {
Config.API.UseHeader = false
userWithAllow := newACMETxt()
userWithAllow.AllowFrom = cidrslice{"192.168.1.2/32", "[::1]/128"}
userWithoutAllow := newACMETxt()
for i, test := range []struct {
remoteaddr string
expected bool
}{
{"192.168.1.2:1234", true},
{"192.168.1.1:1234", false},
{"invalid", false},
{"[::1]:4567", true},
} {
newreq, _ := http.NewRequest("GET", "/whatever", nil)
newreq.RemoteAddr = test.remoteaddr
ret := updateAllowedFromIP(newreq, userWithAllow)
if test.expected != ret {
t.Errorf("Test %d: Unexpected result for user with allowForm set", i)
}
if !updateAllowedFromIP(newreq, userWithoutAllow) {
t.Errorf("Test %d: Unexpected result for user without allowForm set", i)
}
}
}

60
main.go
View File

@ -1,7 +1,6 @@
package main package main
import ( import (
"encoding/json"
"flag" "flag"
"fmt" "fmt"
"github.com/acme-dns/acme-dns/pkg/api" "github.com/acme-dns/acme-dns/pkg/api"
@ -15,60 +14,6 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
func setupLogging(config acmedns.AcmeDnsConfig) (*zap.Logger, error) {
var logger *zap.Logger
logformat := "console"
if config.Logconfig.Format == "json" {
logformat = "json"
}
outputPath := "stdout"
if config.Logconfig.Logtype == "file" {
outputPath = config.Logconfig.File
}
errorPath := "stderr"
if config.Logconfig.Logtype == "file" {
errorPath = config.Logconfig.File
}
zapConfigJson := fmt.Sprintf(`{
"level": "%s",
"encoding": "%s",
"outputPaths": ["%s"],
"errorOutputPaths": ["%s"],
"encoderConfig": {
"timeKey": "time",
"messageKey": "msg",
"levelKey": "level",
"levelEncoder": "lowercase",
"timeEncoder": "iso8601"
}
}`, config.Logconfig.Level, logformat, outputPath, errorPath)
var zapCfg zap.Config
if err := json.Unmarshal([]byte(zapConfigJson), &zapCfg); err != nil {
return logger, err
}
logger, err := zapCfg.Build()
return logger, err
}
func readConfig(configFile string) (acmedns.AcmeDnsConfig, string, error) {
var usedConfigFile string
var config acmedns.AcmeDnsConfig
var err error
if acmedns.FileIsAccessible(configFile) {
usedConfigFile = configFile
config, err = acmedns.ReadConfig(configFile)
} else if acmedns.FileIsAccessible("./config.cfg") {
usedConfigFile = "./config.cfg"
config, err = acmedns.ReadConfig("./config.cfg")
} else {
err = fmt.Errorf("configuration file not found")
}
if err != nil {
err = fmt.Errorf("encountered an error while trying to read configuration file: %s\n", err)
}
return config, usedConfigFile, err
}
func main() { func main() {
syscall.Umask(0077) syscall.Umask(0077)
configPtr := flag.String("c", "/etc/acme-dns/config.cfg", "config file location") configPtr := flag.String("c", "/etc/acme-dns/config.cfg", "config file location")
@ -76,18 +21,19 @@ func main() {
// Read global config // Read global config
var err error var err error
var logger *zap.Logger var logger *zap.Logger
config, usedConfigFile, err := readConfig(*configPtr) config, usedConfigFile, err := acmedns.ReadConfig(*configPtr)
if err != nil { if err != nil {
fmt.Printf("Error: %s\n", err) fmt.Printf("Error: %s\n", err)
os.Exit(1) os.Exit(1)
} }
logger, err = setupLogging(config) logger, err = acmedns.SetupLogging(config)
if err != nil { if err != nil {
fmt.Printf("Could not set up logging: %s\n", err) fmt.Printf("Could not set up logging: %s\n", err)
os.Exit(1) os.Exit(1)
} }
defer logger.Sync() defer logger.Sync()
sugar := logger.Sugar() sugar := logger.Sugar()
sugar.Infow("Using config file", sugar.Infow("Using config file",
"file", usedConfigFile) "file", usedConfigFile)
sugar.Info("Starting up") sugar.Info("Starting up")

View File

@ -2,6 +2,7 @@ package acmedns
import ( import (
"errors" "errors"
"fmt"
"os" "os"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
@ -20,7 +21,7 @@ func FileIsAccessible(fname string) bool {
return true return true
} }
func ReadConfig(fname string) (AcmeDnsConfig, error) { func readTomlConfig(fname string) (AcmeDnsConfig, error) {
var conf AcmeDnsConfig var conf AcmeDnsConfig
_, err := toml.DecodeFile(fname, &conf) _, err := toml.DecodeFile(fname, &conf)
if err != nil { if err != nil {
@ -46,3 +47,22 @@ func prepareConfig(conf AcmeDnsConfig) (AcmeDnsConfig, error) {
return conf, nil return conf, nil
} }
func ReadConfig(configFile string) (AcmeDnsConfig, string, error) {
var usedConfigFile string
var config AcmeDnsConfig
var err error
if FileIsAccessible(configFile) {
usedConfigFile = configFile
config, err = readTomlConfig(configFile)
} else if FileIsAccessible("./config.cfg") {
usedConfigFile = "./config.cfg"
config, err = readTomlConfig("./config.cfg")
} else {
err = fmt.Errorf("configuration file not found")
}
if err != nil {
err = fmt.Errorf("encountered an error while trying to read configuration file: %s\n", err)
}
return config, usedConfigFile, err
}

View File

@ -18,5 +18,6 @@ type AcmednsDB interface {
type AcmednsNS interface { type AcmednsNS interface {
Start(errorChannel chan error) Start(errorChannel chan error)
SetOwnAuthKey(key string) SetOwnAuthKey(key string)
SetNotifyStartedFunc(func())
ParseRecords() ParseRecords()
} }

43
pkg/acmedns/logging.go Normal file
View File

@ -0,0 +1,43 @@
package acmedns
import (
"encoding/json"
"fmt"
"go.uber.org/zap"
)
func SetupLogging(config AcmeDnsConfig) (*zap.Logger, error) {
var logger *zap.Logger
logformat := "console"
if config.Logconfig.Format == "json" {
logformat = "json"
}
outputPath := "stdout"
if config.Logconfig.Logtype == "file" {
outputPath = config.Logconfig.File
}
errorPath := "stderr"
if config.Logconfig.Logtype == "file" {
errorPath = config.Logconfig.File
}
zapConfigJson := fmt.Sprintf(`{
"level": "%s",
"encoding": "%s",
"outputPaths": ["%s"],
"errorOutputPaths": ["%s"],
"encoderConfig": {
"timeKey": "time",
"messageKey": "msg",
"levelKey": "level",
"levelEncoder": "lowercase",
"timeEncoder": "iso8601"
}
}`, config.Logconfig.Level, logformat, outputPath, errorPath)
var zapCfg zap.Config
if err := json.Unmarshal([]byte(zapConfigJson), &zapCfg); err != nil {
return logger, err
}
logger, err := zapCfg.Build()
return logger, err
}

View File

@ -2,6 +2,7 @@ package acmedns
import ( import (
"crypto/rand" "crypto/rand"
"golang.org/x/crypto/bcrypt"
"math/big" "math/big"
"regexp" "regexp"
) )
@ -29,3 +30,10 @@ func generatePassword(length int) string {
} }
return string(ret) return string(ret)
} }
func CorrectPassword(pw string, hash string) bool {
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pw)); err == nil {
return true
}
return false
}

View File

@ -1,6 +1,8 @@
package main package acmedns
import ( import (
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"io/ioutil" "io/ioutil"
"os" "os"
"syscall" "syscall"
@ -9,33 +11,67 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
func fakeConfig() AcmeDnsConfig {
conf := AcmeDnsConfig{}
conf.Logconfig.Logtype = "stdout"
return conf
}
func TestSetupLogging(t *testing.T) { func TestSetupLogging(t *testing.T) {
conf := fakeConfig()
for i, test := range []struct { for i, test := range []struct {
format string format string
level string level string
expected string expected zapcore.Level
}{ }{
{"text", "warning", "warning"}, {"text", "warn", zap.WarnLevel},
{"json", "debug", "debug"}, {"json", "debug", zap.DebugLevel},
{"text", "info", "info"}, {"text", "info", zap.InfoLevel},
{"json", "error", "error"}, {"json", "error", zap.ErrorLevel},
{"text", "something", "warning"},
} { } {
setupLogging(test.format, test.level) conf.Logconfig.Format = test.format
if log.GetLevel().String() != test.expected { conf.Logconfig.Level = test.level
logger, err := SetupLogging(conf)
if err != nil {
t.Errorf("Got unexpected error: %s", err)
} else {
if logger.Sugar().Level() != test.expected {
t.Errorf("Test %d: Expected loglevel %s but got %s", i, test.expected, log.GetLevel().String()) t.Errorf("Test %d: Expected loglevel %s but got %s", i, test.expected, log.GetLevel().String())
} }
} }
} }
}
func TestSetupLoggingError(t *testing.T) {
conf := fakeConfig()
for _, test := range []struct {
format string
level string
errexpected bool
}{
{"text", "warn", false},
{"json", "debug", false},
{"text", "info", false},
{"json", "error", false},
{"text", "something", true},
} {
conf.Logconfig.Format = test.format
conf.Logconfig.Level = test.level
_, err := SetupLogging(conf)
if test.errexpected && err == nil {
t.Errorf("Expected error but did not get one for loglevel: %s", err)
}
}
}
func TestReadConfig(t *testing.T) { func TestReadConfig(t *testing.T) {
for i, test := range []struct { for i, test := range []struct {
inFile []byte inFile []byte
output DNSConfig output AcmeDnsConfig
}{ }{
{ {
[]byte("[general]\nlisten = \":53\"\ndebug = true\n[api]\napi_domain = \"something.strange\""), []byte("[general]\nlisten = \":53\"\ndebug = true\n[api]\napi_domain = \"something.strange\""),
DNSConfig{ AcmeDnsConfig{
General: general{ General: general{
Listen: ":53", Listen: ":53",
Debug: true, Debug: true,
@ -48,10 +84,10 @@ func TestReadConfig(t *testing.T) {
{ {
[]byte("[\x00[[[[[[[[[de\nlisten =]"), []byte("[\x00[[[[[[[[[de\nlisten =]"),
DNSConfig{}, AcmeDnsConfig{},
}, },
} { } {
tmpfile, err := ioutil.TempFile("", "acmedns") tmpfile, err := os.CreateTemp("", "acmedns")
if err != nil { if err != nil {
t.Error("Could not create temporary file") t.Error("Could not create temporary file")
} }
@ -64,7 +100,7 @@ func TestReadConfig(t *testing.T) {
if err := tmpfile.Close(); err != nil { if err := tmpfile.Close(); err != nil {
t.Error("Could not close temporary file") t.Error("Could not close temporary file")
} }
ret, _ := readConfig(tmpfile.Name()) ret, _, _ := ReadConfig(tmpfile.Name())
if ret.General.Listen != test.output.General.Listen { if ret.General.Listen != test.output.General.Listen {
t.Errorf("Test %d: Expected listen value %s, but got %s", i, test.output.General.Listen, ret.General.Listen) t.Errorf("Test %d: Expected listen value %s, but got %s", i, test.output.General.Listen, ret.General.Listen)
} }
@ -74,30 +110,6 @@ func TestReadConfig(t *testing.T) {
} }
} }
func TestGetIPListFromHeader(t *testing.T) {
for i, test := range []struct {
input string
output []string
}{
{"1.1.1.1, 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
{" 1.1.1.1 , 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
{",1.1.1.1 ,2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
} {
res := getIPListFromHeader(test.input)
if len(res) != len(test.output) {
t.Errorf("Test %d: Expected [%d] items in return list, but got [%d]", i, len(test.output), len(res))
} else {
for j, vv := range test.output {
if res[j] != vv {
t.Errorf("Test %d: Expected return value [%v] but got [%v]", j, test.output, res)
}
}
}
}
}
func TestFileCheckPermissionDenied(t *testing.T) { func TestFileCheckPermissionDenied(t *testing.T) {
tmpfile, err := ioutil.TempFile("", "acmedns") tmpfile, err := ioutil.TempFile("", "acmedns")
if err != nil { if err != nil {
@ -105,14 +117,14 @@ func TestFileCheckPermissionDenied(t *testing.T) {
} }
defer os.Remove(tmpfile.Name()) defer os.Remove(tmpfile.Name())
_ = syscall.Chmod(tmpfile.Name(), 0000) _ = syscall.Chmod(tmpfile.Name(), 0000)
if fileIsAccessible(tmpfile.Name()) { if FileIsAccessible(tmpfile.Name()) {
t.Errorf("File should not be accessible") t.Errorf("File should not be accessible")
} }
_ = syscall.Chmod(tmpfile.Name(), 0644) _ = syscall.Chmod(tmpfile.Name(), 0644)
} }
func TestFileCheckNotExists(t *testing.T) { func TestFileCheckNotExists(t *testing.T) {
if fileIsAccessible("/path/that/does/not/exist") { if FileIsAccessible("/path/that/does/not/exist") {
t.Errorf("File should not be accessible") t.Errorf("File should not be accessible")
} }
} }
@ -123,19 +135,19 @@ func TestFileCheckOK(t *testing.T) {
t.Error("Could not create temporary file") t.Error("Could not create temporary file")
} }
defer os.Remove(tmpfile.Name()) defer os.Remove(tmpfile.Name())
if !fileIsAccessible(tmpfile.Name()) { if !FileIsAccessible(tmpfile.Name()) {
t.Errorf("File should be accessible") t.Errorf("File should be accessible")
} }
} }
func TestPrepareConfig(t *testing.T) { func TestPrepareConfig(t *testing.T) {
for i, test := range []struct { for i, test := range []struct {
input DNSConfig input AcmeDnsConfig
shoulderror bool shoulderror bool
}{ }{
{DNSConfig{Database: dbsettings{Engine: "whatever", Connection: "whatever_too"}}, false}, {AcmeDnsConfig{Database: dbsettings{Engine: "whatever", Connection: "whatever_too"}}, false},
{DNSConfig{Database: dbsettings{Engine: "", Connection: "whatever_too"}}, true}, {AcmeDnsConfig{Database: dbsettings{Engine: "", Connection: "whatever_too"}}, true},
{DNSConfig{Database: dbsettings{Engine: "whatever", Connection: ""}}, true}, {AcmeDnsConfig{Database: dbsettings{Engine: "whatever", Connection: ""}}, true},
} { } {
_, err := prepareConfig(test.input) _, err := prepareConfig(test.input)
if test.shoulderror { if test.shoulderror {

View File

@ -1,4 +1,4 @@
package main package api
import ( import (
"context" "context"
@ -8,17 +8,29 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/acme-dns/acme-dns/pkg/acmedns"
"github.com/acme-dns/acme-dns/pkg/database"
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
"github.com/gavv/httpexpect" "github.com/gavv/httpexpect"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"github.com/rs/cors" "github.com/rs/cors"
"go.uber.org/zap"
) )
func fakeConfigAndLogger() (acmedns.AcmeDnsConfig, *zap.SugaredLogger) {
c := acmedns.AcmeDnsConfig{}
c.Database.Engine = "sqlite"
c.Database.Connection = ":memory:"
l := zap.NewNop().Sugar()
return c, l
}
// noAuth function to write ACMETxt model to context while not preforming any validation // noAuth function to write ACMETxt model to context while not preforming any validation
func noAuth(update httprouter.Handle) httprouter.Handle { func noAuth(update httprouter.Handle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
postData := ACMETxt{} postData := acmedns.ACMETxt{}
uname := r.Header.Get("X-Api-User") uname := r.Header.Get("X-Api-User")
passwd := r.Header.Get("X-Api-Key") passwd := r.Header.Get("X-Api-Key")
@ -46,42 +58,37 @@ func getExpect(t *testing.T, server *httptest.Server) *httpexpect.Expect {
}) })
} }
func setupRouter(debug bool, noauth bool) http.Handler { func setupRouter(debug bool, noauth bool) (http.Handler, AcmednsAPI, acmedns.AcmednsDB) {
api := httprouter.New() api := httprouter.New()
var dbcfg = dbsettings{ config, logger := fakeConfigAndLogger()
Engine: "sqlite3", config.API.Domain = ""
Connection: ":memory:"} config.API.Port = "8080"
var httpapicfg = httpapi{ config.API.TLS = "none"
Domain: "", config.API.CorsOrigins = []string{"*"}
Port: "8080", config.API.UseHeader = true
TLS: "none", config.API.HeaderName = "X-Forwarded-For"
CorsOrigins: []string{"*"},
UseHeader: true, db, _ := database.Init(&config, logger)
HeaderName: "X-Forwarded-For", errChan := make(chan error, 1)
} adnsapi := Init(&config, db, logger, errChan)
var dnscfg = DNSConfig{
API: httpapicfg,
Database: dbcfg,
}
Config = dnscfg
c := cors.New(cors.Options{ c := cors.New(cors.Options{
AllowedOrigins: Config.API.CorsOrigins, AllowedOrigins: config.API.CorsOrigins,
AllowedMethods: []string{"GET", "POST"}, AllowedMethods: []string{"GET", "POST"},
OptionsPassthrough: false, OptionsPassthrough: false,
Debug: Config.General.Debug, Debug: config.General.Debug,
}) })
api.POST("/register", webRegisterPost) api.POST("/register", adnsapi.webRegisterPost)
api.GET("/health", healthCheck) api.GET("/health", adnsapi.healthCheck)
if noauth { if noauth {
api.POST("/update", noAuth(webUpdatePost)) api.POST("/update", noAuth(adnsapi.webUpdatePost))
} else { } else {
api.POST("/update", Auth(webUpdatePost)) api.POST("/update", adnsapi.Auth(adnsapi.webUpdatePost))
} }
return c.Handler(api) return c.Handler(api), adnsapi, db
} }
func TestApiRegister(t *testing.T) { func TestApiRegister(t *testing.T) {
router := setupRouter(false, false) router, _, _ := setupRouter(false, false)
server := httptest.NewServer(router) server := httptest.NewServer(router)
defer server.Close() defer server.Close()
e := getExpect(t, server) e := getExpect(t, server)
@ -117,7 +124,7 @@ func TestApiRegister(t *testing.T) {
} }
func TestApiRegisterBadAllowFrom(t *testing.T) { func TestApiRegisterBadAllowFrom(t *testing.T) {
router := setupRouter(false, false) router, _, _ := setupRouter(false, false)
server := httptest.NewServer(router) server := httptest.NewServer(router)
defer server.Close() defer server.Close()
e := getExpect(t, server) e := getExpect(t, server)
@ -147,7 +154,7 @@ func TestApiRegisterBadAllowFrom(t *testing.T) {
} }
func TestApiRegisterMalformedJSON(t *testing.T) { func TestApiRegisterMalformedJSON(t *testing.T) {
router := setupRouter(false, false) router, _, _ := setupRouter(false, false)
server := httptest.NewServer(router) server := httptest.NewServer(router)
defer server.Close() defer server.Close()
e := getExpect(t, server) e := getExpect(t, server)
@ -174,13 +181,13 @@ func TestApiRegisterMalformedJSON(t *testing.T) {
} }
func TestApiRegisterWithMockDB(t *testing.T) { func TestApiRegisterWithMockDB(t *testing.T) {
router := setupRouter(false, false) router, _, db := setupRouter(false, false)
server := httptest.NewServer(router) server := httptest.NewServer(router)
defer server.Close() defer server.Close()
e := getExpect(t, server) e := getExpect(t, server)
oldDb := DB.GetBackend() oldDb := db.GetBackend()
db, mock, _ := sqlmock.New() mdb, mock, _ := sqlmock.New()
DB.SetBackend(db) db.SetBackend(mdb)
defer db.Close() defer db.Close()
mock.ExpectBegin() mock.ExpectBegin()
mock.ExpectPrepare("INSERT INTO records").WillReturnError(errors.New("error")) mock.ExpectPrepare("INSERT INTO records").WillReturnError(errors.New("error"))
@ -188,7 +195,7 @@ func TestApiRegisterWithMockDB(t *testing.T) {
Status(http.StatusInternalServerError). Status(http.StatusInternalServerError).
JSON().Object(). JSON().Object().
ContainsKey("error") ContainsKey("error")
DB.SetBackend(oldDb) db.SetBackend(oldDb)
} }
func TestApiUpdateWithInvalidSubdomain(t *testing.T) { func TestApiUpdateWithInvalidSubdomain(t *testing.T) {
@ -198,11 +205,11 @@ func TestApiUpdateWithInvalidSubdomain(t *testing.T) {
"subdomain": "", "subdomain": "",
"txt": ""} "txt": ""}
router := setupRouter(false, false) router, _, db := setupRouter(false, false)
server := httptest.NewServer(router) server := httptest.NewServer(router)
defer server.Close() defer server.Close()
e := getExpect(t, server) e := getExpect(t, server)
newUser, err := DB.Register(cidrslice{}) newUser, err := db.Register(acmedns.Cidrslice{})
if err != nil { if err != nil {
t.Errorf("Could not create new user, got error [%v]", err) t.Errorf("Could not create new user, got error [%v]", err)
} }
@ -228,11 +235,11 @@ func TestApiUpdateWithInvalidTxt(t *testing.T) {
"subdomain": "", "subdomain": "",
"txt": ""} "txt": ""}
router := setupRouter(false, false) router, _, db := setupRouter(false, false)
server := httptest.NewServer(router) server := httptest.NewServer(router)
defer server.Close() defer server.Close()
e := getExpect(t, server) e := getExpect(t, server)
newUser, err := DB.Register(cidrslice{}) newUser, err := db.Register(acmedns.Cidrslice{})
if err != nil { if err != nil {
t.Errorf("Could not create new user, got error [%v]", err) t.Errorf("Could not create new user, got error [%v]", err)
} }
@ -252,7 +259,7 @@ func TestApiUpdateWithInvalidTxt(t *testing.T) {
} }
func TestApiUpdateWithoutCredentials(t *testing.T) { func TestApiUpdateWithoutCredentials(t *testing.T) {
router := setupRouter(false, false) router, _, _ := setupRouter(false, false)
server := httptest.NewServer(router) server := httptest.NewServer(router)
defer server.Close() defer server.Close()
e := getExpect(t, server) e := getExpect(t, server)
@ -270,11 +277,11 @@ func TestApiUpdateWithCredentials(t *testing.T) {
"subdomain": "", "subdomain": "",
"txt": ""} "txt": ""}
router := setupRouter(false, false) router, _, db := setupRouter(false, false)
server := httptest.NewServer(router) server := httptest.NewServer(router)
defer server.Close() defer server.Close()
e := getExpect(t, server) e := getExpect(t, server)
newUser, err := DB.Register(cidrslice{}) newUser, err := db.Register(acmedns.Cidrslice{})
if err != nil { if err != nil {
t.Errorf("Could not create new user, got error [%v]", err) t.Errorf("Could not create new user, got error [%v]", err)
} }
@ -303,13 +310,13 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
updateJSON["subdomain"] = "a097455b-52cc-4569-90c8-7a4b97c6eba8" updateJSON["subdomain"] = "a097455b-52cc-4569-90c8-7a4b97c6eba8"
updateJSON["txt"] = validTxtData updateJSON["txt"] = validTxtData
router := setupRouter(false, true) router, _, db := setupRouter(false, true)
server := httptest.NewServer(router) server := httptest.NewServer(router)
defer server.Close() defer server.Close()
e := getExpect(t, server) e := getExpect(t, server)
oldDb := DB.GetBackend() oldDb := db.GetBackend()
db, mock, _ := sqlmock.New() mdb, mock, _ := sqlmock.New()
DB.SetBackend(db) db.SetBackend(mdb)
defer db.Close() defer db.Close()
mock.ExpectBegin() mock.ExpectBegin()
mock.ExpectPrepare("UPDATE records").WillReturnError(errors.New("error")) mock.ExpectPrepare("UPDATE records").WillReturnError(errors.New("error"))
@ -319,31 +326,31 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
Status(http.StatusInternalServerError). Status(http.StatusInternalServerError).
JSON().Object(). JSON().Object().
ContainsKey("error") ContainsKey("error")
DB.SetBackend(oldDb) db.SetBackend(oldDb)
} }
func TestApiManyUpdateWithCredentials(t *testing.T) { func TestApiManyUpdateWithCredentials(t *testing.T) {
validTxtData := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" validTxtData := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
router := setupRouter(true, false) router, _, db := setupRouter(true, false)
server := httptest.NewServer(router) server := httptest.NewServer(router)
defer server.Close() defer server.Close()
e := getExpect(t, server) e := getExpect(t, server)
// User without defined CIDR masks // User without defined CIDR masks
newUser, err := DB.Register(cidrslice{}) newUser, err := db.Register(acmedns.Cidrslice{})
if err != nil { if err != nil {
t.Errorf("Could not create new user, got error [%v]", err) t.Errorf("Could not create new user, got error [%v]", err)
} }
// User with defined allow from - CIDR masks, all invalid // User with defined allow from - CIDR masks, all invalid
// (httpexpect doesn't provide a way to mock remote ip) // (httpexpect doesn't provide a way to mock remote ip)
newUserWithCIDR, err := DB.Register(cidrslice{"192.168.1.1/32", "invalid"}) newUserWithCIDR, err := db.Register(acmedns.Cidrslice{"192.168.1.1/32", "invalid"})
if err != nil { if err != nil {
t.Errorf("Could not create new user with CIDR, got error [%v]", err) t.Errorf("Could not create new user with CIDR, got error [%v]", err)
} }
// Another user with valid CIDR mask to match the httpexpect default // Another user with valid CIDR mask to match the httpexpect default
newUserWithValidCIDR, err := DB.Register(cidrslice{"10.1.2.3/32", "invalid"}) newUserWithValidCIDR, err := db.Register(acmedns.Cidrslice{"10.1.2.3/32", "invalid"})
if err != nil { if err != nil {
t.Errorf("Could not create new user with a valid CIDR, got error [%v]", err) t.Errorf("Could not create new user with a valid CIDR, got error [%v]", err)
} }
@ -381,30 +388,30 @@ func TestApiManyUpdateWithCredentials(t *testing.T) {
func TestApiManyUpdateWithIpCheckHeaders(t *testing.T) { func TestApiManyUpdateWithIpCheckHeaders(t *testing.T) {
router := setupRouter(false, false) router, adnsapi, db := setupRouter(false, false)
server := httptest.NewServer(router) server := httptest.NewServer(router)
defer server.Close() defer server.Close()
e := getExpect(t, server) e := getExpect(t, server)
// Use header checks from default header (X-Forwarded-For) // Use header checks from default header (X-Forwarded-For)
Config.API.UseHeader = true adnsapi.Config.API.UseHeader = true
// User without defined CIDR masks // User without defined CIDR masks
newUser, err := DB.Register(cidrslice{}) newUser, err := db.Register(acmedns.Cidrslice{})
if err != nil { if err != nil {
t.Errorf("Could not create new user, got error [%v]", err) t.Errorf("Could not create new user, got error [%v]", err)
} }
newUserWithCIDR, err := DB.Register(cidrslice{"192.168.1.2/32", "invalid"}) newUserWithCIDR, err := db.Register(acmedns.Cidrslice{"192.168.1.2/32", "invalid"})
if err != nil { if err != nil {
t.Errorf("Could not create new user with CIDR, got error [%v]", err) t.Errorf("Could not create new user with CIDR, got error [%v]", err)
} }
newUserWithIP6CIDR, err := DB.Register(cidrslice{"2002:c0a8::0/32"}) newUserWithIP6CIDR, err := db.Register(acmedns.Cidrslice{"2002:c0a8::0/32"})
if err != nil { if err != nil {
t.Errorf("Could not create a new user with IP6 CIDR, got error [%v]", err) t.Errorf("Could not create a new user with IP6 CIDR, got error [%v]", err)
} }
for _, test := range []struct { for _, test := range []struct {
user ACMETxt user acmedns.ACMETxt
headerValue string headerValue string
status int status int
}{ }{
@ -428,13 +435,66 @@ func TestApiManyUpdateWithIpCheckHeaders(t *testing.T) {
Expect(). Expect().
Status(test.status) Status(test.status)
} }
Config.API.UseHeader = false adnsapi.Config.API.UseHeader = false
} }
func TestApiHealthCheck(t *testing.T) { func TestApiHealthCheck(t *testing.T) {
router := setupRouter(false, false) router, _, _ := setupRouter(false, false)
server := httptest.NewServer(router) server := httptest.NewServer(router)
defer server.Close() defer server.Close()
e := getExpect(t, server) e := getExpect(t, server)
e.GET("/health").Expect().Status(http.StatusOK) e.GET("/health").Expect().Status(http.StatusOK)
} }
func TestGetIPListFromHeader(t *testing.T) {
for i, test := range []struct {
input string
output []string
}{
{"1.1.1.1, 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
{" 1.1.1.1 , 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
{",1.1.1.1 ,2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
} {
res := getIPListFromHeader(test.input)
if len(res) != len(test.output) {
t.Errorf("Test %d: Expected [%d] items in return list, but got [%d]", i, len(test.output), len(res))
} else {
for j, vv := range test.output {
if res[j] != vv {
t.Errorf("Test %d: Expected return value [%v] but got [%v]", j, test.output, res)
}
}
}
}
}
func TestUpdateAllowedFromIP(t *testing.T) {
_, adnsapi, _ := setupRouter(false, false)
adnsapi.Config.API.UseHeader = false
userWithAllow := acmedns.NewACMETxt()
userWithAllow.AllowFrom = acmedns.Cidrslice{"192.168.1.2/32", "[::1]/128"}
userWithoutAllow := acmedns.NewACMETxt()
for i, test := range []struct {
remoteaddr string
expected bool
}{
{"192.168.1.2:1234", true},
{"192.168.1.1:1234", false},
{"invalid", false},
{"[::1]:4567", true},
} {
newreq, _ := http.NewRequest("GET", "/whatever", nil)
newreq.RemoteAddr = test.remoteaddr
ret := adnsapi.updateAllowedFromIP(newreq, userWithAllow)
if test.expected != ret {
t.Errorf("Test %d: Unexpected result for user with allowForm set", i)
}
if !adnsapi.updateAllowedFromIP(newreq, userWithoutAllow) {
t.Errorf("Test %d: Unexpected result for user without allowForm set", i)
}
}
}

View File

@ -74,11 +74,11 @@ func (a *AcmednsAPI) getUserFromRequest(r *http.Request) (acmedns.ACMETxt, error
a.Logger.Errorw("Error while trying to get user", a.Logger.Errorw("Error while trying to get user",
"error", err.Error()) "error", err.Error())
// To protect against timed side channel (never gonna give you up) // To protect against timed side channel (never gonna give you up)
correctPassword(passwd, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36") acmedns.CorrectPassword(passwd, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36")
return acmedns.ACMETxt{}, fmt.Errorf("invalid username: %s", uname) return acmedns.ACMETxt{}, fmt.Errorf("invalid username: %s", uname)
} }
if correctPassword(passwd, dbuser.Password) { if acmedns.CorrectPassword(passwd, dbuser.Password) {
return dbuser, nil return dbuser, nil
} }
return acmedns.ACMETxt{}, fmt.Errorf("invalid password for user %s", uname) return acmedns.ACMETxt{}, fmt.Errorf("invalid password for user %s", uname)

View File

@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"github.com/acme-dns/acme-dns/pkg/acmedns" "github.com/acme-dns/acme-dns/pkg/acmedns"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"regexp" "regexp"
"strings" "strings"
"unicode/utf8" "unicode/utf8"
@ -31,13 +30,6 @@ func validKey(k string) bool {
return false return false
} }
func correctPassword(pw string, hash string) bool {
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pw)); err == nil {
return true
}
return false
}
func getIPListFromHeader(header string) []string { func getIPListFromHeader(header string) []string {
iplist := []string{} iplist := []string{}
for _, v := range strings.Split(header, ",") { for _, v := range strings.Split(header, ",") {

View File

@ -1,6 +1,7 @@
package main package api
import ( import (
"github.com/acme-dns/acme-dns/pkg/acmedns"
"testing" "testing"
"github.com/google/uuid" "github.com/google/uuid"
@ -103,7 +104,7 @@ func TestCorrectPassword(t *testing.T) {
false}, false},
{"", "", false}, {"", "", false},
} { } {
ret := correctPassword(test.pw, test.hash) ret := acmedns.CorrectPassword(test.pw, test.hash)
if ret != test.output { if ret != test.output {
t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret) t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret)
} }
@ -112,12 +113,12 @@ func TestCorrectPassword(t *testing.T) {
func TestGetValidCIDRMasks(t *testing.T) { func TestGetValidCIDRMasks(t *testing.T) {
for i, test := range []struct { for i, test := range []struct {
input cidrslice input acmedns.Cidrslice
output cidrslice output acmedns.Cidrslice
}{ }{
{cidrslice{"10.0.0.1/24"}, cidrslice{"10.0.0.1/24"}}, {acmedns.Cidrslice{"10.0.0.1/24"}, acmedns.Cidrslice{"10.0.0.1/24"}},
{cidrslice{"invalid", "127.0.0.1/32"}, cidrslice{"127.0.0.1/32"}}, {acmedns.Cidrslice{"invalid", "127.0.0.1/32"}, acmedns.Cidrslice{"127.0.0.1/32"}},
{cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}, cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}}, {acmedns.Cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}, acmedns.Cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}},
} { } {
ret := test.input.ValidEntries() ret := test.input.ValidEntries()
if len(ret) == len(test.output) { if len(ret) == len(test.output) {

View File

@ -1,10 +1,12 @@
package main package database
import ( import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"errors" "errors"
"github.com/acme-dns/acme-dns/pkg/acmedns"
"github.com/erikstmartin/go-testdb" "github.com/erikstmartin/go-testdb"
"go.uber.org/zap"
"testing" "testing"
) )
@ -21,42 +23,38 @@ func (r testResult) RowsAffected() (int64, error) {
return r.affectedRows, nil return r.affectedRows, nil
} }
func TestDBInit(t *testing.T) { func fakeConfigAndLogger() (acmedns.AcmeDnsConfig, *zap.SugaredLogger) {
fakeDB := new(acmedb) c := acmedns.AcmeDnsConfig{}
err := fakeDB.Init("notarealegine", "connectionstring") c.Database.Engine = "sqlite"
if err == nil { c.Database.Connection = ":memory:"
t.Errorf("Was expecting error, didn't get one.") l := zap.NewNop().Sugar()
return c, l
} }
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) { func fakeDB() acmedns.AcmednsDB {
return testResult{1, 0}, errors.New("Prepared query error") conf, logger := fakeConfigAndLogger()
}) db, _ := Init(&conf, logger)
defer testdb.Reset() return db
errorDB := new(acmedb)
err = errorDB.Init("testdb", "")
if err == nil {
t.Errorf("Was expecting DB initiation error but got none")
}
errorDB.Close()
} }
func TestRegisterNoCIDR(t *testing.T) { func TestRegisterNoCIDR(t *testing.T) {
// Register tests // Register tests
_, err := DB.Register(cidrslice{}) DB := fakeDB()
_, err := DB.Register(acmedns.Cidrslice{})
if err != nil { if err != nil {
t.Errorf("Registration failed, got error [%v]", err) t.Errorf("Registration failed, got error [%v]", err)
} }
} }
func TestRegisterMany(t *testing.T) { func TestRegisterMany(t *testing.T) {
DB := fakeDB()
for i, test := range []struct { for i, test := range []struct {
input cidrslice input acmedns.Cidrslice
output cidrslice output acmedns.Cidrslice
}{ }{
{cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}, cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}}, {acmedns.Cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}, acmedns.Cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}},
{cidrslice{"1.1.1./32", "1922.168.42.42/8", "1.1.1.1/33", "1.2.3.4/"}, cidrslice{}}, {acmedns.Cidrslice{"1.1.1./32", "1922.168.42.42/8", "1.1.1.1/33", "1.2.3.4/"}, acmedns.Cidrslice{}},
{cidrslice{"7.6.5.4/32", "invalid", "1.0.0.1/2"}, cidrslice{"7.6.5.4/32", "1.0.0.1/2"}}, {acmedns.Cidrslice{"7.6.5.4/32", "invalid", "1.0.0.1/2"}, acmedns.Cidrslice{"7.6.5.4/32", "1.0.0.1/2"}},
} { } {
user, err := DB.Register(test.input) user, err := DB.Register(test.input)
if err != nil { if err != nil {
@ -77,8 +75,9 @@ func TestRegisterMany(t *testing.T) {
} }
func TestGetByUsername(t *testing.T) { func TestGetByUsername(t *testing.T) {
DB := fakeDB()
// Create reg to refer to // Create reg to refer to
reg, err := DB.Register(cidrslice{}) reg, err := DB.Register(acmedns.Cidrslice{})
if err != nil { if err != nil {
t.Errorf("Registration failed, got error [%v]", err) t.Errorf("Registration failed, got error [%v]", err)
} }
@ -97,13 +96,14 @@ func TestGetByUsername(t *testing.T) {
} }
// regUser password already is a bcrypt hash // regUser password already is a bcrypt hash
if !correctPassword(reg.Password, regUser.Password) { if !acmedns.CorrectPassword(reg.Password, regUser.Password) {
t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regUser.Password) t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regUser.Password)
} }
} }
func TestPrepareErrors(t *testing.T) { func TestPrepareErrors(t *testing.T) {
reg, _ := DB.Register(cidrslice{}) DB := fakeDB()
reg, _ := DB.Register(acmedns.Cidrslice{})
tdb, err := sql.Open("testdb", "") tdb, err := sql.Open("testdb", "")
if err != nil { if err != nil {
t.Errorf("Got error: %v", err) t.Errorf("Got error: %v", err)
@ -125,7 +125,8 @@ func TestPrepareErrors(t *testing.T) {
} }
func TestQueryExecErrors(t *testing.T) { func TestQueryExecErrors(t *testing.T) {
reg, _ := DB.Register(cidrslice{}) DB := fakeDB()
reg, _ := DB.Register(acmedns.Cidrslice{})
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) { testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
return testResult{1, 0}, errors.New("Prepared query error") return testResult{1, 0}, errors.New("Prepared query error")
}) })
@ -156,7 +157,7 @@ func TestQueryExecErrors(t *testing.T) {
t.Errorf("Expected error from exec in GetByDomain, but got none") t.Errorf("Expected error from exec in GetByDomain, but got none")
} }
_, err = DB.Register(cidrslice{}) _, err = DB.Register(acmedns.Cidrslice{})
if err == nil { if err == nil {
t.Errorf("Expected error from exec in Register, but got none") t.Errorf("Expected error from exec in Register, but got none")
} }
@ -169,7 +170,8 @@ func TestQueryExecErrors(t *testing.T) {
} }
func TestQueryScanErrors(t *testing.T) { func TestQueryScanErrors(t *testing.T) {
reg, _ := DB.Register(cidrslice{}) DB := fakeDB()
reg, _ := DB.Register(acmedns.Cidrslice{})
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) { testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
return testResult{1, 0}, errors.New("Prepared query error") return testResult{1, 0}, errors.New("Prepared query error")
@ -198,7 +200,8 @@ func TestQueryScanErrors(t *testing.T) {
} }
func TestBadDBValues(t *testing.T) { func TestBadDBValues(t *testing.T) {
reg, _ := DB.Register(cidrslice{}) DB := fakeDB()
reg, _ := DB.Register(acmedns.Cidrslice{})
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) { testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"} columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"}
@ -228,8 +231,9 @@ func TestBadDBValues(t *testing.T) {
} }
func TestGetTXTForDomain(t *testing.T) { func TestGetTXTForDomain(t *testing.T) {
DB := fakeDB()
// Create reg to refer to // Create reg to refer to
reg, err := DB.Register(cidrslice{}) reg, err := DB.Register(acmedns.Cidrslice{})
if err != nil { if err != nil {
t.Errorf("Registration failed, got error [%v]", err) t.Errorf("Registration failed, got error [%v]", err)
} }
@ -276,8 +280,9 @@ func TestGetTXTForDomain(t *testing.T) {
} }
func TestUpdate(t *testing.T) { func TestUpdate(t *testing.T) {
DB := fakeDB()
// Create reg to refer to // Create reg to refer to
reg, err := DB.Register(cidrslice{}) reg, err := DB.Register(acmedns.Cidrslice{})
if err != nil { if err != nil {
t.Errorf("Registration failed, got error [%v]", err) t.Errorf("Registration failed, got error [%v]", err)
} }

View File

@ -1,20 +1,66 @@
package main package nameserver
import ( import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"errors" "errors"
"fmt" "fmt"
"testing" "github.com/acme-dns/acme-dns/pkg/acmedns"
"github.com/acme-dns/acme-dns/pkg/database"
"github.com/erikstmartin/go-testdb" "github.com/erikstmartin/go-testdb"
"github.com/miekg/dns" "github.com/miekg/dns"
"go.uber.org/zap"
"go.uber.org/zap/zaptest/observer"
"sync"
"testing"
) )
type resolver struct { type resolver struct {
server string server string
} }
var records = []string{
"auth.example.org. A 192.168.1.100",
"ns1.auth.example.org. A 192.168.1.101",
"cn.example.org CNAME something.example.org.",
"!''b', unparseable ",
"ns2.auth.example.org. A 192.168.1.102",
}
func loggerHasEntryWithMessage(message string, logObserver *observer.ObservedLogs) bool {
if len(logObserver.FilterMessage(message).All()) > 0 {
return true
}
return false
}
func fakeConfigAndLogger() (acmedns.AcmeDnsConfig, *zap.SugaredLogger, *observer.ObservedLogs) {
c := acmedns.AcmeDnsConfig{}
c.Database.Engine = "sqlite"
c.Database.Connection = ":memory:"
obsCore, logObserver := observer.New(zap.DebugLevel)
obsLogger := zap.New(obsCore).Sugar()
return c, obsLogger, logObserver
}
func setupDNS() (acmedns.AcmednsNS, acmedns.AcmednsDB, *observer.ObservedLogs) {
config, logger, logObserver := fakeConfigAndLogger()
config.General.Domain = "auth.example.org"
config.General.Listen = "127.0.0.1:15353"
config.General.Proto = "udp"
config.General.Nsname = "ns1.auth.example.org"
config.General.Nsadmin = "admin.example.org"
config.General.StaticRecords = records
config.General.Debug = false
db, _ := database.Init(&config, logger)
server := Nameserver{Config: &config, DB: db, Logger: logger, personalAuthKey: ""}
server.Domains = make(map[string]Records)
server.Server = &dns.Server{Addr: config.General.Listen, Net: config.General.Proto}
server.ParseRecords()
server.OwnDomain = "auth.example.org."
return &server, db, logObserver
}
func (r *resolver) lookup(host string, qtype uint16) (*dns.Msg, error) { func (r *resolver) lookup(host string, qtype uint16) (*dns.Msg, error) {
msg := new(dns.Msg) msg := new(dns.Msg)
msg.Id = dns.Id() msg.Id = dns.Id()
@ -27,7 +73,6 @@ func (r *resolver) lookup(host string, qtype uint16) (*dns.Msg, error) {
if in != nil && in.Rcode != dns.RcodeSuccess { if in != nil && in.Rcode != dns.RcodeSuccess {
return in, fmt.Errorf("Received error from the server [%s]", dns.RcodeToString[in.Rcode]) return in, fmt.Errorf("Received error from the server [%s]", dns.RcodeToString[in.Rcode])
} }
return in, nil return in, nil
} }
@ -49,6 +94,18 @@ func hasExpectedTXTAnswer(answer []dns.RR, cmpTXT string) error {
} }
func TestQuestionDBError(t *testing.T) { func TestQuestionDBError(t *testing.T) {
config, logger, _ := fakeConfigAndLogger()
config.General.Listen = "127.0.0.1:15353"
config.General.Proto = "udp"
config.General.Domain = "auth.example.org"
config.General.Nsname = "ns1.auth.example.org"
config.General.Nsadmin = "admin.example.org"
config.General.StaticRecords = records
config.General.Debug = false
db, _ := database.Init(&config, logger)
server := Nameserver{Config: &config, DB: db, Logger: logger, personalAuthKey: ""}
server.Domains = make(map[string]Records)
server.ParseRecords()
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) { testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"} columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"}
return testdb.RowsFromSlice(columns, [][]driver.Value{}), errors.New("Prepared query error") return testdb.RowsFromSlice(columns, [][]driver.Value{}), errors.New("Prepared query error")
@ -60,48 +117,61 @@ func TestQuestionDBError(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("Got error: %v", err) t.Errorf("Got error: %v", err)
} }
oldDb := DB.GetBackend() oldDb := db.GetBackend()
DB.SetBackend(tdb) db.SetBackend(tdb)
defer DB.SetBackend(oldDb) defer db.SetBackend(oldDb)
q := dns.Question{Name: dns.Fqdn("whatever.tld"), Qtype: dns.TypeTXT, Qclass: dns.ClassINET} q := dns.Question{Name: dns.Fqdn("whatever.tld"), Qtype: dns.TypeTXT, Qclass: dns.ClassINET}
_, err = dnsserver.answerTXT(q) _, err = server.answerTXT(q)
if err == nil { if err == nil {
t.Errorf("Expected error but got none") t.Errorf("Expected error but got none")
} }
} }
func TestParse(t *testing.T) { func TestParse(t *testing.T) {
var testcfg = DNSConfig{ config, logger, logObserver := fakeConfigAndLogger()
General: general{ config.General.Listen = "127.0.0.1:15353"
Domain: ")", config.General.Proto = "udp"
Nsname: "ns1.auth.example.org", config.General.Domain = ")"
Nsadmin: "admin.example.org", config.General.Nsname = "ns1.auth.example.org"
StaticRecords: []string{}, config.General.Nsadmin = "admin.example.org"
Debug: false, config.General.StaticRecords = records
}, config.General.Debug = false
} config.General.StaticRecords = []string{}
dnsserver.ParseRecords(testcfg) db, _ := database.Init(&config, logger)
if !loggerHasEntryWithMessage("Error while adding SOA record") { server := Nameserver{Config: &config, DB: db, Logger: logger, personalAuthKey: ""}
server.Domains = make(map[string]Records)
server.ParseRecords()
if !loggerHasEntryWithMessage("Error while adding SOA record", logObserver) {
t.Errorf("Expected SOA parsing to return error, but did not find one") t.Errorf("Expected SOA parsing to return error, but did not find one")
} }
} }
func TestResolveA(t *testing.T) { func TestResolveA(t *testing.T) {
server, _, _ := setupDNS()
errChan := make(chan error, 1)
waitLock := sync.Mutex{}
waitLock.Lock()
server.SetNotifyStartedFunc(waitLock.Unlock)
go server.Start(errChan)
waitLock.Lock()
resolv := resolver{server: "127.0.0.1:15353"} resolv := resolver{server: "127.0.0.1:15353"}
answer, err := resolv.lookup("auth.example.org", dns.TypeA) answer, err := resolv.lookup("auth.example.org", dns.TypeA)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
return
} }
if len(answer.Answer) == 0 { if len(answer.Answer) == 0 {
t.Error("No answer for DNS query") t.Error("No answer for DNS query")
return
} }
_, err = resolv.lookup("nonexistent.domain.tld", dns.TypeA) _, err = resolv.lookup("nonexistent.domain.tld", dns.TypeA)
if err == nil { if err == nil {
t.Errorf("Was expecting error because of NXDOMAIN but got none") t.Errorf("Was expecting error because of NXDOMAIN but got none")
return
} }
} }
@ -195,17 +265,20 @@ func TestAuthoritative(t *testing.T) {
} }
} }
/*
func TestResolveTXT(t *testing.T) { func TestResolveTXT(t *testing.T) {
_, db, _ := setupDNS()
resolv := resolver{server: "127.0.0.1:15353"} resolv := resolver{server: "127.0.0.1:15353"}
validTXT := "______________valid_response_______________" validTXT := "______________valid_response_______________"
atxt, err := DB.Register(cidrslice{}) atxt, err := db.Register(acmedns.Cidrslice{})
if err != nil { if err != nil {
t.Errorf("Could not initiate db record: [%v]", err) t.Errorf("Could not initiate db record: [%v]", err)
return return
} }
atxt.Value = validTXT atxt.Value = validTXT
err = DB.Update(atxt.ACMETxtPost)
err = db.Update(atxt.ACMETxtPost)
if err != nil { if err != nil {
t.Errorf("Could not update db record: [%v]", err) t.Errorf("Could not update db record: [%v]", err)
return return
@ -254,7 +327,7 @@ func TestResolveTXT(t *testing.T) {
} }
} }
} }
} }*/
func TestCaseInsensitiveResolveA(t *testing.T) { func TestCaseInsensitiveResolveA(t *testing.T) {
resolv := resolver{server: "127.0.0.1:15353"} resolv := resolver{server: "127.0.0.1:15353"}

View File

@ -9,7 +9,6 @@ import (
func (n *Nameserver) handleRequest(w dns.ResponseWriter, r *dns.Msg) { func (n *Nameserver) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(r) m.SetReply(r)
// handle edns0 // handle edns0
opt := r.IsEdns0() opt := r.IsEdns0()
if opt != nil { if opt != nil {

View File

@ -1,12 +1,11 @@
package nameserver package nameserver
import ( import (
"strings"
"github.com/acme-dns/acme-dns/pkg/acmedns" "github.com/acme-dns/acme-dns/pkg/acmedns"
"github.com/miekg/dns" "github.com/miekg/dns"
"go.uber.org/zap" "go.uber.org/zap"
"strings"
"sync"
) )
// Records is a slice of ResourceRecords // Records is a slice of ResourceRecords
@ -20,6 +19,7 @@ type Nameserver struct {
Logger *zap.SugaredLogger Logger *zap.SugaredLogger
Server *dns.Server Server *dns.Server
OwnDomain string OwnDomain string
NotifyStartedFunc func()
SOA dns.RR SOA dns.RR
personalAuthKey string personalAuthKey string
Domains map[string]Records Domains map[string]Records
@ -28,8 +28,9 @@ type Nameserver struct {
func InitAndStart(config *acmedns.AcmeDnsConfig, db acmedns.AcmednsDB, logger *zap.SugaredLogger, errChan chan error) []acmedns.AcmednsNS { func InitAndStart(config *acmedns.AcmeDnsConfig, db acmedns.AcmednsDB, logger *zap.SugaredLogger, errChan chan error) []acmedns.AcmednsNS {
dnsservers := make([]acmedns.AcmednsNS, 0) dnsservers := make([]acmedns.AcmednsNS, 0)
waitLock := sync.Mutex{}
if strings.HasPrefix(config.General.Proto, "both") { if strings.HasPrefix(config.General.Proto, "both") {
// Handle the case where DNS server should be started for both udp and tcp // Handle the case where DNS server should be started for both udp and tcp
udpProto := "udp" udpProto := "udp"
tcpProto := "tcp" tcpProto := "tcp"
@ -46,13 +47,22 @@ func InitAndStart(config *acmedns.AcmeDnsConfig, db acmedns.AcmednsDB, logger *z
dnsServerTCP := NewDNSServer(config, db, logger, tcpProto) dnsServerTCP := NewDNSServer(config, db, logger, tcpProto)
dnsservers = append(dnsservers, dnsServerTCP) dnsservers = append(dnsservers, dnsServerTCP)
dnsServerTCP.ParseRecords() dnsServerTCP.ParseRecords()
// wait for the server to get started to proceed
waitLock.Lock()
dnsServerUDP.SetNotifyStartedFunc(waitLock.Unlock)
go dnsServerUDP.Start(errChan) go dnsServerUDP.Start(errChan)
waitLock.Lock()
dnsServerTCP.SetNotifyStartedFunc(waitLock.Unlock)
go dnsServerTCP.Start(errChan) go dnsServerTCP.Start(errChan)
waitLock.Lock()
} else { } else {
dnsServer := NewDNSServer(config, db, logger, config.General.Proto) dnsServer := NewDNSServer(config, db, logger, config.General.Proto)
dnsservers = append(dnsservers, dnsServer) dnsservers = append(dnsservers, dnsServer)
dnsServer.ParseRecords() dnsServer.ParseRecords()
waitLock.Lock()
dnsServer.SetNotifyStartedFunc(waitLock.Unlock)
go dnsServer.Start(errChan) go dnsServer.Start(errChan)
waitLock.Lock()
} }
return dnsservers return dnsservers
} }
@ -67,7 +77,6 @@ func NewDNSServer(config *acmedns.AcmeDnsConfig, db acmedns.AcmednsDB, logger *z
domain = domain + "." domain = domain + "."
} }
server.OwnDomain = strings.ToLower(domain) server.OwnDomain = strings.ToLower(domain)
server.DB = db
server.personalAuthKey = "" server.personalAuthKey = ""
server.Domains = make(map[string]Records) server.Domains = make(map[string]Records)
return &server return &server
@ -79,8 +88,15 @@ func (n *Nameserver) Start(errorChannel chan error) {
n.Logger.Infow("Starting DNS listener", n.Logger.Infow("Starting DNS listener",
"addr", n.Server.Addr, "addr", n.Server.Addr,
"proto", n.Server.Net) "proto", n.Server.Net)
if n.NotifyStartedFunc != nil {
n.Server.NotifyStartedFunc = n.NotifyStartedFunc
}
err := n.Server.ListenAndServe() err := n.Server.ListenAndServe()
if err != nil { if err != nil {
errorChannel <- err errorChannel <- err
} }
} }
func (n *Nameserver) SetNotifyStartedFunc(fun func()) {
n.Server.NotifyStartedFunc = fun
}