Re-added tests
This commit is contained in:
parent
1405e6ab47
commit
157241994f
34
auth_test.go
34
auth_test.go
@ -1,34 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestUpdateAllowedFromIP(t *testing.T) {
|
|
||||||
Config.API.UseHeader = false
|
|
||||||
userWithAllow := newACMETxt()
|
|
||||||
userWithAllow.AllowFrom = cidrslice{"192.168.1.2/32", "[::1]/128"}
|
|
||||||
userWithoutAllow := newACMETxt()
|
|
||||||
|
|
||||||
for i, test := range []struct {
|
|
||||||
remoteaddr string
|
|
||||||
expected bool
|
|
||||||
}{
|
|
||||||
{"192.168.1.2:1234", true},
|
|
||||||
{"192.168.1.1:1234", false},
|
|
||||||
{"invalid", false},
|
|
||||||
{"[::1]:4567", true},
|
|
||||||
} {
|
|
||||||
newreq, _ := http.NewRequest("GET", "/whatever", nil)
|
|
||||||
newreq.RemoteAddr = test.remoteaddr
|
|
||||||
ret := updateAllowedFromIP(newreq, userWithAllow)
|
|
||||||
if test.expected != ret {
|
|
||||||
t.Errorf("Test %d: Unexpected result for user with allowForm set", i)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !updateAllowedFromIP(newreq, userWithoutAllow) {
|
|
||||||
t.Errorf("Test %d: Unexpected result for user without allowForm set", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
60
main.go
60
main.go
@ -1,7 +1,6 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/acme-dns/acme-dns/pkg/api"
|
"github.com/acme-dns/acme-dns/pkg/api"
|
||||||
@ -15,60 +14,6 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupLogging(config acmedns.AcmeDnsConfig) (*zap.Logger, error) {
|
|
||||||
var logger *zap.Logger
|
|
||||||
logformat := "console"
|
|
||||||
if config.Logconfig.Format == "json" {
|
|
||||||
logformat = "json"
|
|
||||||
}
|
|
||||||
outputPath := "stdout"
|
|
||||||
if config.Logconfig.Logtype == "file" {
|
|
||||||
outputPath = config.Logconfig.File
|
|
||||||
}
|
|
||||||
errorPath := "stderr"
|
|
||||||
if config.Logconfig.Logtype == "file" {
|
|
||||||
errorPath = config.Logconfig.File
|
|
||||||
}
|
|
||||||
zapConfigJson := fmt.Sprintf(`{
|
|
||||||
"level": "%s",
|
|
||||||
"encoding": "%s",
|
|
||||||
"outputPaths": ["%s"],
|
|
||||||
"errorOutputPaths": ["%s"],
|
|
||||||
"encoderConfig": {
|
|
||||||
"timeKey": "time",
|
|
||||||
"messageKey": "msg",
|
|
||||||
"levelKey": "level",
|
|
||||||
"levelEncoder": "lowercase",
|
|
||||||
"timeEncoder": "iso8601"
|
|
||||||
}
|
|
||||||
}`, config.Logconfig.Level, logformat, outputPath, errorPath)
|
|
||||||
var zapCfg zap.Config
|
|
||||||
if err := json.Unmarshal([]byte(zapConfigJson), &zapCfg); err != nil {
|
|
||||||
return logger, err
|
|
||||||
}
|
|
||||||
logger, err := zapCfg.Build()
|
|
||||||
return logger, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func readConfig(configFile string) (acmedns.AcmeDnsConfig, string, error) {
|
|
||||||
var usedConfigFile string
|
|
||||||
var config acmedns.AcmeDnsConfig
|
|
||||||
var err error
|
|
||||||
if acmedns.FileIsAccessible(configFile) {
|
|
||||||
usedConfigFile = configFile
|
|
||||||
config, err = acmedns.ReadConfig(configFile)
|
|
||||||
} else if acmedns.FileIsAccessible("./config.cfg") {
|
|
||||||
usedConfigFile = "./config.cfg"
|
|
||||||
config, err = acmedns.ReadConfig("./config.cfg")
|
|
||||||
} else {
|
|
||||||
err = fmt.Errorf("configuration file not found")
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("encountered an error while trying to read configuration file: %s\n", err)
|
|
||||||
}
|
|
||||||
return config, usedConfigFile, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
syscall.Umask(0077)
|
syscall.Umask(0077)
|
||||||
configPtr := flag.String("c", "/etc/acme-dns/config.cfg", "config file location")
|
configPtr := flag.String("c", "/etc/acme-dns/config.cfg", "config file location")
|
||||||
@ -76,18 +21,19 @@ func main() {
|
|||||||
// Read global config
|
// Read global config
|
||||||
var err error
|
var err error
|
||||||
var logger *zap.Logger
|
var logger *zap.Logger
|
||||||
config, usedConfigFile, err := readConfig(*configPtr)
|
config, usedConfigFile, err := acmedns.ReadConfig(*configPtr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Error: %s\n", err)
|
fmt.Printf("Error: %s\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
logger, err = setupLogging(config)
|
logger, err = acmedns.SetupLogging(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Could not set up logging: %s\n", err)
|
fmt.Printf("Could not set up logging: %s\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
defer logger.Sync()
|
defer logger.Sync()
|
||||||
sugar := logger.Sugar()
|
sugar := logger.Sugar()
|
||||||
|
|
||||||
sugar.Infow("Using config file",
|
sugar.Infow("Using config file",
|
||||||
"file", usedConfigFile)
|
"file", usedConfigFile)
|
||||||
sugar.Info("Starting up")
|
sugar.Info("Starting up")
|
||||||
|
|||||||
@ -2,6 +2,7 @@ package acmedns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/BurntSushi/toml"
|
"github.com/BurntSushi/toml"
|
||||||
@ -20,7 +21,7 @@ func FileIsAccessible(fname string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReadConfig(fname string) (AcmeDnsConfig, error) {
|
func readTomlConfig(fname string) (AcmeDnsConfig, error) {
|
||||||
var conf AcmeDnsConfig
|
var conf AcmeDnsConfig
|
||||||
_, err := toml.DecodeFile(fname, &conf)
|
_, err := toml.DecodeFile(fname, &conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -46,3 +47,22 @@ func prepareConfig(conf AcmeDnsConfig) (AcmeDnsConfig, error) {
|
|||||||
|
|
||||||
return conf, nil
|
return conf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ReadConfig(configFile string) (AcmeDnsConfig, string, error) {
|
||||||
|
var usedConfigFile string
|
||||||
|
var config AcmeDnsConfig
|
||||||
|
var err error
|
||||||
|
if FileIsAccessible(configFile) {
|
||||||
|
usedConfigFile = configFile
|
||||||
|
config, err = readTomlConfig(configFile)
|
||||||
|
} else if FileIsAccessible("./config.cfg") {
|
||||||
|
usedConfigFile = "./config.cfg"
|
||||||
|
config, err = readTomlConfig("./config.cfg")
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("configuration file not found")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("encountered an error while trying to read configuration file: %s\n", err)
|
||||||
|
}
|
||||||
|
return config, usedConfigFile, err
|
||||||
|
}
|
||||||
|
|||||||
@ -18,5 +18,6 @@ type AcmednsDB interface {
|
|||||||
type AcmednsNS interface {
|
type AcmednsNS interface {
|
||||||
Start(errorChannel chan error)
|
Start(errorChannel chan error)
|
||||||
SetOwnAuthKey(key string)
|
SetOwnAuthKey(key string)
|
||||||
|
SetNotifyStartedFunc(func())
|
||||||
ParseRecords()
|
ParseRecords()
|
||||||
}
|
}
|
||||||
|
|||||||
43
pkg/acmedns/logging.go
Normal file
43
pkg/acmedns/logging.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package acmedns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SetupLogging(config AcmeDnsConfig) (*zap.Logger, error) {
|
||||||
|
var logger *zap.Logger
|
||||||
|
logformat := "console"
|
||||||
|
if config.Logconfig.Format == "json" {
|
||||||
|
logformat = "json"
|
||||||
|
}
|
||||||
|
outputPath := "stdout"
|
||||||
|
if config.Logconfig.Logtype == "file" {
|
||||||
|
outputPath = config.Logconfig.File
|
||||||
|
}
|
||||||
|
errorPath := "stderr"
|
||||||
|
if config.Logconfig.Logtype == "file" {
|
||||||
|
errorPath = config.Logconfig.File
|
||||||
|
}
|
||||||
|
zapConfigJson := fmt.Sprintf(`{
|
||||||
|
"level": "%s",
|
||||||
|
"encoding": "%s",
|
||||||
|
"outputPaths": ["%s"],
|
||||||
|
"errorOutputPaths": ["%s"],
|
||||||
|
"encoderConfig": {
|
||||||
|
"timeKey": "time",
|
||||||
|
"messageKey": "msg",
|
||||||
|
"levelKey": "level",
|
||||||
|
"levelEncoder": "lowercase",
|
||||||
|
"timeEncoder": "iso8601"
|
||||||
|
}
|
||||||
|
}`, config.Logconfig.Level, logformat, outputPath, errorPath)
|
||||||
|
var zapCfg zap.Config
|
||||||
|
if err := json.Unmarshal([]byte(zapConfigJson), &zapCfg); err != nil {
|
||||||
|
return logger, err
|
||||||
|
}
|
||||||
|
logger, err := zapCfg.Build()
|
||||||
|
return logger, err
|
||||||
|
}
|
||||||
@ -2,6 +2,7 @@ package acmedns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"regexp"
|
"regexp"
|
||||||
)
|
)
|
||||||
@ -29,3 +30,10 @@ func generatePassword(length int) string {
|
|||||||
}
|
}
|
||||||
return string(ret)
|
return string(ret)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CorrectPassword(pw string, hash string) bool {
|
||||||
|
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pw)); err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
package main
|
package acmedns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"go.uber.org/zap/zapcore"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"syscall"
|
"syscall"
|
||||||
@ -9,33 +11,67 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func fakeConfig() AcmeDnsConfig {
|
||||||
|
conf := AcmeDnsConfig{}
|
||||||
|
conf.Logconfig.Logtype = "stdout"
|
||||||
|
return conf
|
||||||
|
}
|
||||||
|
|
||||||
func TestSetupLogging(t *testing.T) {
|
func TestSetupLogging(t *testing.T) {
|
||||||
|
conf := fakeConfig()
|
||||||
for i, test := range []struct {
|
for i, test := range []struct {
|
||||||
format string
|
format string
|
||||||
level string
|
level string
|
||||||
expected string
|
expected zapcore.Level
|
||||||
}{
|
}{
|
||||||
{"text", "warning", "warning"},
|
{"text", "warn", zap.WarnLevel},
|
||||||
{"json", "debug", "debug"},
|
{"json", "debug", zap.DebugLevel},
|
||||||
{"text", "info", "info"},
|
{"text", "info", zap.InfoLevel},
|
||||||
{"json", "error", "error"},
|
{"json", "error", zap.ErrorLevel},
|
||||||
{"text", "something", "warning"},
|
|
||||||
} {
|
} {
|
||||||
setupLogging(test.format, test.level)
|
conf.Logconfig.Format = test.format
|
||||||
if log.GetLevel().String() != test.expected {
|
conf.Logconfig.Level = test.level
|
||||||
|
logger, err := SetupLogging(conf)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Got unexpected error: %s", err)
|
||||||
|
} else {
|
||||||
|
if logger.Sugar().Level() != test.expected {
|
||||||
t.Errorf("Test %d: Expected loglevel %s but got %s", i, test.expected, log.GetLevel().String())
|
t.Errorf("Test %d: Expected loglevel %s but got %s", i, test.expected, log.GetLevel().String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetupLoggingError(t *testing.T) {
|
||||||
|
conf := fakeConfig()
|
||||||
|
for _, test := range []struct {
|
||||||
|
format string
|
||||||
|
level string
|
||||||
|
errexpected bool
|
||||||
|
}{
|
||||||
|
{"text", "warn", false},
|
||||||
|
{"json", "debug", false},
|
||||||
|
{"text", "info", false},
|
||||||
|
{"json", "error", false},
|
||||||
|
{"text", "something", true},
|
||||||
|
} {
|
||||||
|
conf.Logconfig.Format = test.format
|
||||||
|
conf.Logconfig.Level = test.level
|
||||||
|
_, err := SetupLogging(conf)
|
||||||
|
if test.errexpected && err == nil {
|
||||||
|
t.Errorf("Expected error but did not get one for loglevel: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReadConfig(t *testing.T) {
|
func TestReadConfig(t *testing.T) {
|
||||||
for i, test := range []struct {
|
for i, test := range []struct {
|
||||||
inFile []byte
|
inFile []byte
|
||||||
output DNSConfig
|
output AcmeDnsConfig
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
[]byte("[general]\nlisten = \":53\"\ndebug = true\n[api]\napi_domain = \"something.strange\""),
|
[]byte("[general]\nlisten = \":53\"\ndebug = true\n[api]\napi_domain = \"something.strange\""),
|
||||||
DNSConfig{
|
AcmeDnsConfig{
|
||||||
General: general{
|
General: general{
|
||||||
Listen: ":53",
|
Listen: ":53",
|
||||||
Debug: true,
|
Debug: true,
|
||||||
@ -48,10 +84,10 @@ func TestReadConfig(t *testing.T) {
|
|||||||
|
|
||||||
{
|
{
|
||||||
[]byte("[\x00[[[[[[[[[de\nlisten =]"),
|
[]byte("[\x00[[[[[[[[[de\nlisten =]"),
|
||||||
DNSConfig{},
|
AcmeDnsConfig{},
|
||||||
},
|
},
|
||||||
} {
|
} {
|
||||||
tmpfile, err := ioutil.TempFile("", "acmedns")
|
tmpfile, err := os.CreateTemp("", "acmedns")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Could not create temporary file")
|
t.Error("Could not create temporary file")
|
||||||
}
|
}
|
||||||
@ -64,7 +100,7 @@ func TestReadConfig(t *testing.T) {
|
|||||||
if err := tmpfile.Close(); err != nil {
|
if err := tmpfile.Close(); err != nil {
|
||||||
t.Error("Could not close temporary file")
|
t.Error("Could not close temporary file")
|
||||||
}
|
}
|
||||||
ret, _ := readConfig(tmpfile.Name())
|
ret, _, _ := ReadConfig(tmpfile.Name())
|
||||||
if ret.General.Listen != test.output.General.Listen {
|
if ret.General.Listen != test.output.General.Listen {
|
||||||
t.Errorf("Test %d: Expected listen value %s, but got %s", i, test.output.General.Listen, ret.General.Listen)
|
t.Errorf("Test %d: Expected listen value %s, but got %s", i, test.output.General.Listen, ret.General.Listen)
|
||||||
}
|
}
|
||||||
@ -74,30 +110,6 @@ func TestReadConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetIPListFromHeader(t *testing.T) {
|
|
||||||
for i, test := range []struct {
|
|
||||||
input string
|
|
||||||
output []string
|
|
||||||
}{
|
|
||||||
{"1.1.1.1, 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
|
|
||||||
{" 1.1.1.1 , 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
|
|
||||||
{",1.1.1.1 ,2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
|
|
||||||
} {
|
|
||||||
res := getIPListFromHeader(test.input)
|
|
||||||
if len(res) != len(test.output) {
|
|
||||||
t.Errorf("Test %d: Expected [%d] items in return list, but got [%d]", i, len(test.output), len(res))
|
|
||||||
} else {
|
|
||||||
|
|
||||||
for j, vv := range test.output {
|
|
||||||
if res[j] != vv {
|
|
||||||
t.Errorf("Test %d: Expected return value [%v] but got [%v]", j, test.output, res)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileCheckPermissionDenied(t *testing.T) {
|
func TestFileCheckPermissionDenied(t *testing.T) {
|
||||||
tmpfile, err := ioutil.TempFile("", "acmedns")
|
tmpfile, err := ioutil.TempFile("", "acmedns")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -105,14 +117,14 @@ func TestFileCheckPermissionDenied(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer os.Remove(tmpfile.Name())
|
defer os.Remove(tmpfile.Name())
|
||||||
_ = syscall.Chmod(tmpfile.Name(), 0000)
|
_ = syscall.Chmod(tmpfile.Name(), 0000)
|
||||||
if fileIsAccessible(tmpfile.Name()) {
|
if FileIsAccessible(tmpfile.Name()) {
|
||||||
t.Errorf("File should not be accessible")
|
t.Errorf("File should not be accessible")
|
||||||
}
|
}
|
||||||
_ = syscall.Chmod(tmpfile.Name(), 0644)
|
_ = syscall.Chmod(tmpfile.Name(), 0644)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFileCheckNotExists(t *testing.T) {
|
func TestFileCheckNotExists(t *testing.T) {
|
||||||
if fileIsAccessible("/path/that/does/not/exist") {
|
if FileIsAccessible("/path/that/does/not/exist") {
|
||||||
t.Errorf("File should not be accessible")
|
t.Errorf("File should not be accessible")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -123,19 +135,19 @@ func TestFileCheckOK(t *testing.T) {
|
|||||||
t.Error("Could not create temporary file")
|
t.Error("Could not create temporary file")
|
||||||
}
|
}
|
||||||
defer os.Remove(tmpfile.Name())
|
defer os.Remove(tmpfile.Name())
|
||||||
if !fileIsAccessible(tmpfile.Name()) {
|
if !FileIsAccessible(tmpfile.Name()) {
|
||||||
t.Errorf("File should be accessible")
|
t.Errorf("File should be accessible")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPrepareConfig(t *testing.T) {
|
func TestPrepareConfig(t *testing.T) {
|
||||||
for i, test := range []struct {
|
for i, test := range []struct {
|
||||||
input DNSConfig
|
input AcmeDnsConfig
|
||||||
shoulderror bool
|
shoulderror bool
|
||||||
}{
|
}{
|
||||||
{DNSConfig{Database: dbsettings{Engine: "whatever", Connection: "whatever_too"}}, false},
|
{AcmeDnsConfig{Database: dbsettings{Engine: "whatever", Connection: "whatever_too"}}, false},
|
||||||
{DNSConfig{Database: dbsettings{Engine: "", Connection: "whatever_too"}}, true},
|
{AcmeDnsConfig{Database: dbsettings{Engine: "", Connection: "whatever_too"}}, true},
|
||||||
{DNSConfig{Database: dbsettings{Engine: "whatever", Connection: ""}}, true},
|
{AcmeDnsConfig{Database: dbsettings{Engine: "whatever", Connection: ""}}, true},
|
||||||
} {
|
} {
|
||||||
_, err := prepareConfig(test.input)
|
_, err := prepareConfig(test.input)
|
||||||
if test.shoulderror {
|
if test.shoulderror {
|
||||||
@ -1,4 +1,4 @@
|
|||||||
package main
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -8,17 +8,29 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/acme-dns/acme-dns/pkg/acmedns"
|
||||||
|
"github.com/acme-dns/acme-dns/pkg/database"
|
||||||
|
|
||||||
"github.com/DATA-DOG/go-sqlmock"
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
"github.com/gavv/httpexpect"
|
"github.com/gavv/httpexpect"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
"github.com/rs/cors"
|
"github.com/rs/cors"
|
||||||
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func fakeConfigAndLogger() (acmedns.AcmeDnsConfig, *zap.SugaredLogger) {
|
||||||
|
c := acmedns.AcmeDnsConfig{}
|
||||||
|
c.Database.Engine = "sqlite"
|
||||||
|
c.Database.Connection = ":memory:"
|
||||||
|
l := zap.NewNop().Sugar()
|
||||||
|
return c, l
|
||||||
|
}
|
||||||
|
|
||||||
// noAuth function to write ACMETxt model to context while not preforming any validation
|
// noAuth function to write ACMETxt model to context while not preforming any validation
|
||||||
func noAuth(update httprouter.Handle) httprouter.Handle {
|
func noAuth(update httprouter.Handle) httprouter.Handle {
|
||||||
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
||||||
postData := ACMETxt{}
|
postData := acmedns.ACMETxt{}
|
||||||
uname := r.Header.Get("X-Api-User")
|
uname := r.Header.Get("X-Api-User")
|
||||||
passwd := r.Header.Get("X-Api-Key")
|
passwd := r.Header.Get("X-Api-Key")
|
||||||
|
|
||||||
@ -46,42 +58,37 @@ func getExpect(t *testing.T, server *httptest.Server) *httpexpect.Expect {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupRouter(debug bool, noauth bool) http.Handler {
|
func setupRouter(debug bool, noauth bool) (http.Handler, AcmednsAPI, acmedns.AcmednsDB) {
|
||||||
api := httprouter.New()
|
api := httprouter.New()
|
||||||
var dbcfg = dbsettings{
|
config, logger := fakeConfigAndLogger()
|
||||||
Engine: "sqlite3",
|
config.API.Domain = ""
|
||||||
Connection: ":memory:"}
|
config.API.Port = "8080"
|
||||||
var httpapicfg = httpapi{
|
config.API.TLS = "none"
|
||||||
Domain: "",
|
config.API.CorsOrigins = []string{"*"}
|
||||||
Port: "8080",
|
config.API.UseHeader = true
|
||||||
TLS: "none",
|
config.API.HeaderName = "X-Forwarded-For"
|
||||||
CorsOrigins: []string{"*"},
|
|
||||||
UseHeader: true,
|
db, _ := database.Init(&config, logger)
|
||||||
HeaderName: "X-Forwarded-For",
|
errChan := make(chan error, 1)
|
||||||
}
|
adnsapi := Init(&config, db, logger, errChan)
|
||||||
var dnscfg = DNSConfig{
|
|
||||||
API: httpapicfg,
|
|
||||||
Database: dbcfg,
|
|
||||||
}
|
|
||||||
Config = dnscfg
|
|
||||||
c := cors.New(cors.Options{
|
c := cors.New(cors.Options{
|
||||||
AllowedOrigins: Config.API.CorsOrigins,
|
AllowedOrigins: config.API.CorsOrigins,
|
||||||
AllowedMethods: []string{"GET", "POST"},
|
AllowedMethods: []string{"GET", "POST"},
|
||||||
OptionsPassthrough: false,
|
OptionsPassthrough: false,
|
||||||
Debug: Config.General.Debug,
|
Debug: config.General.Debug,
|
||||||
})
|
})
|
||||||
api.POST("/register", webRegisterPost)
|
api.POST("/register", adnsapi.webRegisterPost)
|
||||||
api.GET("/health", healthCheck)
|
api.GET("/health", adnsapi.healthCheck)
|
||||||
if noauth {
|
if noauth {
|
||||||
api.POST("/update", noAuth(webUpdatePost))
|
api.POST("/update", noAuth(adnsapi.webUpdatePost))
|
||||||
} else {
|
} else {
|
||||||
api.POST("/update", Auth(webUpdatePost))
|
api.POST("/update", adnsapi.Auth(adnsapi.webUpdatePost))
|
||||||
}
|
}
|
||||||
return c.Handler(api)
|
return c.Handler(api), adnsapi, db
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestApiRegister(t *testing.T) {
|
func TestApiRegister(t *testing.T) {
|
||||||
router := setupRouter(false, false)
|
router, _, _ := setupRouter(false, false)
|
||||||
server := httptest.NewServer(router)
|
server := httptest.NewServer(router)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
e := getExpect(t, server)
|
e := getExpect(t, server)
|
||||||
@ -117,7 +124,7 @@ func TestApiRegister(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestApiRegisterBadAllowFrom(t *testing.T) {
|
func TestApiRegisterBadAllowFrom(t *testing.T) {
|
||||||
router := setupRouter(false, false)
|
router, _, _ := setupRouter(false, false)
|
||||||
server := httptest.NewServer(router)
|
server := httptest.NewServer(router)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
e := getExpect(t, server)
|
e := getExpect(t, server)
|
||||||
@ -147,7 +154,7 @@ func TestApiRegisterBadAllowFrom(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestApiRegisterMalformedJSON(t *testing.T) {
|
func TestApiRegisterMalformedJSON(t *testing.T) {
|
||||||
router := setupRouter(false, false)
|
router, _, _ := setupRouter(false, false)
|
||||||
server := httptest.NewServer(router)
|
server := httptest.NewServer(router)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
e := getExpect(t, server)
|
e := getExpect(t, server)
|
||||||
@ -174,13 +181,13 @@ func TestApiRegisterMalformedJSON(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestApiRegisterWithMockDB(t *testing.T) {
|
func TestApiRegisterWithMockDB(t *testing.T) {
|
||||||
router := setupRouter(false, false)
|
router, _, db := setupRouter(false, false)
|
||||||
server := httptest.NewServer(router)
|
server := httptest.NewServer(router)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
e := getExpect(t, server)
|
e := getExpect(t, server)
|
||||||
oldDb := DB.GetBackend()
|
oldDb := db.GetBackend()
|
||||||
db, mock, _ := sqlmock.New()
|
mdb, mock, _ := sqlmock.New()
|
||||||
DB.SetBackend(db)
|
db.SetBackend(mdb)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
mock.ExpectBegin()
|
mock.ExpectBegin()
|
||||||
mock.ExpectPrepare("INSERT INTO records").WillReturnError(errors.New("error"))
|
mock.ExpectPrepare("INSERT INTO records").WillReturnError(errors.New("error"))
|
||||||
@ -188,7 +195,7 @@ func TestApiRegisterWithMockDB(t *testing.T) {
|
|||||||
Status(http.StatusInternalServerError).
|
Status(http.StatusInternalServerError).
|
||||||
JSON().Object().
|
JSON().Object().
|
||||||
ContainsKey("error")
|
ContainsKey("error")
|
||||||
DB.SetBackend(oldDb)
|
db.SetBackend(oldDb)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestApiUpdateWithInvalidSubdomain(t *testing.T) {
|
func TestApiUpdateWithInvalidSubdomain(t *testing.T) {
|
||||||
@ -198,11 +205,11 @@ func TestApiUpdateWithInvalidSubdomain(t *testing.T) {
|
|||||||
"subdomain": "",
|
"subdomain": "",
|
||||||
"txt": ""}
|
"txt": ""}
|
||||||
|
|
||||||
router := setupRouter(false, false)
|
router, _, db := setupRouter(false, false)
|
||||||
server := httptest.NewServer(router)
|
server := httptest.NewServer(router)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
e := getExpect(t, server)
|
e := getExpect(t, server)
|
||||||
newUser, err := DB.Register(cidrslice{})
|
newUser, err := db.Register(acmedns.Cidrslice{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Could not create new user, got error [%v]", err)
|
t.Errorf("Could not create new user, got error [%v]", err)
|
||||||
}
|
}
|
||||||
@ -228,11 +235,11 @@ func TestApiUpdateWithInvalidTxt(t *testing.T) {
|
|||||||
"subdomain": "",
|
"subdomain": "",
|
||||||
"txt": ""}
|
"txt": ""}
|
||||||
|
|
||||||
router := setupRouter(false, false)
|
router, _, db := setupRouter(false, false)
|
||||||
server := httptest.NewServer(router)
|
server := httptest.NewServer(router)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
e := getExpect(t, server)
|
e := getExpect(t, server)
|
||||||
newUser, err := DB.Register(cidrslice{})
|
newUser, err := db.Register(acmedns.Cidrslice{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Could not create new user, got error [%v]", err)
|
t.Errorf("Could not create new user, got error [%v]", err)
|
||||||
}
|
}
|
||||||
@ -252,7 +259,7 @@ func TestApiUpdateWithInvalidTxt(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestApiUpdateWithoutCredentials(t *testing.T) {
|
func TestApiUpdateWithoutCredentials(t *testing.T) {
|
||||||
router := setupRouter(false, false)
|
router, _, _ := setupRouter(false, false)
|
||||||
server := httptest.NewServer(router)
|
server := httptest.NewServer(router)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
e := getExpect(t, server)
|
e := getExpect(t, server)
|
||||||
@ -270,11 +277,11 @@ func TestApiUpdateWithCredentials(t *testing.T) {
|
|||||||
"subdomain": "",
|
"subdomain": "",
|
||||||
"txt": ""}
|
"txt": ""}
|
||||||
|
|
||||||
router := setupRouter(false, false)
|
router, _, db := setupRouter(false, false)
|
||||||
server := httptest.NewServer(router)
|
server := httptest.NewServer(router)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
e := getExpect(t, server)
|
e := getExpect(t, server)
|
||||||
newUser, err := DB.Register(cidrslice{})
|
newUser, err := db.Register(acmedns.Cidrslice{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Could not create new user, got error [%v]", err)
|
t.Errorf("Could not create new user, got error [%v]", err)
|
||||||
}
|
}
|
||||||
@ -303,13 +310,13 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
|
|||||||
updateJSON["subdomain"] = "a097455b-52cc-4569-90c8-7a4b97c6eba8"
|
updateJSON["subdomain"] = "a097455b-52cc-4569-90c8-7a4b97c6eba8"
|
||||||
updateJSON["txt"] = validTxtData
|
updateJSON["txt"] = validTxtData
|
||||||
|
|
||||||
router := setupRouter(false, true)
|
router, _, db := setupRouter(false, true)
|
||||||
server := httptest.NewServer(router)
|
server := httptest.NewServer(router)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
e := getExpect(t, server)
|
e := getExpect(t, server)
|
||||||
oldDb := DB.GetBackend()
|
oldDb := db.GetBackend()
|
||||||
db, mock, _ := sqlmock.New()
|
mdb, mock, _ := sqlmock.New()
|
||||||
DB.SetBackend(db)
|
db.SetBackend(mdb)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
mock.ExpectBegin()
|
mock.ExpectBegin()
|
||||||
mock.ExpectPrepare("UPDATE records").WillReturnError(errors.New("error"))
|
mock.ExpectPrepare("UPDATE records").WillReturnError(errors.New("error"))
|
||||||
@ -319,31 +326,31 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
|
|||||||
Status(http.StatusInternalServerError).
|
Status(http.StatusInternalServerError).
|
||||||
JSON().Object().
|
JSON().Object().
|
||||||
ContainsKey("error")
|
ContainsKey("error")
|
||||||
DB.SetBackend(oldDb)
|
db.SetBackend(oldDb)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestApiManyUpdateWithCredentials(t *testing.T) {
|
func TestApiManyUpdateWithCredentials(t *testing.T) {
|
||||||
validTxtData := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
validTxtData := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||||
|
|
||||||
router := setupRouter(true, false)
|
router, _, db := setupRouter(true, false)
|
||||||
server := httptest.NewServer(router)
|
server := httptest.NewServer(router)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
e := getExpect(t, server)
|
e := getExpect(t, server)
|
||||||
// User without defined CIDR masks
|
// User without defined CIDR masks
|
||||||
newUser, err := DB.Register(cidrslice{})
|
newUser, err := db.Register(acmedns.Cidrslice{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Could not create new user, got error [%v]", err)
|
t.Errorf("Could not create new user, got error [%v]", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// User with defined allow from - CIDR masks, all invalid
|
// User with defined allow from - CIDR masks, all invalid
|
||||||
// (httpexpect doesn't provide a way to mock remote ip)
|
// (httpexpect doesn't provide a way to mock remote ip)
|
||||||
newUserWithCIDR, err := DB.Register(cidrslice{"192.168.1.1/32", "invalid"})
|
newUserWithCIDR, err := db.Register(acmedns.Cidrslice{"192.168.1.1/32", "invalid"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Could not create new user with CIDR, got error [%v]", err)
|
t.Errorf("Could not create new user with CIDR, got error [%v]", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Another user with valid CIDR mask to match the httpexpect default
|
// Another user with valid CIDR mask to match the httpexpect default
|
||||||
newUserWithValidCIDR, err := DB.Register(cidrslice{"10.1.2.3/32", "invalid"})
|
newUserWithValidCIDR, err := db.Register(acmedns.Cidrslice{"10.1.2.3/32", "invalid"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Could not create new user with a valid CIDR, got error [%v]", err)
|
t.Errorf("Could not create new user with a valid CIDR, got error [%v]", err)
|
||||||
}
|
}
|
||||||
@ -381,30 +388,30 @@ func TestApiManyUpdateWithCredentials(t *testing.T) {
|
|||||||
|
|
||||||
func TestApiManyUpdateWithIpCheckHeaders(t *testing.T) {
|
func TestApiManyUpdateWithIpCheckHeaders(t *testing.T) {
|
||||||
|
|
||||||
router := setupRouter(false, false)
|
router, adnsapi, db := setupRouter(false, false)
|
||||||
server := httptest.NewServer(router)
|
server := httptest.NewServer(router)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
e := getExpect(t, server)
|
e := getExpect(t, server)
|
||||||
// Use header checks from default header (X-Forwarded-For)
|
// Use header checks from default header (X-Forwarded-For)
|
||||||
Config.API.UseHeader = true
|
adnsapi.Config.API.UseHeader = true
|
||||||
// User without defined CIDR masks
|
// User without defined CIDR masks
|
||||||
newUser, err := DB.Register(cidrslice{})
|
newUser, err := db.Register(acmedns.Cidrslice{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Could not create new user, got error [%v]", err)
|
t.Errorf("Could not create new user, got error [%v]", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
newUserWithCIDR, err := DB.Register(cidrslice{"192.168.1.2/32", "invalid"})
|
newUserWithCIDR, err := db.Register(acmedns.Cidrslice{"192.168.1.2/32", "invalid"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Could not create new user with CIDR, got error [%v]", err)
|
t.Errorf("Could not create new user with CIDR, got error [%v]", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
newUserWithIP6CIDR, err := DB.Register(cidrslice{"2002:c0a8::0/32"})
|
newUserWithIP6CIDR, err := db.Register(acmedns.Cidrslice{"2002:c0a8::0/32"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Could not create a new user with IP6 CIDR, got error [%v]", err)
|
t.Errorf("Could not create a new user with IP6 CIDR, got error [%v]", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range []struct {
|
for _, test := range []struct {
|
||||||
user ACMETxt
|
user acmedns.ACMETxt
|
||||||
headerValue string
|
headerValue string
|
||||||
status int
|
status int
|
||||||
}{
|
}{
|
||||||
@ -428,13 +435,66 @@ func TestApiManyUpdateWithIpCheckHeaders(t *testing.T) {
|
|||||||
Expect().
|
Expect().
|
||||||
Status(test.status)
|
Status(test.status)
|
||||||
}
|
}
|
||||||
Config.API.UseHeader = false
|
adnsapi.Config.API.UseHeader = false
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestApiHealthCheck(t *testing.T) {
|
func TestApiHealthCheck(t *testing.T) {
|
||||||
router := setupRouter(false, false)
|
router, _, _ := setupRouter(false, false)
|
||||||
server := httptest.NewServer(router)
|
server := httptest.NewServer(router)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
e := getExpect(t, server)
|
e := getExpect(t, server)
|
||||||
e.GET("/health").Expect().Status(http.StatusOK)
|
e.GET("/health").Expect().Status(http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetIPListFromHeader(t *testing.T) {
|
||||||
|
for i, test := range []struct {
|
||||||
|
input string
|
||||||
|
output []string
|
||||||
|
}{
|
||||||
|
{"1.1.1.1, 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
|
||||||
|
{" 1.1.1.1 , 2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
|
||||||
|
{",1.1.1.1 ,2.2.2.2", []string{"1.1.1.1", "2.2.2.2"}},
|
||||||
|
} {
|
||||||
|
res := getIPListFromHeader(test.input)
|
||||||
|
if len(res) != len(test.output) {
|
||||||
|
t.Errorf("Test %d: Expected [%d] items in return list, but got [%d]", i, len(test.output), len(res))
|
||||||
|
} else {
|
||||||
|
|
||||||
|
for j, vv := range test.output {
|
||||||
|
if res[j] != vv {
|
||||||
|
t.Errorf("Test %d: Expected return value [%v] but got [%v]", j, test.output, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateAllowedFromIP(t *testing.T) {
|
||||||
|
_, adnsapi, _ := setupRouter(false, false)
|
||||||
|
adnsapi.Config.API.UseHeader = false
|
||||||
|
userWithAllow := acmedns.NewACMETxt()
|
||||||
|
userWithAllow.AllowFrom = acmedns.Cidrslice{"192.168.1.2/32", "[::1]/128"}
|
||||||
|
userWithoutAllow := acmedns.NewACMETxt()
|
||||||
|
|
||||||
|
for i, test := range []struct {
|
||||||
|
remoteaddr string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"192.168.1.2:1234", true},
|
||||||
|
{"192.168.1.1:1234", false},
|
||||||
|
{"invalid", false},
|
||||||
|
{"[::1]:4567", true},
|
||||||
|
} {
|
||||||
|
newreq, _ := http.NewRequest("GET", "/whatever", nil)
|
||||||
|
newreq.RemoteAddr = test.remoteaddr
|
||||||
|
ret := adnsapi.updateAllowedFromIP(newreq, userWithAllow)
|
||||||
|
if test.expected != ret {
|
||||||
|
t.Errorf("Test %d: Unexpected result for user with allowForm set", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !adnsapi.updateAllowedFromIP(newreq, userWithoutAllow) {
|
||||||
|
t.Errorf("Test %d: Unexpected result for user without allowForm set", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -74,11 +74,11 @@ func (a *AcmednsAPI) getUserFromRequest(r *http.Request) (acmedns.ACMETxt, error
|
|||||||
a.Logger.Errorw("Error while trying to get user",
|
a.Logger.Errorw("Error while trying to get user",
|
||||||
"error", err.Error())
|
"error", err.Error())
|
||||||
// To protect against timed side channel (never gonna give you up)
|
// To protect against timed side channel (never gonna give you up)
|
||||||
correctPassword(passwd, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36")
|
acmedns.CorrectPassword(passwd, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36")
|
||||||
|
|
||||||
return acmedns.ACMETxt{}, fmt.Errorf("invalid username: %s", uname)
|
return acmedns.ACMETxt{}, fmt.Errorf("invalid username: %s", uname)
|
||||||
}
|
}
|
||||||
if correctPassword(passwd, dbuser.Password) {
|
if acmedns.CorrectPassword(passwd, dbuser.Password) {
|
||||||
return dbuser, nil
|
return dbuser, nil
|
||||||
}
|
}
|
||||||
return acmedns.ACMETxt{}, fmt.Errorf("invalid password for user %s", uname)
|
return acmedns.ACMETxt{}, fmt.Errorf("invalid password for user %s", uname)
|
||||||
|
|||||||
@ -4,7 +4,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/acme-dns/acme-dns/pkg/acmedns"
|
"github.com/acme-dns/acme-dns/pkg/acmedns"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
@ -31,13 +30,6 @@ func validKey(k string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func correctPassword(pw string, hash string) bool {
|
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pw)); err == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func getIPListFromHeader(header string) []string {
|
func getIPListFromHeader(header string) []string {
|
||||||
iplist := []string{}
|
iplist := []string{}
|
||||||
for _, v := range strings.Split(header, ",") {
|
for _, v := range strings.Split(header, ",") {
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package main
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/acme-dns/acme-dns/pkg/acmedns"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@ -103,7 +104,7 @@ func TestCorrectPassword(t *testing.T) {
|
|||||||
false},
|
false},
|
||||||
{"", "", false},
|
{"", "", false},
|
||||||
} {
|
} {
|
||||||
ret := correctPassword(test.pw, test.hash)
|
ret := acmedns.CorrectPassword(test.pw, test.hash)
|
||||||
if ret != test.output {
|
if ret != test.output {
|
||||||
t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret)
|
t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret)
|
||||||
}
|
}
|
||||||
@ -112,12 +113,12 @@ func TestCorrectPassword(t *testing.T) {
|
|||||||
|
|
||||||
func TestGetValidCIDRMasks(t *testing.T) {
|
func TestGetValidCIDRMasks(t *testing.T) {
|
||||||
for i, test := range []struct {
|
for i, test := range []struct {
|
||||||
input cidrslice
|
input acmedns.Cidrslice
|
||||||
output cidrslice
|
output acmedns.Cidrslice
|
||||||
}{
|
}{
|
||||||
{cidrslice{"10.0.0.1/24"}, cidrslice{"10.0.0.1/24"}},
|
{acmedns.Cidrslice{"10.0.0.1/24"}, acmedns.Cidrslice{"10.0.0.1/24"}},
|
||||||
{cidrslice{"invalid", "127.0.0.1/32"}, cidrslice{"127.0.0.1/32"}},
|
{acmedns.Cidrslice{"invalid", "127.0.0.1/32"}, acmedns.Cidrslice{"127.0.0.1/32"}},
|
||||||
{cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}, cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}},
|
{acmedns.Cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}, acmedns.Cidrslice{"2002:c0a8::0/32", "8.8.8.8/32"}},
|
||||||
} {
|
} {
|
||||||
ret := test.input.ValidEntries()
|
ret := test.input.ValidEntries()
|
||||||
if len(ret) == len(test.output) {
|
if len(ret) == len(test.output) {
|
||||||
@ -1,10 +1,12 @@
|
|||||||
package main
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
|
"github.com/acme-dns/acme-dns/pkg/acmedns"
|
||||||
"github.com/erikstmartin/go-testdb"
|
"github.com/erikstmartin/go-testdb"
|
||||||
|
"go.uber.org/zap"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -21,42 +23,38 @@ func (r testResult) RowsAffected() (int64, error) {
|
|||||||
return r.affectedRows, nil
|
return r.affectedRows, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDBInit(t *testing.T) {
|
func fakeConfigAndLogger() (acmedns.AcmeDnsConfig, *zap.SugaredLogger) {
|
||||||
fakeDB := new(acmedb)
|
c := acmedns.AcmeDnsConfig{}
|
||||||
err := fakeDB.Init("notarealegine", "connectionstring")
|
c.Database.Engine = "sqlite"
|
||||||
if err == nil {
|
c.Database.Connection = ":memory:"
|
||||||
t.Errorf("Was expecting error, didn't get one.")
|
l := zap.NewNop().Sugar()
|
||||||
|
return c, l
|
||||||
}
|
}
|
||||||
|
|
||||||
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
|
func fakeDB() acmedns.AcmednsDB {
|
||||||
return testResult{1, 0}, errors.New("Prepared query error")
|
conf, logger := fakeConfigAndLogger()
|
||||||
})
|
db, _ := Init(&conf, logger)
|
||||||
defer testdb.Reset()
|
return db
|
||||||
|
|
||||||
errorDB := new(acmedb)
|
|
||||||
err = errorDB.Init("testdb", "")
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("Was expecting DB initiation error but got none")
|
|
||||||
}
|
|
||||||
errorDB.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegisterNoCIDR(t *testing.T) {
|
func TestRegisterNoCIDR(t *testing.T) {
|
||||||
// Register tests
|
// Register tests
|
||||||
_, err := DB.Register(cidrslice{})
|
DB := fakeDB()
|
||||||
|
_, err := DB.Register(acmedns.Cidrslice{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Registration failed, got error [%v]", err)
|
t.Errorf("Registration failed, got error [%v]", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegisterMany(t *testing.T) {
|
func TestRegisterMany(t *testing.T) {
|
||||||
|
DB := fakeDB()
|
||||||
for i, test := range []struct {
|
for i, test := range []struct {
|
||||||
input cidrslice
|
input acmedns.Cidrslice
|
||||||
output cidrslice
|
output acmedns.Cidrslice
|
||||||
}{
|
}{
|
||||||
{cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}, cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}},
|
{acmedns.Cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}, acmedns.Cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}},
|
||||||
{cidrslice{"1.1.1./32", "1922.168.42.42/8", "1.1.1.1/33", "1.2.3.4/"}, cidrslice{}},
|
{acmedns.Cidrslice{"1.1.1./32", "1922.168.42.42/8", "1.1.1.1/33", "1.2.3.4/"}, acmedns.Cidrslice{}},
|
||||||
{cidrslice{"7.6.5.4/32", "invalid", "1.0.0.1/2"}, cidrslice{"7.6.5.4/32", "1.0.0.1/2"}},
|
{acmedns.Cidrslice{"7.6.5.4/32", "invalid", "1.0.0.1/2"}, acmedns.Cidrslice{"7.6.5.4/32", "1.0.0.1/2"}},
|
||||||
} {
|
} {
|
||||||
user, err := DB.Register(test.input)
|
user, err := DB.Register(test.input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -77,8 +75,9 @@ func TestRegisterMany(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGetByUsername(t *testing.T) {
|
func TestGetByUsername(t *testing.T) {
|
||||||
|
DB := fakeDB()
|
||||||
// Create reg to refer to
|
// Create reg to refer to
|
||||||
reg, err := DB.Register(cidrslice{})
|
reg, err := DB.Register(acmedns.Cidrslice{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Registration failed, got error [%v]", err)
|
t.Errorf("Registration failed, got error [%v]", err)
|
||||||
}
|
}
|
||||||
@ -97,13 +96,14 @@ func TestGetByUsername(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// regUser password already is a bcrypt hash
|
// regUser password already is a bcrypt hash
|
||||||
if !correctPassword(reg.Password, regUser.Password) {
|
if !acmedns.CorrectPassword(reg.Password, regUser.Password) {
|
||||||
t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regUser.Password)
|
t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regUser.Password)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPrepareErrors(t *testing.T) {
|
func TestPrepareErrors(t *testing.T) {
|
||||||
reg, _ := DB.Register(cidrslice{})
|
DB := fakeDB()
|
||||||
|
reg, _ := DB.Register(acmedns.Cidrslice{})
|
||||||
tdb, err := sql.Open("testdb", "")
|
tdb, err := sql.Open("testdb", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Got error: %v", err)
|
t.Errorf("Got error: %v", err)
|
||||||
@ -125,7 +125,8 @@ func TestPrepareErrors(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryExecErrors(t *testing.T) {
|
func TestQueryExecErrors(t *testing.T) {
|
||||||
reg, _ := DB.Register(cidrslice{})
|
DB := fakeDB()
|
||||||
|
reg, _ := DB.Register(acmedns.Cidrslice{})
|
||||||
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
|
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
|
||||||
return testResult{1, 0}, errors.New("Prepared query error")
|
return testResult{1, 0}, errors.New("Prepared query error")
|
||||||
})
|
})
|
||||||
@ -156,7 +157,7 @@ func TestQueryExecErrors(t *testing.T) {
|
|||||||
t.Errorf("Expected error from exec in GetByDomain, but got none")
|
t.Errorf("Expected error from exec in GetByDomain, but got none")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = DB.Register(cidrslice{})
|
_, err = DB.Register(acmedns.Cidrslice{})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected error from exec in Register, but got none")
|
t.Errorf("Expected error from exec in Register, but got none")
|
||||||
}
|
}
|
||||||
@ -169,7 +170,8 @@ func TestQueryExecErrors(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryScanErrors(t *testing.T) {
|
func TestQueryScanErrors(t *testing.T) {
|
||||||
reg, _ := DB.Register(cidrslice{})
|
DB := fakeDB()
|
||||||
|
reg, _ := DB.Register(acmedns.Cidrslice{})
|
||||||
|
|
||||||
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
|
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
|
||||||
return testResult{1, 0}, errors.New("Prepared query error")
|
return testResult{1, 0}, errors.New("Prepared query error")
|
||||||
@ -198,7 +200,8 @@ func TestQueryScanErrors(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBadDBValues(t *testing.T) {
|
func TestBadDBValues(t *testing.T) {
|
||||||
reg, _ := DB.Register(cidrslice{})
|
DB := fakeDB()
|
||||||
|
reg, _ := DB.Register(acmedns.Cidrslice{})
|
||||||
|
|
||||||
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
|
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
|
||||||
columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"}
|
columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"}
|
||||||
@ -228,8 +231,9 @@ func TestBadDBValues(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGetTXTForDomain(t *testing.T) {
|
func TestGetTXTForDomain(t *testing.T) {
|
||||||
|
DB := fakeDB()
|
||||||
// Create reg to refer to
|
// Create reg to refer to
|
||||||
reg, err := DB.Register(cidrslice{})
|
reg, err := DB.Register(acmedns.Cidrslice{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Registration failed, got error [%v]", err)
|
t.Errorf("Registration failed, got error [%v]", err)
|
||||||
}
|
}
|
||||||
@ -276,8 +280,9 @@ func TestGetTXTForDomain(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdate(t *testing.T) {
|
func TestUpdate(t *testing.T) {
|
||||||
|
DB := fakeDB()
|
||||||
// Create reg to refer to
|
// Create reg to refer to
|
||||||
reg, err := DB.Register(cidrslice{})
|
reg, err := DB.Register(acmedns.Cidrslice{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Registration failed, got error [%v]", err)
|
t.Errorf("Registration failed, got error [%v]", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,20 +1,66 @@
|
|||||||
package main
|
package nameserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"github.com/acme-dns/acme-dns/pkg/acmedns"
|
||||||
|
"github.com/acme-dns/acme-dns/pkg/database"
|
||||||
"github.com/erikstmartin/go-testdb"
|
"github.com/erikstmartin/go-testdb"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"go.uber.org/zap/zaptest/observer"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
type resolver struct {
|
type resolver struct {
|
||||||
server string
|
server string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var records = []string{
|
||||||
|
"auth.example.org. A 192.168.1.100",
|
||||||
|
"ns1.auth.example.org. A 192.168.1.101",
|
||||||
|
"cn.example.org CNAME something.example.org.",
|
||||||
|
"!''b', unparseable ",
|
||||||
|
"ns2.auth.example.org. A 192.168.1.102",
|
||||||
|
}
|
||||||
|
|
||||||
|
func loggerHasEntryWithMessage(message string, logObserver *observer.ObservedLogs) bool {
|
||||||
|
if len(logObserver.FilterMessage(message).All()) > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func fakeConfigAndLogger() (acmedns.AcmeDnsConfig, *zap.SugaredLogger, *observer.ObservedLogs) {
|
||||||
|
c := acmedns.AcmeDnsConfig{}
|
||||||
|
c.Database.Engine = "sqlite"
|
||||||
|
c.Database.Connection = ":memory:"
|
||||||
|
obsCore, logObserver := observer.New(zap.DebugLevel)
|
||||||
|
obsLogger := zap.New(obsCore).Sugar()
|
||||||
|
return c, obsLogger, logObserver
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupDNS() (acmedns.AcmednsNS, acmedns.AcmednsDB, *observer.ObservedLogs) {
|
||||||
|
config, logger, logObserver := fakeConfigAndLogger()
|
||||||
|
config.General.Domain = "auth.example.org"
|
||||||
|
config.General.Listen = "127.0.0.1:15353"
|
||||||
|
config.General.Proto = "udp"
|
||||||
|
config.General.Nsname = "ns1.auth.example.org"
|
||||||
|
config.General.Nsadmin = "admin.example.org"
|
||||||
|
config.General.StaticRecords = records
|
||||||
|
config.General.Debug = false
|
||||||
|
db, _ := database.Init(&config, logger)
|
||||||
|
server := Nameserver{Config: &config, DB: db, Logger: logger, personalAuthKey: ""}
|
||||||
|
server.Domains = make(map[string]Records)
|
||||||
|
server.Server = &dns.Server{Addr: config.General.Listen, Net: config.General.Proto}
|
||||||
|
server.ParseRecords()
|
||||||
|
server.OwnDomain = "auth.example.org."
|
||||||
|
return &server, db, logObserver
|
||||||
|
}
|
||||||
|
|
||||||
func (r *resolver) lookup(host string, qtype uint16) (*dns.Msg, error) {
|
func (r *resolver) lookup(host string, qtype uint16) (*dns.Msg, error) {
|
||||||
msg := new(dns.Msg)
|
msg := new(dns.Msg)
|
||||||
msg.Id = dns.Id()
|
msg.Id = dns.Id()
|
||||||
@ -27,7 +73,6 @@ func (r *resolver) lookup(host string, qtype uint16) (*dns.Msg, error) {
|
|||||||
if in != nil && in.Rcode != dns.RcodeSuccess {
|
if in != nil && in.Rcode != dns.RcodeSuccess {
|
||||||
return in, fmt.Errorf("Received error from the server [%s]", dns.RcodeToString[in.Rcode])
|
return in, fmt.Errorf("Received error from the server [%s]", dns.RcodeToString[in.Rcode])
|
||||||
}
|
}
|
||||||
|
|
||||||
return in, nil
|
return in, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -49,6 +94,18 @@ func hasExpectedTXTAnswer(answer []dns.RR, cmpTXT string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestQuestionDBError(t *testing.T) {
|
func TestQuestionDBError(t *testing.T) {
|
||||||
|
config, logger, _ := fakeConfigAndLogger()
|
||||||
|
config.General.Listen = "127.0.0.1:15353"
|
||||||
|
config.General.Proto = "udp"
|
||||||
|
config.General.Domain = "auth.example.org"
|
||||||
|
config.General.Nsname = "ns1.auth.example.org"
|
||||||
|
config.General.Nsadmin = "admin.example.org"
|
||||||
|
config.General.StaticRecords = records
|
||||||
|
config.General.Debug = false
|
||||||
|
db, _ := database.Init(&config, logger)
|
||||||
|
server := Nameserver{Config: &config, DB: db, Logger: logger, personalAuthKey: ""}
|
||||||
|
server.Domains = make(map[string]Records)
|
||||||
|
server.ParseRecords()
|
||||||
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
|
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
|
||||||
columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"}
|
columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"}
|
||||||
return testdb.RowsFromSlice(columns, [][]driver.Value{}), errors.New("Prepared query error")
|
return testdb.RowsFromSlice(columns, [][]driver.Value{}), errors.New("Prepared query error")
|
||||||
@ -60,48 +117,61 @@ func TestQuestionDBError(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Got error: %v", err)
|
t.Errorf("Got error: %v", err)
|
||||||
}
|
}
|
||||||
oldDb := DB.GetBackend()
|
oldDb := db.GetBackend()
|
||||||
|
|
||||||
DB.SetBackend(tdb)
|
db.SetBackend(tdb)
|
||||||
defer DB.SetBackend(oldDb)
|
defer db.SetBackend(oldDb)
|
||||||
|
|
||||||
q := dns.Question{Name: dns.Fqdn("whatever.tld"), Qtype: dns.TypeTXT, Qclass: dns.ClassINET}
|
q := dns.Question{Name: dns.Fqdn("whatever.tld"), Qtype: dns.TypeTXT, Qclass: dns.ClassINET}
|
||||||
_, err = dnsserver.answerTXT(q)
|
_, err = server.answerTXT(q)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected error but got none")
|
t.Errorf("Expected error but got none")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParse(t *testing.T) {
|
func TestParse(t *testing.T) {
|
||||||
var testcfg = DNSConfig{
|
config, logger, logObserver := fakeConfigAndLogger()
|
||||||
General: general{
|
config.General.Listen = "127.0.0.1:15353"
|
||||||
Domain: ")",
|
config.General.Proto = "udp"
|
||||||
Nsname: "ns1.auth.example.org",
|
config.General.Domain = ")"
|
||||||
Nsadmin: "admin.example.org",
|
config.General.Nsname = "ns1.auth.example.org"
|
||||||
StaticRecords: []string{},
|
config.General.Nsadmin = "admin.example.org"
|
||||||
Debug: false,
|
config.General.StaticRecords = records
|
||||||
},
|
config.General.Debug = false
|
||||||
}
|
config.General.StaticRecords = []string{}
|
||||||
dnsserver.ParseRecords(testcfg)
|
db, _ := database.Init(&config, logger)
|
||||||
if !loggerHasEntryWithMessage("Error while adding SOA record") {
|
server := Nameserver{Config: &config, DB: db, Logger: logger, personalAuthKey: ""}
|
||||||
|
server.Domains = make(map[string]Records)
|
||||||
|
server.ParseRecords()
|
||||||
|
if !loggerHasEntryWithMessage("Error while adding SOA record", logObserver) {
|
||||||
t.Errorf("Expected SOA parsing to return error, but did not find one")
|
t.Errorf("Expected SOA parsing to return error, but did not find one")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolveA(t *testing.T) {
|
func TestResolveA(t *testing.T) {
|
||||||
|
server, _, _ := setupDNS()
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
waitLock := sync.Mutex{}
|
||||||
|
waitLock.Lock()
|
||||||
|
server.SetNotifyStartedFunc(waitLock.Unlock)
|
||||||
|
go server.Start(errChan)
|
||||||
|
waitLock.Lock()
|
||||||
resolv := resolver{server: "127.0.0.1:15353"}
|
resolv := resolver{server: "127.0.0.1:15353"}
|
||||||
answer, err := resolv.lookup("auth.example.org", dns.TypeA)
|
answer, err := resolv.lookup("auth.example.org", dns.TypeA)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%v", err)
|
t.Errorf("%v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(answer.Answer) == 0 {
|
if len(answer.Answer) == 0 {
|
||||||
t.Error("No answer for DNS query")
|
t.Error("No answer for DNS query")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = resolv.lookup("nonexistent.domain.tld", dns.TypeA)
|
_, err = resolv.lookup("nonexistent.domain.tld", dns.TypeA)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Was expecting error because of NXDOMAIN but got none")
|
t.Errorf("Was expecting error because of NXDOMAIN but got none")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -195,17 +265,20 @@ func TestAuthoritative(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
func TestResolveTXT(t *testing.T) {
|
func TestResolveTXT(t *testing.T) {
|
||||||
|
_, db, _ := setupDNS()
|
||||||
resolv := resolver{server: "127.0.0.1:15353"}
|
resolv := resolver{server: "127.0.0.1:15353"}
|
||||||
validTXT := "______________valid_response_______________"
|
validTXT := "______________valid_response_______________"
|
||||||
|
|
||||||
atxt, err := DB.Register(cidrslice{})
|
atxt, err := db.Register(acmedns.Cidrslice{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Could not initiate db record: [%v]", err)
|
t.Errorf("Could not initiate db record: [%v]", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
atxt.Value = validTXT
|
atxt.Value = validTXT
|
||||||
err = DB.Update(atxt.ACMETxtPost)
|
|
||||||
|
err = db.Update(atxt.ACMETxtPost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Could not update db record: [%v]", err)
|
t.Errorf("Could not update db record: [%v]", err)
|
||||||
return
|
return
|
||||||
@ -254,7 +327,7 @@ func TestResolveTXT(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}*/
|
||||||
|
|
||||||
func TestCaseInsensitiveResolveA(t *testing.T) {
|
func TestCaseInsensitiveResolveA(t *testing.T) {
|
||||||
resolv := resolver{server: "127.0.0.1:15353"}
|
resolv := resolver{server: "127.0.0.1:15353"}
|
||||||
@ -9,7 +9,6 @@ import (
|
|||||||
func (n *Nameserver) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
func (n *Nameserver) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetReply(r)
|
m.SetReply(r)
|
||||||
|
|
||||||
// handle edns0
|
// handle edns0
|
||||||
opt := r.IsEdns0()
|
opt := r.IsEdns0()
|
||||||
if opt != nil {
|
if opt != nil {
|
||||||
|
|||||||
@ -1,12 +1,11 @@
|
|||||||
package nameserver
|
package nameserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/acme-dns/acme-dns/pkg/acmedns"
|
"github.com/acme-dns/acme-dns/pkg/acmedns"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Records is a slice of ResourceRecords
|
// Records is a slice of ResourceRecords
|
||||||
@ -20,6 +19,7 @@ type Nameserver struct {
|
|||||||
Logger *zap.SugaredLogger
|
Logger *zap.SugaredLogger
|
||||||
Server *dns.Server
|
Server *dns.Server
|
||||||
OwnDomain string
|
OwnDomain string
|
||||||
|
NotifyStartedFunc func()
|
||||||
SOA dns.RR
|
SOA dns.RR
|
||||||
personalAuthKey string
|
personalAuthKey string
|
||||||
Domains map[string]Records
|
Domains map[string]Records
|
||||||
@ -28,8 +28,9 @@ type Nameserver struct {
|
|||||||
|
|
||||||
func InitAndStart(config *acmedns.AcmeDnsConfig, db acmedns.AcmednsDB, logger *zap.SugaredLogger, errChan chan error) []acmedns.AcmednsNS {
|
func InitAndStart(config *acmedns.AcmeDnsConfig, db acmedns.AcmednsDB, logger *zap.SugaredLogger, errChan chan error) []acmedns.AcmednsNS {
|
||||||
dnsservers := make([]acmedns.AcmednsNS, 0)
|
dnsservers := make([]acmedns.AcmednsNS, 0)
|
||||||
|
waitLock := sync.Mutex{}
|
||||||
if strings.HasPrefix(config.General.Proto, "both") {
|
if strings.HasPrefix(config.General.Proto, "both") {
|
||||||
|
|
||||||
// Handle the case where DNS server should be started for both udp and tcp
|
// Handle the case where DNS server should be started for both udp and tcp
|
||||||
udpProto := "udp"
|
udpProto := "udp"
|
||||||
tcpProto := "tcp"
|
tcpProto := "tcp"
|
||||||
@ -46,13 +47,22 @@ func InitAndStart(config *acmedns.AcmeDnsConfig, db acmedns.AcmednsDB, logger *z
|
|||||||
dnsServerTCP := NewDNSServer(config, db, logger, tcpProto)
|
dnsServerTCP := NewDNSServer(config, db, logger, tcpProto)
|
||||||
dnsservers = append(dnsservers, dnsServerTCP)
|
dnsservers = append(dnsservers, dnsServerTCP)
|
||||||
dnsServerTCP.ParseRecords()
|
dnsServerTCP.ParseRecords()
|
||||||
|
// wait for the server to get started to proceed
|
||||||
|
waitLock.Lock()
|
||||||
|
dnsServerUDP.SetNotifyStartedFunc(waitLock.Unlock)
|
||||||
go dnsServerUDP.Start(errChan)
|
go dnsServerUDP.Start(errChan)
|
||||||
|
waitLock.Lock()
|
||||||
|
dnsServerTCP.SetNotifyStartedFunc(waitLock.Unlock)
|
||||||
go dnsServerTCP.Start(errChan)
|
go dnsServerTCP.Start(errChan)
|
||||||
|
waitLock.Lock()
|
||||||
} else {
|
} else {
|
||||||
dnsServer := NewDNSServer(config, db, logger, config.General.Proto)
|
dnsServer := NewDNSServer(config, db, logger, config.General.Proto)
|
||||||
dnsservers = append(dnsservers, dnsServer)
|
dnsservers = append(dnsservers, dnsServer)
|
||||||
dnsServer.ParseRecords()
|
dnsServer.ParseRecords()
|
||||||
|
waitLock.Lock()
|
||||||
|
dnsServer.SetNotifyStartedFunc(waitLock.Unlock)
|
||||||
go dnsServer.Start(errChan)
|
go dnsServer.Start(errChan)
|
||||||
|
waitLock.Lock()
|
||||||
}
|
}
|
||||||
return dnsservers
|
return dnsservers
|
||||||
}
|
}
|
||||||
@ -67,7 +77,6 @@ func NewDNSServer(config *acmedns.AcmeDnsConfig, db acmedns.AcmednsDB, logger *z
|
|||||||
domain = domain + "."
|
domain = domain + "."
|
||||||
}
|
}
|
||||||
server.OwnDomain = strings.ToLower(domain)
|
server.OwnDomain = strings.ToLower(domain)
|
||||||
server.DB = db
|
|
||||||
server.personalAuthKey = ""
|
server.personalAuthKey = ""
|
||||||
server.Domains = make(map[string]Records)
|
server.Domains = make(map[string]Records)
|
||||||
return &server
|
return &server
|
||||||
@ -79,8 +88,15 @@ func (n *Nameserver) Start(errorChannel chan error) {
|
|||||||
n.Logger.Infow("Starting DNS listener",
|
n.Logger.Infow("Starting DNS listener",
|
||||||
"addr", n.Server.Addr,
|
"addr", n.Server.Addr,
|
||||||
"proto", n.Server.Net)
|
"proto", n.Server.Net)
|
||||||
|
if n.NotifyStartedFunc != nil {
|
||||||
|
n.Server.NotifyStartedFunc = n.NotifyStartedFunc
|
||||||
|
}
|
||||||
err := n.Server.ListenAndServe()
|
err := n.Server.ListenAndServe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errorChannel <- err
|
errorChannel <- err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *Nameserver) SetNotifyStartedFunc(fun func()) {
|
||||||
|
n.Server.NotifyStartedFunc = fun
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user