Re-added tests

This commit is contained in:
Joona Hoikkala 2022-12-25 12:52:49 +02:00
parent 1405e6ab47
commit 157241994f
No known key found for this signature in database
GPG Key ID: 1708DAE66E87A524
15 changed files with 425 additions and 283 deletions

View File

@ -1,34 +0,0 @@
package main
import (
"net/http"
"testing"
)
func TestUpdateAllowedFromIP(t *testing.T) {
Config.API.UseHeader = false
userWithAllow := newACMETxt()
userWithAllow.AllowFrom = cidrslice{"192.168.1.2/32", "[::1]/128"}
userWithoutAllow := newACMETxt()
for i, test := range []struct {
remoteaddr string
expected bool
}{
{"192.168.1.2:1234", true},
{"192.168.1.1:1234", false},
{"invalid", false},
{"[::1]:4567", true},
} {
newreq, _ := http.NewRequest("GET", "/whatever", nil)
newreq.RemoteAddr = test.remoteaddr
ret := updateAllowedFromIP(newreq, userWithAllow)
if test.expected != ret {
t.Errorf("Test %d: Unexpected result for user with allowForm set", i)
}
if !updateAllowedFromIP(newreq, userWithoutAllow) {
t.Errorf("Test %d: Unexpected result for user without allowForm set", i)
}
}
}

60
main.go
View File

@ -1,7 +1,6 @@
package main
import (
"encoding/json"
"flag"
"fmt"
"github.com/acme-dns/acme-dns/pkg/api"
@ -15,60 +14,6 @@ import (
"go.uber.org/zap"
)
func setupLogging(config acmedns.AcmeDnsConfig) (*zap.Logger, error) {
var logger *zap.Logger
logformat := "console"
if config.Logconfig.Format == "json" {
logformat = "json"
}
outputPath := "stdout"
if config.Logconfig.Logtype == "file" {
outputPath = config.Logconfig.File
}
errorPath := "stderr"
if config.Logconfig.Logtype == "file" {
errorPath = config.Logconfig.File
}
zapConfigJson := fmt.Sprintf(`{
"level": "%s",
"encoding": "%s",
"outputPaths": ["%s"],
"errorOutputPaths": ["%s"],
"encoderConfig": {
"timeKey": "time",
"messageKey": "msg",
"levelKey": "level",
"levelEncoder": "lowercase",
"timeEncoder": "iso8601"
}
}`, config.Logconfig.Level, logformat, outputPath, errorPath)
var zapCfg zap.Config
if err := json.Unmarshal([]byte(zapConfigJson), &zapCfg); err != nil {
return logger, err
}
logger, err := zapCfg.Build()
return logger, err
}
func readConfig(configFile string) (acmedns.AcmeDnsConfig, string, error) {
var usedConfigFile string
var config acmedns.AcmeDnsConfig
var err error
if acmedns.FileIsAccessible(configFile) {
usedConfigFile = configFile
config, err = acmedns.ReadConfig(configFile)
} else if acmedns.FileIsAccessible("./config.cfg") {
usedConfigFile = "./config.cfg"
config, err = acmedns.ReadConfig("./config.cfg")
} else {
err = fmt.Errorf("configuration file not found")
}
if err != nil {
err = fmt.Errorf("encountered an error while trying to read configuration file: %s\n", err)
}
return config, usedConfigFile, err
}
func main() {
syscall.Umask(0077)
configPtr := flag.String("c", "/etc/acme-dns/config.cfg", "config file location")
@ -76,18 +21,19 @@ func main() {
// Read global config
var err error
var logger *zap.Logger
config, usedConfigFile, err := readConfig(*configPtr)
config, usedConfigFile, err := acmedns.ReadConfig(*configPtr)
if err != nil {
fmt.Printf("Error: %s\n", err)
os.Exit(1)
}
logger, err = setupLogging(config)
logger, err = acmedns.SetupLogging(config)
if err != nil {
fmt.Printf("Could not set up logging: %s\n", err)
os.Exit(1)
}
defer logger.Sync()
sugar := logger.Sugar()
sugar.Infow("Using config file",
"file", usedConfigFile)
sugar.Info("Starting up")

View File

@ -2,6 +2,7 @@ package acmedns
import (
"errors"
"fmt"
"os"
"github.com/BurntSushi/toml"
@ -20,7 +21,7 @@ func FileIsAccessible(fname string) bool {
return true
}
func ReadConfig(fname string) (AcmeDnsConfig, error) {
func readTomlConfig(fname string) (AcmeDnsConfig, error) {
var conf AcmeDnsConfig
_, err := toml.DecodeFile(fname, &conf)
if err != nil {
@ -46,3 +47,22 @@ func prepareConfig(conf AcmeDnsConfig) (AcmeDnsConfig, error) {
return conf, nil
}
func ReadConfig(configFile string) (AcmeDnsConfig, string, error) {
var usedConfigFile string
var config AcmeDnsConfig
var err error
if FileIsAccessible(configFile) {
usedConfigFile = configFile
config, err = readTomlConfig(configFile)
} else if FileIsAccessible("./config.cfg") {
usedConfigFile = "./config.cfg"
config, err = readTomlConfig("./config.cfg")
} else {
err = fmt.Errorf("configuration file not found")
}
if err != nil {
err = fmt.Errorf("encountered an error while trying to read configuration file: %s\n", err)
}
return config, usedConfigFile, err
}

View File

@ -18,5 +18,6 @@ type AcmednsDB interface {
type AcmednsNS interface {
Start(errorChannel chan error)
SetOwnAuthKey(key string)
SetNotifyStartedFunc(func())
ParseRecords()
}

43
pkg/acmedns/logging.go Normal file
View File

@ -0,0 +1,43 @@
package acmedns
import (
"encoding/json"
"fmt"
"go.uber.org/zap"
)
func SetupLogging(config AcmeDnsConfig) (*zap.Logger, error) {
var logger *zap.Logger
logformat := "console"
if config.Logconfig.Format == "json" {
logformat = "json"
}
outputPath := "stdout"
if config.Logconfig.Logtype == "file" {
outputPath = config.Logconfig.File
}
errorPath := "stderr"
if config.Logconfig.Logtype == "file" {
errorPath = config.Logconfig.File
}
zapConfigJson := fmt.Sprintf(`{
"level": "%s",
"encoding": "%s",
"outputPaths": ["%s"],
"errorOutputPaths": ["%s"],
"encoderConfig": {
"timeKey": "time",
"messageKey": "msg",
"levelKey": "level",
"levelEncoder": "lowercase",
"timeEncoder": "iso8601"
}
}`, config.Logconfig.Level, logformat, outputPath, errorPath)
var zapCfg zap.Config
if err := json.Unmarshal([]byte(zapConfigJson), &zapCfg); err != nil {
return logger, err
}
logger, err := zapCfg.Build()
return logger, err
}

View File

@ -2,6 +2,7 @@ package acmedns
import (
"crypto/rand"
"golang.org/x/crypto/bcrypt"
"math/big"
"regexp"
)
@ -29,3 +30,10 @@ func generatePassword(length int) string {
}
return string(ret)
}
func CorrectPassword(pw string, hash string) bool {
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pw)); err == nil {
return true
}
return false
}

View File

@ -1,6 +1,8 @@
package main
package acmedns
import (
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"io/ioutil"
"os"
"syscall"
@ -9,21 +11,55 @@ import (
log "github.com/sirupsen/logrus"
)
func fakeConfig() AcmeDnsConfig {
conf := AcmeDnsConfig{}
conf.Logconfig.Logtype = "stdout"
return conf
}
func TestSetupLogging(t *testing.T) {
conf := fakeConfig()
for i, test := range []struct {
format string
level string
expected string
expected zapcore.Level
}{
{"text", "warning", "warning"},
{"json", "debug", "debug"},
{"text", "info", "info"},
{"json", "error", "error"},
{"text", "something", "warning"},
{"text", "warn", zap.WarnLevel},
{"json", "debug", zap.DebugLevel},
{"text", "info", zap.InfoLevel},
{"json", "error", zap.ErrorLevel},
} {
setupLogging(test.format, test.level)
if log.GetLevel().String() != test.expected {
t.Errorf("Test %d: Expected loglevel %s but got %s", i, test.expected, log.GetLevel().String())
conf.Logconfig.Format = test.format
conf.Logconfig.Level = test.level
logger, err := SetupLogging(conf)
if err != nil {
t.Errorf("Got unexpected error: %s", err)
} else {
if logger.Sugar().Level() != test.expected {
t.Errorf("Test %d: Expected loglevel %s but got %s", i, test.expected, log.GetLevel().String())
}
}
}
}
func TestSetupLoggingError(t *testing.T) {
conf := fakeConfig()
for _, test := range []struct {
format string
level string
errexpected bool
}{
{"text", "warn", false},
{"json", "debug", false},
{"text", "info", false},
{"json", "error", false},
{"text", "something", true},
} {
conf.Logconfig.Format = test.format
conf.Logconfig.Level = test.level
_, err := SetupLogging(conf)
if test.errexpected && err == nil {
t.Errorf("Expected error but did not get one for loglevel: %s", err)
}
}
}
@ -31,11 +67,11 @@ func TestSetupLogging(t *testing.T) {
func TestReadConfig(t *testing.T) {
for i, test := range []struct {
inFile []byte
output DNSConfig
output AcmeDnsConfig
}{
{
[]byte("[general]\nlisten = \":53\"\ndebug = true\n[api]\napi_domain = \"something.strange\""),
DNSConfig{
AcmeDnsConfig{
General: general{
Listen: ":53",
Debug: true,
@ -48,10 +84,10 @@ func TestReadConfig(t *testing.T) {
{
[]byte("[\x00[[[[[[[[[de\nlisten =]"),
DNSConfig{},
AcmeDnsConfig{},
},
} {
tmpfile, err := ioutil.TempFile("", "acmedns")
tmpfile, err := os.CreateTemp("", "acmedns")
if err != nil {
t.Error("Could not create temporary file")
}
@ -64,7 +100,7 @@ func TestReadConfig(t *testing.T) {
if err := tmpfile.Close(); err != nil {
t.Error("Could not close temporary file")
}
ret, _ := readConfig(tmpfile.Name())
ret, _, _ := ReadConfig(tmpfile.Name())
if ret.General.Listen != test.output.General.Listen {
t.Errorf("Test %d: Expected listen value %s, but got %s", i, test.output.General.Listen, ret.General.Listen)
}
@ -74,30 +110,6 @@ func TestReadConfig(t *testing.T) {
}
}
func TestGetIPListFromHeader(t *testing.T) {
for i, test := range []struct {
input string
output []string
}{
{"1.1.1.1, 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
{" 1.1.1.1 , 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
{",1.1.1.1 ,2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
} {
res := getIPListFromHeader(test.input)
if len(res) != len(test.output) {
t.Errorf("Test %d: Expected [%d] items in return list, but got [%d]", i, len(test.output), len(res))
} else {
for j, vv := range test.output {
if res[j] != vv {
t.Errorf("Test %d: Expected return value [%v] but got [%v]", j, test.output, res)
}
}
}
}
}
func TestFileCheckPermissionDenied(t *testing.T) {
tmpfile, err := ioutil.TempFile("", "acmedns")
if err != nil {
@ -105,14 +117,14 @@ func TestFileCheckPermissionDenied(t *testing.T) {
}
defer os.Remove(tmpfile.Name())
_ = syscall.Chmod(tmpfile.Name(), 0000)
if fileIsAccessible(tmpfile.Name()) {
if FileIsAccessible(tmpfile.Name()) {
t.Errorf("File should not be accessible")
}
_ = syscall.Chmod(tmpfile.Name(), 0644)
}
func TestFileCheckNotExists(t *testing.T) {
if fileIsAccessible("/path/that/does/not/exist") {
if FileIsAccessible("/path/that/does/not/exist") {
t.Errorf("File should not be accessible")
}
}
@ -123,19 +135,19 @@ func TestFileCheckOK(t *testing.T) {
t.Error("Could not create temporary file")
}
defer os.Remove(tmpfile.Name())
if !fileIsAccessible(tmpfile.Name()) {
if !FileIsAccessible(tmpfile.Name()) {
t.Errorf("File should be accessible")
}
}
func TestPrepareConfig(t *testing.T) {
for i, test := range []struct {
input DNSConfig
input AcmeDnsConfig
shoulderror bool
}{
{DNSConfig{Database: dbsettings{Engine: "whatever", Connection: "whatever_too"}}, false},
{DNSConfig{Database: dbsettings{Engine: "", Connection: "whatever_too"}}, true},
{DNSConfig{Database: dbsettings{Engine: "whatever", Connection: ""}}, true},
{AcmeDnsConfig{Database: dbsettings{Engine: "whatever", Connection: "whatever_too"}}, false},
{AcmeDnsConfig{Database: dbsettings{Engine: "", Connection: "whatever_too"}}, true},
{AcmeDnsConfig{Database: dbsettings{Engine: "whatever", Connection: ""}}, true},
} {
_, err := prepareConfig(test.input)
if test.shoulderror {

View File

@ -1,4 +1,4 @@
package main
package api
import (
"context"
@ -8,17 +8,29 @@ import (
"net/http/httptest"
"testing"
"github.com/acme-dns/acme-dns/pkg/acmedns"
"github.com/acme-dns/acme-dns/pkg/database"
"github.com/DATA-DOG/go-sqlmock"
"github.com/gavv/httpexpect"
"github.com/google/uuid"
"github.com/julienschmidt/httprouter"
"github.com/rs/cors"
"go.uber.org/zap"
)
func fakeConfigAndLogger() (acmedns.AcmeDnsConfig, *zap.SugaredLogger) {
c := acmedns.AcmeDnsConfig{}
c.Database.Engine = "sqlite"
c.Database.Connection = ":memory:"
l := zap.NewNop().Sugar()
return c, l
}
// noAuth function to write ACMETxt model to context while not preforming any validation
func noAuth(update httprouter.Handle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
postData := ACMETxt{}
postData := acmedns.ACMETxt{}
uname := r.Header.Get("X-Api-User")
passwd := r.Header.Get("X-Api-Key")
@ -46,42 +58,37 @@ func getExpect(t *testing.T, server *httptest.Server) *httpexpect.Expect {
})
}
func setupRouter(debug bool, noauth bool) http.Handler {
func setupRouter(debug bool, noauth bool) (http.Handler, AcmednsAPI, acmedns.AcmednsDB) {
api := httprouter.New()
var dbcfg = dbsettings{
Engine: "sqlite3",
Connection: ":memory:"}
var httpapicfg = httpapi{
Domain: "",
Port: "8080",
TLS: "none",
CorsOrigins: []string{"*"},
UseHeader: true,
HeaderName: "X-Forwarded-For",
}
var dnscfg = DNSConfig{
API: httpapicfg,
Database: dbcfg,
}
Config = dnscfg
config, logger := fakeConfigAndLogger()
config.API.Domain = ""
config.API.Port = "8080"
config.API.TLS = "none"
config.API.CorsOrigins = []string{"*"}
config.API.UseHeader = true
config.API.HeaderName = "X-Forwarded-For"
db, _ := database.Init(&config, logger)
errChan := make(chan error, 1)
adnsapi := Init(&config, db, logger, errChan)
c := cors.New(cors.Options{
AllowedOrigins: Config.API.CorsOrigins,
AllowedOrigins: config.API.CorsOrigins,
AllowedMethods: []string{"GET", "POST"},
OptionsPassthrough: false,
Debug: Config.General.Debug,
Debug: config.General.Debug,
})
api.POST("/register", webRegisterPost)
api.GET("/health", healthCheck)
api.POST("/register", adnsapi.webRegisterPost)
api.GET("/health", adnsapi.healthCheck)
if noauth {
api.POST("/update", noAuth(webUpdatePost))
api.POST("/update", noAuth(adnsapi.webUpdatePost))
} else {
api.POST("/update", Auth(webUpdatePost))
api.POST("/update", adnsapi.Auth(adnsapi.webUpdatePost))
}
return c.Handler(api)
return c.Handler(api), adnsapi, db
}
func TestApiRegister(t *testing.T) {
router := setupRouter(false, false)
router, _, _ := setupRouter(false, false)
server := httptest.NewServer(router)
defer server.Close()
e := getExpect(t, server)
@ -117,7 +124,7 @@ func TestApiRegister(t *testing.T) {
}
func TestApiRegisterBadAllowFrom(t *testing.T) {
router := setupRouter(false, false)
router, _, _ := setupRouter(false, false)
server := httptest.NewServer(router)
defer server.Close()
e := getExpect(t, server)
@ -147,7 +154,7 @@ func TestApiRegisterBadAllowFrom(t *testing.T) {
}
func TestApiRegisterMalformedJSON(t *testing.T) {
router := setupRouter(false, false)
router, _, _ := setupRouter(false, false)
server := httptest.NewServer(router)
defer server.Close()
e := getExpect(t, server)
@ -174,13 +181,13 @@ func TestApiRegisterMalformedJSON(t *testing.T) {
}
func TestApiRegisterWithMockDB(t *testing.T) {
router := setupRouter(false, false)
router, _, db := setupRouter(false, false)
server := httptest.NewServer(router)
defer server.Close()
e := getExpect(t, server)
oldDb := DB.GetBackend()
db, mock, _ := sqlmock.New()
DB.SetBackend(db)
oldDb := db.GetBackend()
mdb, mock, _ := sqlmock.New()
db.SetBackend(mdb)
defer db.Close()
mock.ExpectBegin()
mock.ExpectPrepare("INSERT INTO records").WillReturnError(errors.New("error"))
@ -188,7 +195,7 @@ func TestApiRegisterWithMockDB(t *testing.T) {
Status(http.StatusInternalServerError).
JSON().Object().
ContainsKey("error")
DB.SetBackend(oldDb)
db.SetBackend(oldDb)
}
func TestApiUpdateWithInvalidSubdomain(t *testing.T) {
@ -198,11 +205,11 @@ func TestApiUpdateWithInvalidSubdomain(t *testing.T) {
"subdomain": "",
"txt": ""}
router := setupRouter(false, false)
router, _, db := setupRouter(false, false)
server := httptest.NewServer(router)
defer server.Close()
e := getExpect(t, server)
newUser, err := DB.Register(cidrslice{})
newUser, err := db.Register(acmedns.Cidrslice{})
if err != nil {
t.Errorf("Could not create new user, got error [%v]", err)
}
@ -228,11 +235,11 @@ func TestApiUpdateWithInvalidTxt(t *testing.T) {
"subdomain": "",
"txt": ""}
router := setupRouter(false, false)
router, _, db := setupRouter(false, false)
server := httptest.NewServer(router)
defer server.Close()
e := getExpect(t, server)
newUser, err := DB.Register(cidrslice{})
newUser, err := db.Register(acmedns.Cidrslice{})
if err != nil {
t.Errorf("Could not create new user, got error [%v]", err)
}
@ -252,7 +259,7 @@ func TestApiUpdateWithInvalidTxt(t *testing.T) {
}
func TestApiUpdateWithoutCredentials(t *testing.T) {
router := setupRouter(false, false)
router, _, _ := setupRouter(false, false)
server := httptest.NewServer(router)
defer server.Close()
e := getExpect(t, server)
@ -270,11 +277,11 @@ func TestApiUpdateWithCredentials(t *testing.T) {
"subdomain": "",
"txt": ""}
router := setupRouter(false, false)
router, _, db := setupRouter(false, false)
server := httptest.NewServer(router)
defer server.Close()
e := getExpect(t, server)
newUser, err := DB.Register(cidrslice{})
newUser, err := db.Register(acmedns.Cidrslice{})
if err != nil {
t.Errorf("Could not create new user, got error [%v]", err)
}
@ -303,13 +310,13 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
updateJSON["subdomain"] = "a097455b-52cc-4569-90c8-7a4b97c6eba8"
updateJSON["txt"] = validTxtData
router := setupRouter(false, true)
router, _, db := setupRouter(false, true)
server := httptest.NewServer(router)
defer server.Close()
e := getExpect(t, server)
oldDb := DB.GetBackend()
db, mock, _ := sqlmock.New()
DB.SetBackend(db)
oldDb := db.GetBackend()
mdb, mock, _ := sqlmock.New()
db.SetBackend(mdb)
defer db.Close()
mock.ExpectBegin()
mock.ExpectPrepare("UPDATE records").WillReturnError(errors.New("error"))
@ -319,31 +326,31 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
Status(http.StatusInternalServerError).
JSON().Object().
ContainsKey("error")
DB.SetBackend(oldDb)
db.SetBackend(oldDb)
}
func TestApiManyUpdateWithCredentials(t *testing.T) {
validTxtData := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
router := setupRouter(true, false)
router, _, db := setupRouter(true, false)
server := httptest.NewServer(router)
defer server.Close()
e := getExpect(t, server)
// User without defined CIDR masks
newUser, err := DB.Register(cidrslice{})
newUser, err := db.Register(acmedns.Cidrslice{})
if err != nil {
t.Errorf("Could not create new user, got error [%v]", err)
}
// User with defined allow from - CIDR masks, all invalid
// (httpexpect doesn't provide a way to mock remote ip)
newUserWithCIDR, err := DB.Register(cidrslice{"192.168.1.1/32", "invalid"})
newUserWithCIDR, err := db.Register(acmedns.Cidrslice{"192.168.1.1/32", "invalid"})
if err != nil {
t.Errorf("Could not create new user with CIDR, got error [%v]", err)
}
// Another user with valid CIDR mask to match the httpexpect default
newUserWithValidCIDR, err := DB.Register(cidrslice{"10.1.2.3/32", "invalid"})
newUserWithValidCIDR, err := db.Register(acmedns.Cidrslice{"10.1.2.3/32", "invalid"})
if err != nil {
t.Errorf("Could not create new user with a valid CIDR, got error [%v]", err)
}
@ -381,30 +388,30 @@ func TestApiManyUpdateWithCredentials(t *testing.T) {
func TestApiManyUpdateWithIpCheckHeaders(t *testing.T) {
router := setupRouter(false, false)
router, adnsapi, db := setupRouter(false, false)
server := httptest.NewServer(router)
defer server.Close()
e := getExpect(t, server)
// Use header checks from default header (X-Forwarded-For)
Config.API.UseHeader = true
adnsapi.Config.API.UseHeader = true
// User without defined CIDR masks
newUser, err := DB.Register(cidrslice{})
newUser, err := db.Register(acmedns.Cidrslice{})
if err != nil {
t.Errorf("Could not create new user, got error [%v]", err)
}
newUserWithCIDR, err := DB.Register(cidrslice{"192.168.1.2/32", "invalid"})
newUserWithCIDR, err := db.Register(acmedns.Cidrslice{"192.168.1.2/32", "invalid"})
if err != nil {
t.Errorf("Could not create new user with CIDR, got error [%v]", err)
}
newUserWithIP6CIDR, err := DB.Register(cidrslice{"2002:c0a8::0/32"})
newUserWithIP6CIDR, err := db.Register(acmedns.Cidrslice{"2002:c0a8::0/32"})
if err != nil {
t.Errorf("Could not create a new user with IP6 CIDR, got error [%v]", err)
}
for _, test := range []struct {
user ACMETxt
user acmedns.ACMETxt
headerValue string
status int
}{
@ -428,13 +435,66 @@ func TestApiManyUpdateWithIpCheckHeaders(t *testing.T) {
Expect().
Status(test.status)
}
Config.API.UseHeader = false
adnsapi.Config.API.UseHeader = false
}
func TestApiHealthCheck(t *testing.T) {
router := setupRouter(false, false)
router, _, _ := setupRouter(false, false)
server := httptest.NewServer(router)
defer server.Close()
e := getExpect(t, server)
e.GET("/health").Expect().Status(http.StatusOK)
}
func TestGetIPListFromHeader(t *testing.T) {
for i, test := range []struct {
input string
output []string
}{
{"1.1.1.1, 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
{" 1.1.1.1 , 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
{",1.1.1.1 ,2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
} {
res := getIPListFromHeader(test.input)
if len(res) != len(test.output) {
t.Errorf("Test %d: Expected [%d] items in return list, but got [%d]", i, len(test.output), len(res))
} else {
for j, vv := range test.output {
if res[j] != vv {
t.Errorf("Test %d: Expected return value [%v] but got [%v]", j, test.output, res)
}
}
}
}
}
func TestUpdateAllowedFromIP(t *testing.T) {
_, adnsapi, _ := setupRouter(false, false)
adnsapi.Config.API.UseHeader = false
userWithAllow := acmedns.NewACMETxt()
userWithAllow.AllowFrom = acmedns.Cidrslice{"192.168.1.2/32", "[::1]/128"}
userWithoutAllow := acmedns.NewACMETxt()
for i, test := range []struct {
remoteaddr string
expected bool
}{
{"192.168.1.2:1234", true},
{"192.168.1.1:1234", false},
{"invalid", false},
{"[::1]:4567", true},
} {
newreq, _ := http.NewRequest("GET", "/whatever", nil)
newreq.RemoteAddr = test.remoteaddr
ret := adnsapi.updateAllowedFromIP(newreq, userWithAllow)
if test.expected != ret {
t.Errorf("Test %d: Unexpected result for user with allowForm set", i)
}
if !adnsapi.updateAllowedFromIP(newreq, userWithoutAllow) {
t.Errorf("Test %d: Unexpected result for user without allowForm set", i)
}
}
}

View File

@ -74,11 +74,11 @@ func (a *AcmednsAPI) getUserFromRequest(r *http.Request) (acmedns.ACMETxt, error
a.Logger.Errorw("Error while trying to get user",
"error", err.Error())
// To protect against timed side channel (never gonna give you up)
correctPassword(passwd, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36")
acmedns.CorrectPassword(passwd, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36")
return acmedns.ACMETxt{}, fmt.Errorf("invalid username: %s", uname)
}
if correctPassword(passwd, dbuser.Password) {
if acmedns.CorrectPassword(passwd, dbuser.Password) {
return dbuser, nil
}
return acmedns.ACMETxt{}, fmt.Errorf("invalid password for user %s", uname)

View File

@ -4,7 +4,6 @@ import (
"fmt"
"github.com/acme-dns/acme-dns/pkg/acmedns"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"regexp"
"strings"
"unicode/utf8"
@ -31,13 +30,6 @@ func validKey(k string) bool {
return false
}
func correctPassword(pw string, hash string) bool {
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pw)); err == nil {
return true
}
return false
}
func getIPListFromHeader(header string) []string {
iplist := []string{}
for _, v := range strings.Split(header, ",") {

View File

@ -1,6 +1,7 @@
package main
package api
import (
"github.com/acme-dns/acme-dns/pkg/acmedns"
"testing"
"github.com/google/uuid"
@ -103,7 +104,7 @@ func TestCorrectPassword(t *testing.T) {
false},
{"", "", false},
} {
ret := correctPassword(test.pw, test.hash)
ret := acmedns.CorrectPassword(test.pw, test.hash)
if ret != test.output {
t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret)
}
@ -112,12 +113,12 @@ func TestCorrectPassword(t *testing.T) {
func TestGetValidCIDRMasks(t *testing.T) {
for i, test := range []struct {
input cidrslice
output cidrslice
input acmedns.Cidrslice
output acmedns.Cidrslice
}{
{cidrslice{"10.0.0.1/24"}, cidrslice{"10.0.0.1/24"}},
{cidrslice{"invalid", "127.0.0.1/32"}, cidrslice{"127.0.0.1/32"}},
{cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}, cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}},
{acmedns.Cidrslice{"10.0.0.1/24"}, acmedns.Cidrslice{"10.0.0.1/24"}},
{acmedns.Cidrslice{"invalid", "127.0.0.1/32"}, acmedns.Cidrslice{"127.0.0.1/32"}},
{acmedns.Cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}, acmedns.Cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}},
} {
ret := test.input.ValidEntries()
if len(ret) == len(test.output) {

View File

@ -1,10 +1,12 @@
package main
package database
import (
"database/sql"
"database/sql/driver"
"errors"
"github.com/acme-dns/acme-dns/pkg/acmedns"
"github.com/erikstmartin/go-testdb"
"go.uber.org/zap"
"testing"
)
@ -21,42 +23,38 @@ func (r testResult) RowsAffected() (int64, error) {
return r.affectedRows, nil
}
func TestDBInit(t *testing.T) {
fakeDB := new(acmedb)
err := fakeDB.Init("notarealegine", "connectionstring")
if err == nil {
t.Errorf("Was expecting error, didn't get one.")
}
func fakeConfigAndLogger() (acmedns.AcmeDnsConfig, *zap.SugaredLogger) {
c := acmedns.AcmeDnsConfig{}
c.Database.Engine = "sqlite"
c.Database.Connection = ":memory:"
l := zap.NewNop().Sugar()
return c, l
}
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
return testResult{1, 0}, errors.New("Prepared query error")
})
defer testdb.Reset()
errorDB := new(acmedb)
err = errorDB.Init("testdb", "")
if err == nil {
t.Errorf("Was expecting DB initiation error but got none")
}
errorDB.Close()
func fakeDB() acmedns.AcmednsDB {
conf, logger := fakeConfigAndLogger()
db, _ := Init(&conf, logger)
return db
}
func TestRegisterNoCIDR(t *testing.T) {
// Register tests
_, err := DB.Register(cidrslice{})
DB := fakeDB()
_, err := DB.Register(acmedns.Cidrslice{})
if err != nil {
t.Errorf("Registration failed, got error [%v]", err)
}
}
func TestRegisterMany(t *testing.T) {
DB := fakeDB()
for i, test := range []struct {
input cidrslice
output cidrslice
input acmedns.Cidrslice
output acmedns.Cidrslice
}{
{cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}, cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}},
{cidrslice{"1.1.1./32", "1922.168.42.42/8", "1.1.1.1/33", "1.2.3.4/"}, cidrslice{}},
{cidrslice{"7.6.5.4/32", "invalid", "1.0.0.1/2"}, cidrslice{"7.6.5.4/32", "1.0.0.1/2"}},
{acmedns.Cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}, acmedns.Cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}},
{acmedns.Cidrslice{"1.1.1./32", "1922.168.42.42/8", "1.1.1.1/33", "1.2.3.4/"}, acmedns.Cidrslice{}},
{acmedns.Cidrslice{"7.6.5.4/32", "invalid", "1.0.0.1/2"}, acmedns.Cidrslice{"7.6.5.4/32", "1.0.0.1/2"}},
} {
user, err := DB.Register(test.input)
if err != nil {
@ -77,8 +75,9 @@ func TestRegisterMany(t *testing.T) {
}
func TestGetByUsername(t *testing.T) {
DB := fakeDB()
// Create reg to refer to
reg, err := DB.Register(cidrslice{})
reg, err := DB.Register(acmedns.Cidrslice{})
if err != nil {
t.Errorf("Registration failed, got error [%v]", err)
}
@ -97,13 +96,14 @@ func TestGetByUsername(t *testing.T) {
}
// regUser password already is a bcrypt hash
if !correctPassword(reg.Password, regUser.Password) {
if !acmedns.CorrectPassword(reg.Password, regUser.Password) {
t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regUser.Password)
}
}
func TestPrepareErrors(t *testing.T) {
reg, _ := DB.Register(cidrslice{})
DB := fakeDB()
reg, _ := DB.Register(acmedns.Cidrslice{})
tdb, err := sql.Open("testdb", "")
if err != nil {
t.Errorf("Got error: %v", err)
@ -125,7 +125,8 @@ func TestPrepareErrors(t *testing.T) {
}
func TestQueryExecErrors(t *testing.T) {
reg, _ := DB.Register(cidrslice{})
DB := fakeDB()
reg, _ := DB.Register(acmedns.Cidrslice{})
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
return testResult{1, 0}, errors.New("Prepared query error")
})
@ -156,7 +157,7 @@ func TestQueryExecErrors(t *testing.T) {
t.Errorf("Expected error from exec in GetByDomain, but got none")
}
_, err = DB.Register(cidrslice{})
_, err = DB.Register(acmedns.Cidrslice{})
if err == nil {
t.Errorf("Expected error from exec in Register, but got none")
}
@ -169,7 +170,8 @@ func TestQueryExecErrors(t *testing.T) {
}
func TestQueryScanErrors(t *testing.T) {
reg, _ := DB.Register(cidrslice{})
DB := fakeDB()
reg, _ := DB.Register(acmedns.Cidrslice{})
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
return testResult{1, 0}, errors.New("Prepared query error")
@ -198,7 +200,8 @@ func TestQueryScanErrors(t *testing.T) {
}
func TestBadDBValues(t *testing.T) {
reg, _ := DB.Register(cidrslice{})
DB := fakeDB()
reg, _ := DB.Register(acmedns.Cidrslice{})
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"}
@ -228,8 +231,9 @@ func TestBadDBValues(t *testing.T) {
}
func TestGetTXTForDomain(t *testing.T) {
DB := fakeDB()
// Create reg to refer to
reg, err := DB.Register(cidrslice{})
reg, err := DB.Register(acmedns.Cidrslice{})
if err != nil {
t.Errorf("Registration failed, got error [%v]", err)
}
@ -276,8 +280,9 @@ func TestGetTXTForDomain(t *testing.T) {
}
func TestUpdate(t *testing.T) {
DB := fakeDB()
// Create reg to refer to
reg, err := DB.Register(cidrslice{})
reg, err := DB.Register(acmedns.Cidrslice{})
if err != nil {
t.Errorf("Registration failed, got error [%v]", err)
}

View File

@ -1,20 +1,66 @@
package main
package nameserver
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"testing"
"github.com/acme-dns/acme-dns/pkg/acmedns"
"github.com/acme-dns/acme-dns/pkg/database"
"github.com/erikstmartin/go-testdb"
"github.com/miekg/dns"
"go.uber.org/zap"
"go.uber.org/zap/zaptest/observer"
"sync"
"testing"
)
type resolver struct {
server string
}
var records = []string{
"auth.example.org. A 192.168.1.100",
"ns1.auth.example.org. A 192.168.1.101",
"cn.example.org CNAME something.example.org.",
"!''b', unparseable ",
"ns2.auth.example.org. A 192.168.1.102",
}
func loggerHasEntryWithMessage(message string, logObserver *observer.ObservedLogs) bool {
if len(logObserver.FilterMessage(message).All()) > 0 {
return true
}
return false
}
func fakeConfigAndLogger() (acmedns.AcmeDnsConfig, *zap.SugaredLogger, *observer.ObservedLogs) {
c := acmedns.AcmeDnsConfig{}
c.Database.Engine = "sqlite"
c.Database.Connection = ":memory:"
obsCore, logObserver := observer.New(zap.DebugLevel)
obsLogger := zap.New(obsCore).Sugar()
return c, obsLogger, logObserver
}
func setupDNS() (acmedns.AcmednsNS, acmedns.AcmednsDB, *observer.ObservedLogs) {
config, logger, logObserver := fakeConfigAndLogger()
config.General.Domain = "auth.example.org"
config.General.Listen = "127.0.0.1:15353"
config.General.Proto = "udp"
config.General.Nsname = "ns1.auth.example.org"
config.General.Nsadmin = "admin.example.org"
config.General.StaticRecords = records
config.General.Debug = false
db, _ := database.Init(&config, logger)
server := Nameserver{Config: &config, DB: db, Logger: logger, personalAuthKey: ""}
server.Domains = make(map[string]Records)
server.Server = &dns.Server{Addr: config.General.Listen, Net: config.General.Proto}
server.ParseRecords()
server.OwnDomain = "auth.example.org."
return &server, db, logObserver
}
func (r *resolver) lookup(host string, qtype uint16) (*dns.Msg, error) {
msg := new(dns.Msg)
msg.Id = dns.Id()
@ -27,7 +73,6 @@ func (r *resolver) lookup(host string, qtype uint16) (*dns.Msg, error) {
if in != nil && in.Rcode != dns.RcodeSuccess {
return in, fmt.Errorf("Received error from the server [%s]", dns.RcodeToString[in.Rcode])
}
return in, nil
}
@ -49,6 +94,18 @@ func hasExpectedTXTAnswer(answer []dns.RR, cmpTXT string) error {
}
func TestQuestionDBError(t *testing.T) {
config, logger, _ := fakeConfigAndLogger()
config.General.Listen = "127.0.0.1:15353"
config.General.Proto = "udp"
config.General.Domain = "auth.example.org"
config.General.Nsname = "ns1.auth.example.org"
config.General.Nsadmin = "admin.example.org"
config.General.StaticRecords = records
config.General.Debug = false
db, _ := database.Init(&config, logger)
server := Nameserver{Config: &config, DB: db, Logger: logger, personalAuthKey: ""}
server.Domains = make(map[string]Records)
server.ParseRecords()
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"}
return testdb.RowsFromSlice(columns, [][]driver.Value{}), errors.New("Prepared query error")
@ -60,48 +117,61 @@ func TestQuestionDBError(t *testing.T) {
if err != nil {
t.Errorf("Got error: %v", err)
}
oldDb := DB.GetBackend()
oldDb := db.GetBackend()
DB.SetBackend(tdb)
defer DB.SetBackend(oldDb)
db.SetBackend(tdb)
defer db.SetBackend(oldDb)
q := dns.Question{Name: dns.Fqdn("whatever.tld"), Qtype: dns.TypeTXT, Qclass: dns.ClassINET}
_, err = dnsserver.answerTXT(q)
_, err = server.answerTXT(q)
if err == nil {
t.Errorf("Expected error but got none")
}
}
func TestParse(t *testing.T) {
var testcfg = DNSConfig{
General: general{
Domain: ")",
Nsname: "ns1.auth.example.org",
Nsadmin: "admin.example.org",
StaticRecords: []string{},
Debug: false,
},
}
dnsserver.ParseRecords(testcfg)
if !loggerHasEntryWithMessage("Error while adding SOA record") {
config, logger, logObserver := fakeConfigAndLogger()
config.General.Listen = "127.0.0.1:15353"
config.General.Proto = "udp"
config.General.Domain = ")"
config.General.Nsname = "ns1.auth.example.org"
config.General.Nsadmin = "admin.example.org"
config.General.StaticRecords = records
config.General.Debug = false
config.General.StaticRecords = []string{}
db, _ := database.Init(&config, logger)
server := Nameserver{Config: &config, DB: db, Logger: logger, personalAuthKey: ""}
server.Domains = make(map[string]Records)
server.ParseRecords()
if !loggerHasEntryWithMessage("Error while adding SOA record", logObserver) {
t.Errorf("Expected SOA parsing to return error, but did not find one")
}
}
func TestResolveA(t *testing.T) {
server, _, _ := setupDNS()
errChan := make(chan error, 1)
waitLock := sync.Mutex{}
waitLock.Lock()
server.SetNotifyStartedFunc(waitLock.Unlock)
go server.Start(errChan)
waitLock.Lock()
resolv := resolver{server: "127.0.0.1:15353"}
answer, err := resolv.lookup("auth.example.org", dns.TypeA)
if err != nil {
t.Errorf("%v", err)
return
}
if len(answer.Answer) == 0 {
t.Error("No answer for DNS query")
return
}
_, err = resolv.lookup("nonexistent.domain.tld", dns.TypeA)
if err == nil {
t.Errorf("Was expecting error because of NXDOMAIN but got none")
return
}
}
@ -195,17 +265,20 @@ func TestAuthoritative(t *testing.T) {
}
}
/*
func TestResolveTXT(t *testing.T) {
_, db, _ := setupDNS()
resolv := resolver{server: "127.0.0.1:15353"}
validTXT := "______________valid_response_______________"
atxt, err := DB.Register(cidrslice{})
atxt, err := db.Register(acmedns.Cidrslice{})
if err != nil {
t.Errorf("Could not initiate db record: [%v]", err)
return
}
atxt.Value = validTXT
err = DB.Update(atxt.ACMETxtPost)
err = db.Update(atxt.ACMETxtPost)
if err != nil {
t.Errorf("Could not update db record: [%v]", err)
return
@ -254,7 +327,7 @@ func TestResolveTXT(t *testing.T) {
}
}
}
}
}*/
func TestCaseInsensitiveResolveA(t *testing.T) {
resolv := resolver{server: "127.0.0.1:15353"}

View File

@ -9,7 +9,6 @@ import (
func (n *Nameserver) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
// handle edns0
opt := r.IsEdns0()
if opt != nil {

View File

@ -1,12 +1,11 @@
package nameserver
import (
"strings"
"github.com/acme-dns/acme-dns/pkg/acmedns"
"github.com/miekg/dns"
"go.uber.org/zap"
"strings"
"sync"
)
// Records is a slice of ResourceRecords
@ -15,21 +14,23 @@ type Records struct {
}
type Nameserver struct {
Config *acmedns.AcmeDnsConfig
DB acmedns.AcmednsDB
Logger *zap.SugaredLogger
Server *dns.Server
OwnDomain string
SOA dns.RR
personalAuthKey string
Domains map[string]Records
errChan chan error
Config *acmedns.AcmeDnsConfig
DB acmedns.AcmednsDB
Logger *zap.SugaredLogger
Server *dns.Server
OwnDomain string
NotifyStartedFunc func()
SOA dns.RR
personalAuthKey string
Domains map[string]Records
errChan chan error
}
func InitAndStart(config *acmedns.AcmeDnsConfig, db acmedns.AcmednsDB, logger *zap.SugaredLogger, errChan chan error) []acmedns.AcmednsNS {
dnsservers := make([]acmedns.AcmednsNS, 0)
waitLock := sync.Mutex{}
if strings.HasPrefix(config.General.Proto, "both") {
// Handle the case where DNS server should be started for both udp and tcp
udpProto := "udp"
tcpProto := "tcp"
@ -46,13 +47,22 @@ func InitAndStart(config *acmedns.AcmeDnsConfig, db acmedns.AcmednsDB, logger *z
dnsServerTCP := NewDNSServer(config, db, logger, tcpProto)
dnsservers = append(dnsservers, dnsServerTCP)
dnsServerTCP.ParseRecords()
// wait for the server to get started to proceed
waitLock.Lock()
dnsServerUDP.SetNotifyStartedFunc(waitLock.Unlock)
go dnsServerUDP.Start(errChan)
waitLock.Lock()
dnsServerTCP.SetNotifyStartedFunc(waitLock.Unlock)
go dnsServerTCP.Start(errChan)
waitLock.Lock()
} else {
dnsServer := NewDNSServer(config, db, logger, config.General.Proto)
dnsservers = append(dnsservers, dnsServer)
dnsServer.ParseRecords()
waitLock.Lock()
dnsServer.SetNotifyStartedFunc(waitLock.Unlock)
go dnsServer.Start(errChan)
waitLock.Lock()
}
return dnsservers
}
@ -67,7 +77,6 @@ func NewDNSServer(config *acmedns.AcmeDnsConfig, db acmedns.AcmednsDB, logger *z
domain = domain + "."
}
server.OwnDomain = strings.ToLower(domain)
server.DB = db
server.personalAuthKey = ""
server.Domains = make(map[string]Records)
return &server
@ -79,8 +88,15 @@ func (n *Nameserver) Start(errorChannel chan error) {
n.Logger.Infow("Starting DNS listener",
"addr", n.Server.Addr,
"proto", n.Server.Net)
if n.NotifyStartedFunc != nil {
n.Server.NotifyStartedFunc = n.NotifyStartedFunc
}
err := n.Server.ListenAndServe()
if err != nil {
errorChannel <- err
}
}
func (n *Nameserver) SetNotifyStartedFunc(fun func()) {
n.Server.NotifyStartedFunc = fun
}