Re-added tests
This commit is contained in:
parent
1405e6ab47
commit
157241994f
34
auth_test.go
34
auth_test.go
@ -1,34 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUpdateAllowedFromIP(t *testing.T) {
|
||||
Config.API.UseHeader = false
|
||||
userWithAllow := newACMETxt()
|
||||
userWithAllow.AllowFrom = cidrslice{"192.168.1.2/32", "[::1]/128"}
|
||||
userWithoutAllow := newACMETxt()
|
||||
|
||||
for i, test := range []struct {
|
||||
remoteaddr string
|
||||
expected bool
|
||||
}{
|
||||
{"192.168.1.2:1234", true},
|
||||
{"192.168.1.1:1234", false},
|
||||
{"invalid", false},
|
||||
{"[::1]:4567", true},
|
||||
} {
|
||||
newreq, _ := http.NewRequest("GET", "/whatever", nil)
|
||||
newreq.RemoteAddr = test.remoteaddr
|
||||
ret := updateAllowedFromIP(newreq, userWithAllow)
|
||||
if test.expected != ret {
|
||||
t.Errorf("Test %d: Unexpected result for user with allowForm set", i)
|
||||
}
|
||||
|
||||
if !updateAllowedFromIP(newreq, userWithoutAllow) {
|
||||
t.Errorf("Test %d: Unexpected result for user without allowForm set", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
60
main.go
60
main.go
@ -1,7 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/acme-dns/acme-dns/pkg/api"
|
||||
@ -15,60 +14,6 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func setupLogging(config acmedns.AcmeDnsConfig) (*zap.Logger, error) {
|
||||
var logger *zap.Logger
|
||||
logformat := "console"
|
||||
if config.Logconfig.Format == "json" {
|
||||
logformat = "json"
|
||||
}
|
||||
outputPath := "stdout"
|
||||
if config.Logconfig.Logtype == "file" {
|
||||
outputPath = config.Logconfig.File
|
||||
}
|
||||
errorPath := "stderr"
|
||||
if config.Logconfig.Logtype == "file" {
|
||||
errorPath = config.Logconfig.File
|
||||
}
|
||||
zapConfigJson := fmt.Sprintf(`{
|
||||
"level": "%s",
|
||||
"encoding": "%s",
|
||||
"outputPaths": ["%s"],
|
||||
"errorOutputPaths": ["%s"],
|
||||
"encoderConfig": {
|
||||
"timeKey": "time",
|
||||
"messageKey": "msg",
|
||||
"levelKey": "level",
|
||||
"levelEncoder": "lowercase",
|
||||
"timeEncoder": "iso8601"
|
||||
}
|
||||
}`, config.Logconfig.Level, logformat, outputPath, errorPath)
|
||||
var zapCfg zap.Config
|
||||
if err := json.Unmarshal([]byte(zapConfigJson), &zapCfg); err != nil {
|
||||
return logger, err
|
||||
}
|
||||
logger, err := zapCfg.Build()
|
||||
return logger, err
|
||||
}
|
||||
|
||||
func readConfig(configFile string) (acmedns.AcmeDnsConfig, string, error) {
|
||||
var usedConfigFile string
|
||||
var config acmedns.AcmeDnsConfig
|
||||
var err error
|
||||
if acmedns.FileIsAccessible(configFile) {
|
||||
usedConfigFile = configFile
|
||||
config, err = acmedns.ReadConfig(configFile)
|
||||
} else if acmedns.FileIsAccessible("./config.cfg") {
|
||||
usedConfigFile = "./config.cfg"
|
||||
config, err = acmedns.ReadConfig("./config.cfg")
|
||||
} else {
|
||||
err = fmt.Errorf("configuration file not found")
|
||||
}
|
||||
if err != nil {
|
||||
err = fmt.Errorf("encountered an error while trying to read configuration file: %s\n", err)
|
||||
}
|
||||
return config, usedConfigFile, err
|
||||
}
|
||||
|
||||
func main() {
|
||||
syscall.Umask(0077)
|
||||
configPtr := flag.String("c", "/etc/acme-dns/config.cfg", "config file location")
|
||||
@ -76,18 +21,19 @@ func main() {
|
||||
// Read global config
|
||||
var err error
|
||||
var logger *zap.Logger
|
||||
config, usedConfigFile, err := readConfig(*configPtr)
|
||||
config, usedConfigFile, err := acmedns.ReadConfig(*configPtr)
|
||||
if err != nil {
|
||||
fmt.Printf("Error: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
logger, err = setupLogging(config)
|
||||
logger, err = acmedns.SetupLogging(config)
|
||||
if err != nil {
|
||||
fmt.Printf("Could not set up logging: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer logger.Sync()
|
||||
sugar := logger.Sugar()
|
||||
|
||||
sugar.Infow("Using config file",
|
||||
"file", usedConfigFile)
|
||||
sugar.Info("Starting up")
|
||||
|
||||
@ -2,6 +2,7 @@ package acmedns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
@ -20,7 +21,7 @@ func FileIsAccessible(fname string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func ReadConfig(fname string) (AcmeDnsConfig, error) {
|
||||
func readTomlConfig(fname string) (AcmeDnsConfig, error) {
|
||||
var conf AcmeDnsConfig
|
||||
_, err := toml.DecodeFile(fname, &conf)
|
||||
if err != nil {
|
||||
@ -46,3 +47,22 @@ func prepareConfig(conf AcmeDnsConfig) (AcmeDnsConfig, error) {
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func ReadConfig(configFile string) (AcmeDnsConfig, string, error) {
|
||||
var usedConfigFile string
|
||||
var config AcmeDnsConfig
|
||||
var err error
|
||||
if FileIsAccessible(configFile) {
|
||||
usedConfigFile = configFile
|
||||
config, err = readTomlConfig(configFile)
|
||||
} else if FileIsAccessible("./config.cfg") {
|
||||
usedConfigFile = "./config.cfg"
|
||||
config, err = readTomlConfig("./config.cfg")
|
||||
} else {
|
||||
err = fmt.Errorf("configuration file not found")
|
||||
}
|
||||
if err != nil {
|
||||
err = fmt.Errorf("encountered an error while trying to read configuration file: %s\n", err)
|
||||
}
|
||||
return config, usedConfigFile, err
|
||||
}
|
||||
|
||||
@ -18,5 +18,6 @@ type AcmednsDB interface {
|
||||
type AcmednsNS interface {
|
||||
Start(errorChannel chan error)
|
||||
SetOwnAuthKey(key string)
|
||||
SetNotifyStartedFunc(func())
|
||||
ParseRecords()
|
||||
}
|
||||
|
||||
43
pkg/acmedns/logging.go
Normal file
43
pkg/acmedns/logging.go
Normal file
@ -0,0 +1,43 @@
|
||||
package acmedns
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func SetupLogging(config AcmeDnsConfig) (*zap.Logger, error) {
|
||||
var logger *zap.Logger
|
||||
logformat := "console"
|
||||
if config.Logconfig.Format == "json" {
|
||||
logformat = "json"
|
||||
}
|
||||
outputPath := "stdout"
|
||||
if config.Logconfig.Logtype == "file" {
|
||||
outputPath = config.Logconfig.File
|
||||
}
|
||||
errorPath := "stderr"
|
||||
if config.Logconfig.Logtype == "file" {
|
||||
errorPath = config.Logconfig.File
|
||||
}
|
||||
zapConfigJson := fmt.Sprintf(`{
|
||||
"level": "%s",
|
||||
"encoding": "%s",
|
||||
"outputPaths": ["%s"],
|
||||
"errorOutputPaths": ["%s"],
|
||||
"encoderConfig": {
|
||||
"timeKey": "time",
|
||||
"messageKey": "msg",
|
||||
"levelKey": "level",
|
||||
"levelEncoder": "lowercase",
|
||||
"timeEncoder": "iso8601"
|
||||
}
|
||||
}`, config.Logconfig.Level, logformat, outputPath, errorPath)
|
||||
var zapCfg zap.Config
|
||||
if err := json.Unmarshal([]byte(zapConfigJson), &zapCfg); err != nil {
|
||||
return logger, err
|
||||
}
|
||||
logger, err := zapCfg.Build()
|
||||
return logger, err
|
||||
}
|
||||
@ -2,6 +2,7 @@ package acmedns
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"math/big"
|
||||
"regexp"
|
||||
)
|
||||
@ -29,3 +30,10 @@ func generatePassword(length int) string {
|
||||
}
|
||||
return string(ret)
|
||||
}
|
||||
|
||||
func CorrectPassword(pw string, hash string) bool {
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pw)); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
package main
|
||||
package acmedns
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"syscall"
|
||||
@ -9,21 +11,55 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func fakeConfig() AcmeDnsConfig {
|
||||
conf := AcmeDnsConfig{}
|
||||
conf.Logconfig.Logtype = "stdout"
|
||||
return conf
|
||||
}
|
||||
|
||||
func TestSetupLogging(t *testing.T) {
|
||||
conf := fakeConfig()
|
||||
for i, test := range []struct {
|
||||
format string
|
||||
level string
|
||||
expected string
|
||||
expected zapcore.Level
|
||||
}{
|
||||
{"text", "warning", "warning"},
|
||||
{"json", "debug", "debug"},
|
||||
{"text", "info", "info"},
|
||||
{"json", "error", "error"},
|
||||
{"text", "something", "warning"},
|
||||
{"text", "warn", zap.WarnLevel},
|
||||
{"json", "debug", zap.DebugLevel},
|
||||
{"text", "info", zap.InfoLevel},
|
||||
{"json", "error", zap.ErrorLevel},
|
||||
} {
|
||||
setupLogging(test.format, test.level)
|
||||
if log.GetLevel().String() != test.expected {
|
||||
t.Errorf("Test %d: Expected loglevel %s but got %s", i, test.expected, log.GetLevel().String())
|
||||
conf.Logconfig.Format = test.format
|
||||
conf.Logconfig.Level = test.level
|
||||
logger, err := SetupLogging(conf)
|
||||
if err != nil {
|
||||
t.Errorf("Got unexpected error: %s", err)
|
||||
} else {
|
||||
if logger.Sugar().Level() != test.expected {
|
||||
t.Errorf("Test %d: Expected loglevel %s but got %s", i, test.expected, log.GetLevel().String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupLoggingError(t *testing.T) {
|
||||
conf := fakeConfig()
|
||||
for _, test := range []struct {
|
||||
format string
|
||||
level string
|
||||
errexpected bool
|
||||
}{
|
||||
{"text", "warn", false},
|
||||
{"json", "debug", false},
|
||||
{"text", "info", false},
|
||||
{"json", "error", false},
|
||||
{"text", "something", true},
|
||||
} {
|
||||
conf.Logconfig.Format = test.format
|
||||
conf.Logconfig.Level = test.level
|
||||
_, err := SetupLogging(conf)
|
||||
if test.errexpected && err == nil {
|
||||
t.Errorf("Expected error but did not get one for loglevel: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -31,11 +67,11 @@ func TestSetupLogging(t *testing.T) {
|
||||
func TestReadConfig(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
inFile []byte
|
||||
output DNSConfig
|
||||
output AcmeDnsConfig
|
||||
}{
|
||||
{
|
||||
[]byte("[general]\nlisten = \":53\"\ndebug = true\n[api]\napi_domain = \"something.strange\""),
|
||||
DNSConfig{
|
||||
AcmeDnsConfig{
|
||||
General: general{
|
||||
Listen: ":53",
|
||||
Debug: true,
|
||||
@ -48,10 +84,10 @@ func TestReadConfig(t *testing.T) {
|
||||
|
||||
{
|
||||
[]byte("[\x00[[[[[[[[[de\nlisten =]"),
|
||||
DNSConfig{},
|
||||
AcmeDnsConfig{},
|
||||
},
|
||||
} {
|
||||
tmpfile, err := ioutil.TempFile("", "acmedns")
|
||||
tmpfile, err := os.CreateTemp("", "acmedns")
|
||||
if err != nil {
|
||||
t.Error("Could not create temporary file")
|
||||
}
|
||||
@ -64,7 +100,7 @@ func TestReadConfig(t *testing.T) {
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
t.Error("Could not close temporary file")
|
||||
}
|
||||
ret, _ := readConfig(tmpfile.Name())
|
||||
ret, _, _ := ReadConfig(tmpfile.Name())
|
||||
if ret.General.Listen != test.output.General.Listen {
|
||||
t.Errorf("Test %d: Expected listen value %s, but got %s", i, test.output.General.Listen, ret.General.Listen)
|
||||
}
|
||||
@ -74,30 +110,6 @@ func TestReadConfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetIPListFromHeader(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
input string
|
||||
output []string
|
||||
}{
|
||||
{"1.1.1.1, 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
|
||||
{" 1.1.1.1 , 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
|
||||
{",1.1.1.1 ,2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
|
||||
} {
|
||||
res := getIPListFromHeader(test.input)
|
||||
if len(res) != len(test.output) {
|
||||
t.Errorf("Test %d: Expected [%d] items in return list, but got [%d]", i, len(test.output), len(res))
|
||||
} else {
|
||||
|
||||
for j, vv := range test.output {
|
||||
if res[j] != vv {
|
||||
t.Errorf("Test %d: Expected return value [%v] but got [%v]", j, test.output, res)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileCheckPermissionDenied(t *testing.T) {
|
||||
tmpfile, err := ioutil.TempFile("", "acmedns")
|
||||
if err != nil {
|
||||
@ -105,14 +117,14 @@ func TestFileCheckPermissionDenied(t *testing.T) {
|
||||
}
|
||||
defer os.Remove(tmpfile.Name())
|
||||
_ = syscall.Chmod(tmpfile.Name(), 0000)
|
||||
if fileIsAccessible(tmpfile.Name()) {
|
||||
if FileIsAccessible(tmpfile.Name()) {
|
||||
t.Errorf("File should not be accessible")
|
||||
}
|
||||
_ = syscall.Chmod(tmpfile.Name(), 0644)
|
||||
}
|
||||
|
||||
func TestFileCheckNotExists(t *testing.T) {
|
||||
if fileIsAccessible("/path/that/does/not/exist") {
|
||||
if FileIsAccessible("/path/that/does/not/exist") {
|
||||
t.Errorf("File should not be accessible")
|
||||
}
|
||||
}
|
||||
@ -123,19 +135,19 @@ func TestFileCheckOK(t *testing.T) {
|
||||
t.Error("Could not create temporary file")
|
||||
}
|
||||
defer os.Remove(tmpfile.Name())
|
||||
if !fileIsAccessible(tmpfile.Name()) {
|
||||
if !FileIsAccessible(tmpfile.Name()) {
|
||||
t.Errorf("File should be accessible")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareConfig(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
input DNSConfig
|
||||
input AcmeDnsConfig
|
||||
shoulderror bool
|
||||
}{
|
||||
{DNSConfig{Database: dbsettings{Engine: "whatever", Connection: "whatever_too"}}, false},
|
||||
{DNSConfig{Database: dbsettings{Engine: "", Connection: "whatever_too"}}, true},
|
||||
{DNSConfig{Database: dbsettings{Engine: "whatever", Connection: ""}}, true},
|
||||
{AcmeDnsConfig{Database: dbsettings{Engine: "whatever", Connection: "whatever_too"}}, false},
|
||||
{AcmeDnsConfig{Database: dbsettings{Engine: "", Connection: "whatever_too"}}, true},
|
||||
{AcmeDnsConfig{Database: dbsettings{Engine: "whatever", Connection: ""}}, true},
|
||||
} {
|
||||
_, err := prepareConfig(test.input)
|
||||
if test.shoulderror {
|
||||
@ -1,4 +1,4 @@
|
||||
package main
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -8,17 +8,29 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/acme-dns/acme-dns/pkg/acmedns"
|
||||
"github.com/acme-dns/acme-dns/pkg/database"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/gavv/httpexpect"
|
||||
"github.com/google/uuid"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/rs/cors"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func fakeConfigAndLogger() (acmedns.AcmeDnsConfig, *zap.SugaredLogger) {
|
||||
c := acmedns.AcmeDnsConfig{}
|
||||
c.Database.Engine = "sqlite"
|
||||
c.Database.Connection = ":memory:"
|
||||
l := zap.NewNop().Sugar()
|
||||
return c, l
|
||||
}
|
||||
|
||||
// noAuth function to write ACMETxt model to context while not preforming any validation
|
||||
func noAuth(update httprouter.Handle) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
||||
postData := ACMETxt{}
|
||||
postData := acmedns.ACMETxt{}
|
||||
uname := r.Header.Get("X-Api-User")
|
||||
passwd := r.Header.Get("X-Api-Key")
|
||||
|
||||
@ -46,42 +58,37 @@ func getExpect(t *testing.T, server *httptest.Server) *httpexpect.Expect {
|
||||
})
|
||||
}
|
||||
|
||||
func setupRouter(debug bool, noauth bool) http.Handler {
|
||||
func setupRouter(debug bool, noauth bool) (http.Handler, AcmednsAPI, acmedns.AcmednsDB) {
|
||||
api := httprouter.New()
|
||||
var dbcfg = dbsettings{
|
||||
Engine: "sqlite3",
|
||||
Connection: ":memory:"}
|
||||
var httpapicfg = httpapi{
|
||||
Domain: "",
|
||||
Port: "8080",
|
||||
TLS: "none",
|
||||
CorsOrigins: []string{"*"},
|
||||
UseHeader: true,
|
||||
HeaderName: "X-Forwarded-For",
|
||||
}
|
||||
var dnscfg = DNSConfig{
|
||||
API: httpapicfg,
|
||||
Database: dbcfg,
|
||||
}
|
||||
Config = dnscfg
|
||||
config, logger := fakeConfigAndLogger()
|
||||
config.API.Domain = ""
|
||||
config.API.Port = "8080"
|
||||
config.API.TLS = "none"
|
||||
config.API.CorsOrigins = []string{"*"}
|
||||
config.API.UseHeader = true
|
||||
config.API.HeaderName = "X-Forwarded-For"
|
||||
|
||||
db, _ := database.Init(&config, logger)
|
||||
errChan := make(chan error, 1)
|
||||
adnsapi := Init(&config, db, logger, errChan)
|
||||
c := cors.New(cors.Options{
|
||||
AllowedOrigins: Config.API.CorsOrigins,
|
||||
AllowedOrigins: config.API.CorsOrigins,
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
OptionsPassthrough: false,
|
||||
Debug: Config.General.Debug,
|
||||
Debug: config.General.Debug,
|
||||
})
|
||||
api.POST("/register", webRegisterPost)
|
||||
api.GET("/health", healthCheck)
|
||||
api.POST("/register", adnsapi.webRegisterPost)
|
||||
api.GET("/health", adnsapi.healthCheck)
|
||||
if noauth {
|
||||
api.POST("/update", noAuth(webUpdatePost))
|
||||
api.POST("/update", noAuth(adnsapi.webUpdatePost))
|
||||
} else {
|
||||
api.POST("/update", Auth(webUpdatePost))
|
||||
api.POST("/update", adnsapi.Auth(adnsapi.webUpdatePost))
|
||||
}
|
||||
return c.Handler(api)
|
||||
return c.Handler(api), adnsapi, db
|
||||
}
|
||||
|
||||
func TestApiRegister(t *testing.T) {
|
||||
router := setupRouter(false, false)
|
||||
router, _, _ := setupRouter(false, false)
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
e := getExpect(t, server)
|
||||
@ -117,7 +124,7 @@ func TestApiRegister(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestApiRegisterBadAllowFrom(t *testing.T) {
|
||||
router := setupRouter(false, false)
|
||||
router, _, _ := setupRouter(false, false)
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
e := getExpect(t, server)
|
||||
@ -147,7 +154,7 @@ func TestApiRegisterBadAllowFrom(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestApiRegisterMalformedJSON(t *testing.T) {
|
||||
router := setupRouter(false, false)
|
||||
router, _, _ := setupRouter(false, false)
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
e := getExpect(t, server)
|
||||
@ -174,13 +181,13 @@ func TestApiRegisterMalformedJSON(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestApiRegisterWithMockDB(t *testing.T) {
|
||||
router := setupRouter(false, false)
|
||||
router, _, db := setupRouter(false, false)
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
e := getExpect(t, server)
|
||||
oldDb := DB.GetBackend()
|
||||
db, mock, _ := sqlmock.New()
|
||||
DB.SetBackend(db)
|
||||
oldDb := db.GetBackend()
|
||||
mdb, mock, _ := sqlmock.New()
|
||||
db.SetBackend(mdb)
|
||||
defer db.Close()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectPrepare("INSERT INTO records").WillReturnError(errors.New("error"))
|
||||
@ -188,7 +195,7 @@ func TestApiRegisterWithMockDB(t *testing.T) {
|
||||
Status(http.StatusInternalServerError).
|
||||
JSON().Object().
|
||||
ContainsKey("error")
|
||||
DB.SetBackend(oldDb)
|
||||
db.SetBackend(oldDb)
|
||||
}
|
||||
|
||||
func TestApiUpdateWithInvalidSubdomain(t *testing.T) {
|
||||
@ -198,11 +205,11 @@ func TestApiUpdateWithInvalidSubdomain(t *testing.T) {
|
||||
"subdomain": "",
|
||||
"txt": ""}
|
||||
|
||||
router := setupRouter(false, false)
|
||||
router, _, db := setupRouter(false, false)
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
e := getExpect(t, server)
|
||||
newUser, err := DB.Register(cidrslice{})
|
||||
newUser, err := db.Register(acmedns.Cidrslice{})
|
||||
if err != nil {
|
||||
t.Errorf("Could not create new user, got error [%v]", err)
|
||||
}
|
||||
@ -228,11 +235,11 @@ func TestApiUpdateWithInvalidTxt(t *testing.T) {
|
||||
"subdomain": "",
|
||||
"txt": ""}
|
||||
|
||||
router := setupRouter(false, false)
|
||||
router, _, db := setupRouter(false, false)
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
e := getExpect(t, server)
|
||||
newUser, err := DB.Register(cidrslice{})
|
||||
newUser, err := db.Register(acmedns.Cidrslice{})
|
||||
if err != nil {
|
||||
t.Errorf("Could not create new user, got error [%v]", err)
|
||||
}
|
||||
@ -252,7 +259,7 @@ func TestApiUpdateWithInvalidTxt(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestApiUpdateWithoutCredentials(t *testing.T) {
|
||||
router := setupRouter(false, false)
|
||||
router, _, _ := setupRouter(false, false)
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
e := getExpect(t, server)
|
||||
@ -270,11 +277,11 @@ func TestApiUpdateWithCredentials(t *testing.T) {
|
||||
"subdomain": "",
|
||||
"txt": ""}
|
||||
|
||||
router := setupRouter(false, false)
|
||||
router, _, db := setupRouter(false, false)
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
e := getExpect(t, server)
|
||||
newUser, err := DB.Register(cidrslice{})
|
||||
newUser, err := db.Register(acmedns.Cidrslice{})
|
||||
if err != nil {
|
||||
t.Errorf("Could not create new user, got error [%v]", err)
|
||||
}
|
||||
@ -303,13 +310,13 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
|
||||
updateJSON["subdomain"] = "a097455b-52cc-4569-90c8-7a4b97c6eba8"
|
||||
updateJSON["txt"] = validTxtData
|
||||
|
||||
router := setupRouter(false, true)
|
||||
router, _, db := setupRouter(false, true)
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
e := getExpect(t, server)
|
||||
oldDb := DB.GetBackend()
|
||||
db, mock, _ := sqlmock.New()
|
||||
DB.SetBackend(db)
|
||||
oldDb := db.GetBackend()
|
||||
mdb, mock, _ := sqlmock.New()
|
||||
db.SetBackend(mdb)
|
||||
defer db.Close()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectPrepare("UPDATE records").WillReturnError(errors.New("error"))
|
||||
@ -319,31 +326,31 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
|
||||
Status(http.StatusInternalServerError).
|
||||
JSON().Object().
|
||||
ContainsKey("error")
|
||||
DB.SetBackend(oldDb)
|
||||
db.SetBackend(oldDb)
|
||||
}
|
||||
|
||||
func TestApiManyUpdateWithCredentials(t *testing.T) {
|
||||
validTxtData := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
|
||||
router := setupRouter(true, false)
|
||||
router, _, db := setupRouter(true, false)
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
e := getExpect(t, server)
|
||||
// User without defined CIDR masks
|
||||
newUser, err := DB.Register(cidrslice{})
|
||||
newUser, err := db.Register(acmedns.Cidrslice{})
|
||||
if err != nil {
|
||||
t.Errorf("Could not create new user, got error [%v]", err)
|
||||
}
|
||||
|
||||
// User with defined allow from - CIDR masks, all invalid
|
||||
// (httpexpect doesn't provide a way to mock remote ip)
|
||||
newUserWithCIDR, err := DB.Register(cidrslice{"192.168.1.1/32", "invalid"})
|
||||
newUserWithCIDR, err := db.Register(acmedns.Cidrslice{"192.168.1.1/32", "invalid"})
|
||||
if err != nil {
|
||||
t.Errorf("Could not create new user with CIDR, got error [%v]", err)
|
||||
}
|
||||
|
||||
// Another user with valid CIDR mask to match the httpexpect default
|
||||
newUserWithValidCIDR, err := DB.Register(cidrslice{"10.1.2.3/32", "invalid"})
|
||||
newUserWithValidCIDR, err := db.Register(acmedns.Cidrslice{"10.1.2.3/32", "invalid"})
|
||||
if err != nil {
|
||||
t.Errorf("Could not create new user with a valid CIDR, got error [%v]", err)
|
||||
}
|
||||
@ -381,30 +388,30 @@ func TestApiManyUpdateWithCredentials(t *testing.T) {
|
||||
|
||||
func TestApiManyUpdateWithIpCheckHeaders(t *testing.T) {
|
||||
|
||||
router := setupRouter(false, false)
|
||||
router, adnsapi, db := setupRouter(false, false)
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
e := getExpect(t, server)
|
||||
// Use header checks from default header (X-Forwarded-For)
|
||||
Config.API.UseHeader = true
|
||||
adnsapi.Config.API.UseHeader = true
|
||||
// User without defined CIDR masks
|
||||
newUser, err := DB.Register(cidrslice{})
|
||||
newUser, err := db.Register(acmedns.Cidrslice{})
|
||||
if err != nil {
|
||||
t.Errorf("Could not create new user, got error [%v]", err)
|
||||
}
|
||||
|
||||
newUserWithCIDR, err := DB.Register(cidrslice{"192.168.1.2/32", "invalid"})
|
||||
newUserWithCIDR, err := db.Register(acmedns.Cidrslice{"192.168.1.2/32", "invalid"})
|
||||
if err != nil {
|
||||
t.Errorf("Could not create new user with CIDR, got error [%v]", err)
|
||||
}
|
||||
|
||||
newUserWithIP6CIDR, err := DB.Register(cidrslice{"2002:c0a8::0/32"})
|
||||
newUserWithIP6CIDR, err := db.Register(acmedns.Cidrslice{"2002:c0a8::0/32"})
|
||||
if err != nil {
|
||||
t.Errorf("Could not create a new user with IP6 CIDR, got error [%v]", err)
|
||||
}
|
||||
|
||||
for _, test := range []struct {
|
||||
user ACMETxt
|
||||
user acmedns.ACMETxt
|
||||
headerValue string
|
||||
status int
|
||||
}{
|
||||
@ -428,13 +435,66 @@ func TestApiManyUpdateWithIpCheckHeaders(t *testing.T) {
|
||||
Expect().
|
||||
Status(test.status)
|
||||
}
|
||||
Config.API.UseHeader = false
|
||||
adnsapi.Config.API.UseHeader = false
|
||||
}
|
||||
|
||||
func TestApiHealthCheck(t *testing.T) {
|
||||
router := setupRouter(false, false)
|
||||
router, _, _ := setupRouter(false, false)
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
e := getExpect(t, server)
|
||||
e.GET("/health").Expect().Status(http.StatusOK)
|
||||
}
|
||||
|
||||
func TestGetIPListFromHeader(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
input string
|
||||
output []string
|
||||
}{
|
||||
{"1.1.1.1, 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
|
||||
{" 1.1.1.1 , 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
|
||||
{",1.1.1.1 ,2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
|
||||
} {
|
||||
res := getIPListFromHeader(test.input)
|
||||
if len(res) != len(test.output) {
|
||||
t.Errorf("Test %d: Expected [%d] items in return list, but got [%d]", i, len(test.output), len(res))
|
||||
} else {
|
||||
|
||||
for j, vv := range test.output {
|
||||
if res[j] != vv {
|
||||
t.Errorf("Test %d: Expected return value [%v] but got [%v]", j, test.output, res)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAllowedFromIP(t *testing.T) {
|
||||
_, adnsapi, _ := setupRouter(false, false)
|
||||
adnsapi.Config.API.UseHeader = false
|
||||
userWithAllow := acmedns.NewACMETxt()
|
||||
userWithAllow.AllowFrom = acmedns.Cidrslice{"192.168.1.2/32", "[::1]/128"}
|
||||
userWithoutAllow := acmedns.NewACMETxt()
|
||||
|
||||
for i, test := range []struct {
|
||||
remoteaddr string
|
||||
expected bool
|
||||
}{
|
||||
{"192.168.1.2:1234", true},
|
||||
{"192.168.1.1:1234", false},
|
||||
{"invalid", false},
|
||||
{"[::1]:4567", true},
|
||||
} {
|
||||
newreq, _ := http.NewRequest("GET", "/whatever", nil)
|
||||
newreq.RemoteAddr = test.remoteaddr
|
||||
ret := adnsapi.updateAllowedFromIP(newreq, userWithAllow)
|
||||
if test.expected != ret {
|
||||
t.Errorf("Test %d: Unexpected result for user with allowForm set", i)
|
||||
}
|
||||
|
||||
if !adnsapi.updateAllowedFromIP(newreq, userWithoutAllow) {
|
||||
t.Errorf("Test %d: Unexpected result for user without allowForm set", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -74,11 +74,11 @@ func (a *AcmednsAPI) getUserFromRequest(r *http.Request) (acmedns.ACMETxt, error
|
||||
a.Logger.Errorw("Error while trying to get user",
|
||||
"error", err.Error())
|
||||
// To protect against timed side channel (never gonna give you up)
|
||||
correctPassword(passwd, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36")
|
||||
acmedns.CorrectPassword(passwd, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36")
|
||||
|
||||
return acmedns.ACMETxt{}, fmt.Errorf("invalid username: %s", uname)
|
||||
}
|
||||
if correctPassword(passwd, dbuser.Password) {
|
||||
if acmedns.CorrectPassword(passwd, dbuser.Password) {
|
||||
return dbuser, nil
|
||||
}
|
||||
return acmedns.ACMETxt{}, fmt.Errorf("invalid password for user %s", uname)
|
||||
|
||||
@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
"github.com/acme-dns/acme-dns/pkg/acmedns"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
@ -31,13 +30,6 @@ func validKey(k string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func correctPassword(pw string, hash string) bool {
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pw)); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getIPListFromHeader(header string) []string {
|
||||
iplist := []string{}
|
||||
for _, v := range strings.Split(header, ",") {
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package main
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/acme-dns/acme-dns/pkg/acmedns"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@ -103,7 +104,7 @@ func TestCorrectPassword(t *testing.T) {
|
||||
false},
|
||||
{"", "", false},
|
||||
} {
|
||||
ret := correctPassword(test.pw, test.hash)
|
||||
ret := acmedns.CorrectPassword(test.pw, test.hash)
|
||||
if ret != test.output {
|
||||
t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret)
|
||||
}
|
||||
@ -112,12 +113,12 @@ func TestCorrectPassword(t *testing.T) {
|
||||
|
||||
func TestGetValidCIDRMasks(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
input cidrslice
|
||||
output cidrslice
|
||||
input acmedns.Cidrslice
|
||||
output acmedns.Cidrslice
|
||||
}{
|
||||
{cidrslice{"10.0.0.1/24"}, cidrslice{"10.0.0.1/24"}},
|
||||
{cidrslice{"invalid", "127.0.0.1/32"}, cidrslice{"127.0.0.1/32"}},
|
||||
{cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}, cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}},
|
||||
{acmedns.Cidrslice{"10.0.0.1/24"}, acmedns.Cidrslice{"10.0.0.1/24"}},
|
||||
{acmedns.Cidrslice{"invalid", "127.0.0.1/32"}, acmedns.Cidrslice{"127.0.0.1/32"}},
|
||||
{acmedns.Cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}, acmedns.Cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}},
|
||||
} {
|
||||
ret := test.input.ValidEntries()
|
||||
if len(ret) == len(test.output) {
|
||||
@ -1,10 +1,12 @@
|
||||
package main
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"github.com/acme-dns/acme-dns/pkg/acmedns"
|
||||
"github.com/erikstmartin/go-testdb"
|
||||
"go.uber.org/zap"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -21,42 +23,38 @@ func (r testResult) RowsAffected() (int64, error) {
|
||||
return r.affectedRows, nil
|
||||
}
|
||||
|
||||
func TestDBInit(t *testing.T) {
|
||||
fakeDB := new(acmedb)
|
||||
err := fakeDB.Init("notarealegine", "connectionstring")
|
||||
if err == nil {
|
||||
t.Errorf("Was expecting error, didn't get one.")
|
||||
}
|
||||
func fakeConfigAndLogger() (acmedns.AcmeDnsConfig, *zap.SugaredLogger) {
|
||||
c := acmedns.AcmeDnsConfig{}
|
||||
c.Database.Engine = "sqlite"
|
||||
c.Database.Connection = ":memory:"
|
||||
l := zap.NewNop().Sugar()
|
||||
return c, l
|
||||
}
|
||||
|
||||
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
|
||||
return testResult{1, 0}, errors.New("Prepared query error")
|
||||
})
|
||||
defer testdb.Reset()
|
||||
|
||||
errorDB := new(acmedb)
|
||||
err = errorDB.Init("testdb", "")
|
||||
if err == nil {
|
||||
t.Errorf("Was expecting DB initiation error but got none")
|
||||
}
|
||||
errorDB.Close()
|
||||
func fakeDB() acmedns.AcmednsDB {
|
||||
conf, logger := fakeConfigAndLogger()
|
||||
db, _ := Init(&conf, logger)
|
||||
return db
|
||||
}
|
||||
|
||||
func TestRegisterNoCIDR(t *testing.T) {
|
||||
// Register tests
|
||||
_, err := DB.Register(cidrslice{})
|
||||
DB := fakeDB()
|
||||
_, err := DB.Register(acmedns.Cidrslice{})
|
||||
if err != nil {
|
||||
t.Errorf("Registration failed, got error [%v]", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterMany(t *testing.T) {
|
||||
DB := fakeDB()
|
||||
for i, test := range []struct {
|
||||
input cidrslice
|
||||
output cidrslice
|
||||
input acmedns.Cidrslice
|
||||
output acmedns.Cidrslice
|
||||
}{
|
||||
{cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}, cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}},
|
||||
{cidrslice{"1.1.1./32", "1922.168.42.42/8", "1.1.1.1/33", "1.2.3.4/"}, cidrslice{}},
|
||||
{cidrslice{"7.6.5.4/32", "invalid", "1.0.0.1/2"}, cidrslice{"7.6.5.4/32", "1.0.0.1/2"}},
|
||||
{acmedns.Cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}, acmedns.Cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}},
|
||||
{acmedns.Cidrslice{"1.1.1./32", "1922.168.42.42/8", "1.1.1.1/33", "1.2.3.4/"}, acmedns.Cidrslice{}},
|
||||
{acmedns.Cidrslice{"7.6.5.4/32", "invalid", "1.0.0.1/2"}, acmedns.Cidrslice{"7.6.5.4/32", "1.0.0.1/2"}},
|
||||
} {
|
||||
user, err := DB.Register(test.input)
|
||||
if err != nil {
|
||||
@ -77,8 +75,9 @@ func TestRegisterMany(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetByUsername(t *testing.T) {
|
||||
DB := fakeDB()
|
||||
// Create reg to refer to
|
||||
reg, err := DB.Register(cidrslice{})
|
||||
reg, err := DB.Register(acmedns.Cidrslice{})
|
||||
if err != nil {
|
||||
t.Errorf("Registration failed, got error [%v]", err)
|
||||
}
|
||||
@ -97,13 +96,14 @@ func TestGetByUsername(t *testing.T) {
|
||||
}
|
||||
|
||||
// regUser password already is a bcrypt hash
|
||||
if !correctPassword(reg.Password, regUser.Password) {
|
||||
if !acmedns.CorrectPassword(reg.Password, regUser.Password) {
|
||||
t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regUser.Password)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareErrors(t *testing.T) {
|
||||
reg, _ := DB.Register(cidrslice{})
|
||||
DB := fakeDB()
|
||||
reg, _ := DB.Register(acmedns.Cidrslice{})
|
||||
tdb, err := sql.Open("testdb", "")
|
||||
if err != nil {
|
||||
t.Errorf("Got error: %v", err)
|
||||
@ -125,7 +125,8 @@ func TestPrepareErrors(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestQueryExecErrors(t *testing.T) {
|
||||
reg, _ := DB.Register(cidrslice{})
|
||||
DB := fakeDB()
|
||||
reg, _ := DB.Register(acmedns.Cidrslice{})
|
||||
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
|
||||
return testResult{1, 0}, errors.New("Prepared query error")
|
||||
})
|
||||
@ -156,7 +157,7 @@ func TestQueryExecErrors(t *testing.T) {
|
||||
t.Errorf("Expected error from exec in GetByDomain, but got none")
|
||||
}
|
||||
|
||||
_, err = DB.Register(cidrslice{})
|
||||
_, err = DB.Register(acmedns.Cidrslice{})
|
||||
if err == nil {
|
||||
t.Errorf("Expected error from exec in Register, but got none")
|
||||
}
|
||||
@ -169,7 +170,8 @@ func TestQueryExecErrors(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestQueryScanErrors(t *testing.T) {
|
||||
reg, _ := DB.Register(cidrslice{})
|
||||
DB := fakeDB()
|
||||
reg, _ := DB.Register(acmedns.Cidrslice{})
|
||||
|
||||
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
|
||||
return testResult{1, 0}, errors.New("Prepared query error")
|
||||
@ -198,7 +200,8 @@ func TestQueryScanErrors(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBadDBValues(t *testing.T) {
|
||||
reg, _ := DB.Register(cidrslice{})
|
||||
DB := fakeDB()
|
||||
reg, _ := DB.Register(acmedns.Cidrslice{})
|
||||
|
||||
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
|
||||
columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"}
|
||||
@ -228,8 +231,9 @@ func TestBadDBValues(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetTXTForDomain(t *testing.T) {
|
||||
DB := fakeDB()
|
||||
// Create reg to refer to
|
||||
reg, err := DB.Register(cidrslice{})
|
||||
reg, err := DB.Register(acmedns.Cidrslice{})
|
||||
if err != nil {
|
||||
t.Errorf("Registration failed, got error [%v]", err)
|
||||
}
|
||||
@ -276,8 +280,9 @@ func TestGetTXTForDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
DB := fakeDB()
|
||||
// Create reg to refer to
|
||||
reg, err := DB.Register(cidrslice{})
|
||||
reg, err := DB.Register(acmedns.Cidrslice{})
|
||||
if err != nil {
|
||||
t.Errorf("Registration failed, got error [%v]", err)
|
||||
}
|
||||
|
||||
@ -1,20 +1,66 @@
|
||||
package main
|
||||
package nameserver
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/acme-dns/acme-dns/pkg/acmedns"
|
||||
"github.com/acme-dns/acme-dns/pkg/database"
|
||||
"github.com/erikstmartin/go-testdb"
|
||||
"github.com/miekg/dns"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zaptest/observer"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type resolver struct {
|
||||
server string
|
||||
}
|
||||
|
||||
var records = []string{
|
||||
"auth.example.org. A 192.168.1.100",
|
||||
"ns1.auth.example.org. A 192.168.1.101",
|
||||
"cn.example.org CNAME something.example.org.",
|
||||
"!''b', unparseable ",
|
||||
"ns2.auth.example.org. A 192.168.1.102",
|
||||
}
|
||||
|
||||
func loggerHasEntryWithMessage(message string, logObserver *observer.ObservedLogs) bool {
|
||||
if len(logObserver.FilterMessage(message).All()) > 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func fakeConfigAndLogger() (acmedns.AcmeDnsConfig, *zap.SugaredLogger, *observer.ObservedLogs) {
|
||||
c := acmedns.AcmeDnsConfig{}
|
||||
c.Database.Engine = "sqlite"
|
||||
c.Database.Connection = ":memory:"
|
||||
obsCore, logObserver := observer.New(zap.DebugLevel)
|
||||
obsLogger := zap.New(obsCore).Sugar()
|
||||
return c, obsLogger, logObserver
|
||||
}
|
||||
|
||||
func setupDNS() (acmedns.AcmednsNS, acmedns.AcmednsDB, *observer.ObservedLogs) {
|
||||
config, logger, logObserver := fakeConfigAndLogger()
|
||||
config.General.Domain = "auth.example.org"
|
||||
config.General.Listen = "127.0.0.1:15353"
|
||||
config.General.Proto = "udp"
|
||||
config.General.Nsname = "ns1.auth.example.org"
|
||||
config.General.Nsadmin = "admin.example.org"
|
||||
config.General.StaticRecords = records
|
||||
config.General.Debug = false
|
||||
db, _ := database.Init(&config, logger)
|
||||
server := Nameserver{Config: &config, DB: db, Logger: logger, personalAuthKey: ""}
|
||||
server.Domains = make(map[string]Records)
|
||||
server.Server = &dns.Server{Addr: config.General.Listen, Net: config.General.Proto}
|
||||
server.ParseRecords()
|
||||
server.OwnDomain = "auth.example.org."
|
||||
return &server, db, logObserver
|
||||
}
|
||||
|
||||
func (r *resolver) lookup(host string, qtype uint16) (*dns.Msg, error) {
|
||||
msg := new(dns.Msg)
|
||||
msg.Id = dns.Id()
|
||||
@ -27,7 +73,6 @@ func (r *resolver) lookup(host string, qtype uint16) (*dns.Msg, error) {
|
||||
if in != nil && in.Rcode != dns.RcodeSuccess {
|
||||
return in, fmt.Errorf("Received error from the server [%s]", dns.RcodeToString[in.Rcode])
|
||||
}
|
||||
|
||||
return in, nil
|
||||
}
|
||||
|
||||
@ -49,6 +94,18 @@ func hasExpectedTXTAnswer(answer []dns.RR, cmpTXT string) error {
|
||||
}
|
||||
|
||||
func TestQuestionDBError(t *testing.T) {
|
||||
config, logger, _ := fakeConfigAndLogger()
|
||||
config.General.Listen = "127.0.0.1:15353"
|
||||
config.General.Proto = "udp"
|
||||
config.General.Domain = "auth.example.org"
|
||||
config.General.Nsname = "ns1.auth.example.org"
|
||||
config.General.Nsadmin = "admin.example.org"
|
||||
config.General.StaticRecords = records
|
||||
config.General.Debug = false
|
||||
db, _ := database.Init(&config, logger)
|
||||
server := Nameserver{Config: &config, DB: db, Logger: logger, personalAuthKey: ""}
|
||||
server.Domains = make(map[string]Records)
|
||||
server.ParseRecords()
|
||||
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
|
||||
columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"}
|
||||
return testdb.RowsFromSlice(columns, [][]driver.Value{}), errors.New("Prepared query error")
|
||||
@ -60,48 +117,61 @@ func TestQuestionDBError(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("Got error: %v", err)
|
||||
}
|
||||
oldDb := DB.GetBackend()
|
||||
oldDb := db.GetBackend()
|
||||
|
||||
DB.SetBackend(tdb)
|
||||
defer DB.SetBackend(oldDb)
|
||||
db.SetBackend(tdb)
|
||||
defer db.SetBackend(oldDb)
|
||||
|
||||
q := dns.Question{Name: dns.Fqdn("whatever.tld"), Qtype: dns.TypeTXT, Qclass: dns.ClassINET}
|
||||
_, err = dnsserver.answerTXT(q)
|
||||
_, err = server.answerTXT(q)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error but got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
var testcfg = DNSConfig{
|
||||
General: general{
|
||||
Domain: ")",
|
||||
Nsname: "ns1.auth.example.org",
|
||||
Nsadmin: "admin.example.org",
|
||||
StaticRecords: []string{},
|
||||
Debug: false,
|
||||
},
|
||||
}
|
||||
dnsserver.ParseRecords(testcfg)
|
||||
if !loggerHasEntryWithMessage("Error while adding SOA record") {
|
||||
config, logger, logObserver := fakeConfigAndLogger()
|
||||
config.General.Listen = "127.0.0.1:15353"
|
||||
config.General.Proto = "udp"
|
||||
config.General.Domain = ")"
|
||||
config.General.Nsname = "ns1.auth.example.org"
|
||||
config.General.Nsadmin = "admin.example.org"
|
||||
config.General.StaticRecords = records
|
||||
config.General.Debug = false
|
||||
config.General.StaticRecords = []string{}
|
||||
db, _ := database.Init(&config, logger)
|
||||
server := Nameserver{Config: &config, DB: db, Logger: logger, personalAuthKey: ""}
|
||||
server.Domains = make(map[string]Records)
|
||||
server.ParseRecords()
|
||||
if !loggerHasEntryWithMessage("Error while adding SOA record", logObserver) {
|
||||
t.Errorf("Expected SOA parsing to return error, but did not find one")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveA(t *testing.T) {
|
||||
server, _, _ := setupDNS()
|
||||
errChan := make(chan error, 1)
|
||||
waitLock := sync.Mutex{}
|
||||
waitLock.Lock()
|
||||
server.SetNotifyStartedFunc(waitLock.Unlock)
|
||||
go server.Start(errChan)
|
||||
waitLock.Lock()
|
||||
resolv := resolver{server: "127.0.0.1:15353"}
|
||||
answer, err := resolv.lookup("auth.example.org", dns.TypeA)
|
||||
if err != nil {
|
||||
t.Errorf("%v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(answer.Answer) == 0 {
|
||||
t.Error("No answer for DNS query")
|
||||
return
|
||||
}
|
||||
|
||||
_, err = resolv.lookup("nonexistent.domain.tld", dns.TypeA)
|
||||
if err == nil {
|
||||
t.Errorf("Was expecting error because of NXDOMAIN but got none")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@ -195,17 +265,20 @@ func TestAuthoritative(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
func TestResolveTXT(t *testing.T) {
|
||||
_, db, _ := setupDNS()
|
||||
resolv := resolver{server: "127.0.0.1:15353"}
|
||||
validTXT := "______________valid_response_______________"
|
||||
|
||||
atxt, err := DB.Register(cidrslice{})
|
||||
atxt, err := db.Register(acmedns.Cidrslice{})
|
||||
if err != nil {
|
||||
t.Errorf("Could not initiate db record: [%v]", err)
|
||||
return
|
||||
}
|
||||
atxt.Value = validTXT
|
||||
err = DB.Update(atxt.ACMETxtPost)
|
||||
|
||||
err = db.Update(atxt.ACMETxtPost)
|
||||
if err != nil {
|
||||
t.Errorf("Could not update db record: [%v]", err)
|
||||
return
|
||||
@ -254,7 +327,7 @@ func TestResolveTXT(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}*/
|
||||
|
||||
func TestCaseInsensitiveResolveA(t *testing.T) {
|
||||
resolv := resolver{server: "127.0.0.1:15353"}
|
||||
@ -9,7 +9,6 @@ import (
|
||||
func (n *Nameserver) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
|
||||
// handle edns0
|
||||
opt := r.IsEdns0()
|
||||
if opt != nil {
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
package nameserver
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/acme-dns/acme-dns/pkg/acmedns"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"go.uber.org/zap"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Records is a slice of ResourceRecords
|
||||
@ -15,21 +14,23 @@ type Records struct {
|
||||
}
|
||||
|
||||
type Nameserver struct {
|
||||
Config *acmedns.AcmeDnsConfig
|
||||
DB acmedns.AcmednsDB
|
||||
Logger *zap.SugaredLogger
|
||||
Server *dns.Server
|
||||
OwnDomain string
|
||||
SOA dns.RR
|
||||
personalAuthKey string
|
||||
Domains map[string]Records
|
||||
errChan chan error
|
||||
Config *acmedns.AcmeDnsConfig
|
||||
DB acmedns.AcmednsDB
|
||||
Logger *zap.SugaredLogger
|
||||
Server *dns.Server
|
||||
OwnDomain string
|
||||
NotifyStartedFunc func()
|
||||
SOA dns.RR
|
||||
personalAuthKey string
|
||||
Domains map[string]Records
|
||||
errChan chan error
|
||||
}
|
||||
|
||||
func InitAndStart(config *acmedns.AcmeDnsConfig, db acmedns.AcmednsDB, logger *zap.SugaredLogger, errChan chan error) []acmedns.AcmednsNS {
|
||||
dnsservers := make([]acmedns.AcmednsNS, 0)
|
||||
|
||||
waitLock := sync.Mutex{}
|
||||
if strings.HasPrefix(config.General.Proto, "both") {
|
||||
|
||||
// Handle the case where DNS server should be started for both udp and tcp
|
||||
udpProto := "udp"
|
||||
tcpProto := "tcp"
|
||||
@ -46,13 +47,22 @@ func InitAndStart(config *acmedns.AcmeDnsConfig, db acmedns.AcmednsDB, logger *z
|
||||
dnsServerTCP := NewDNSServer(config, db, logger, tcpProto)
|
||||
dnsservers = append(dnsservers, dnsServerTCP)
|
||||
dnsServerTCP.ParseRecords()
|
||||
// wait for the server to get started to proceed
|
||||
waitLock.Lock()
|
||||
dnsServerUDP.SetNotifyStartedFunc(waitLock.Unlock)
|
||||
go dnsServerUDP.Start(errChan)
|
||||
waitLock.Lock()
|
||||
dnsServerTCP.SetNotifyStartedFunc(waitLock.Unlock)
|
||||
go dnsServerTCP.Start(errChan)
|
||||
waitLock.Lock()
|
||||
} else {
|
||||
dnsServer := NewDNSServer(config, db, logger, config.General.Proto)
|
||||
dnsservers = append(dnsservers, dnsServer)
|
||||
dnsServer.ParseRecords()
|
||||
waitLock.Lock()
|
||||
dnsServer.SetNotifyStartedFunc(waitLock.Unlock)
|
||||
go dnsServer.Start(errChan)
|
||||
waitLock.Lock()
|
||||
}
|
||||
return dnsservers
|
||||
}
|
||||
@ -67,7 +77,6 @@ func NewDNSServer(config *acmedns.AcmeDnsConfig, db acmedns.AcmednsDB, logger *z
|
||||
domain = domain + "."
|
||||
}
|
||||
server.OwnDomain = strings.ToLower(domain)
|
||||
server.DB = db
|
||||
server.personalAuthKey = ""
|
||||
server.Domains = make(map[string]Records)
|
||||
return &server
|
||||
@ -79,8 +88,15 @@ func (n *Nameserver) Start(errorChannel chan error) {
|
||||
n.Logger.Infow("Starting DNS listener",
|
||||
"addr", n.Server.Addr,
|
||||
"proto", n.Server.Net)
|
||||
if n.NotifyStartedFunc != nil {
|
||||
n.Server.NotifyStartedFunc = n.NotifyStartedFunc
|
||||
}
|
||||
err := n.Server.ListenAndServe()
|
||||
if err != nil {
|
||||
errorChannel <- err
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Nameserver) SetNotifyStartedFunc(fun func()) {
|
||||
n.Server.NotifyStartedFunc = fun
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user