Refactoring - improving coverage (#371)
* Increase code coverage in acmedns * More testing of ReadConfig() and its fallback mechanism * Found that if someone put a '"' double quote into the filename that we configure zap to log to, it would cause the the JSON created to be invalid. I have replaced the JSON string with proper config * Better handling of config options for api.TLS - we now error on an invalid value instead of silently failing. added a basic test for api.setupTLS() (to increase test coverage) * testing nameserver isOwnChallenge and isAuthoritative methods * add a unit test for nameserver answerOwnChallenge * fix linting errors * bump go and golangci-lint versions in github actions * Update golangci-lint.yml Bumping github-actions workflow versions to accommodate some changes in upstream golanci-lint * Bump Golang version to 1.23 (currently the oldest supported version) Bump golanglint-ci to 2.0.2 and migrate the config file. This should resolve the math/rand/v2 issue * bump golanglint-ci action version * Fixing up new golanglint-ci warnings and errors --------- Co-authored-by: Joona Hoikkala <5235109+joohoi@users.noreply.github.com>
This commit is contained in:
parent
d20fae37c9
commit
e0f9745182
10
.github/workflows/golangci-lint.yml
vendored
10
.github/workflows/golangci-lint.yml
vendored
@ -8,14 +8,14 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v3
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: 1.18.4
|
||||
go-version: 1.23
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Run golangci-lint
|
||||
uses: golangci/golangci-lint-action@v3
|
||||
uses: golangci/golangci-lint-action@v7
|
||||
with:
|
||||
version: v1.48
|
||||
version: v2.0.2
|
||||
|
||||
@ -1,15 +1,30 @@
|
||||
version: "2"
|
||||
linters:
|
||||
exclusions:
|
||||
generated: lax
|
||||
presets:
|
||||
- comments
|
||||
- common-false-positives
|
||||
- legacy
|
||||
- std-error-handling
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
issues:
|
||||
max-issues-per-linter: 0
|
||||
max-same-issues: 0
|
||||
linters:
|
||||
# Enable specific linter
|
||||
# https://golangci-lint.run/usage/linters/#enabled-by-default
|
||||
formatters:
|
||||
enable:
|
||||
- gofmt
|
||||
- goimports
|
||||
linters-settings:
|
||||
settings:
|
||||
goimports:
|
||||
# A comma-separated list of prefixes, which, if set, checks import paths
|
||||
# with the given prefixes are grouped after 3rd-party packages.
|
||||
# Default: ""
|
||||
local-prefixes: github.com/acme-dns/acme-dns
|
||||
local-prefixes:
|
||||
- github.com/acme-dns/acme-dns
|
||||
exclusions:
|
||||
generated: lax
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
|
||||
2
go.mod
2
go.mod
@ -1,6 +1,6 @@
|
||||
module github.com/joohoi/acme-dns
|
||||
|
||||
go 1.19
|
||||
go 1.23
|
||||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.2.1
|
||||
|
||||
2
main.go
2
main.go
@ -21,7 +21,7 @@ func main() {
|
||||
// Read global config
|
||||
var err error
|
||||
var logger *zap.Logger
|
||||
config, usedConfigFile, err := acmedns.ReadConfig(*configPtr)
|
||||
config, usedConfigFile, err := acmedns.ReadConfig(*configPtr, "./config.cfg")
|
||||
if err != nil {
|
||||
fmt.Printf("Error: %s\n", err)
|
||||
os.Exit(1)
|
||||
|
||||
@ -6,7 +6,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Check if IP belongs to an allowed net
|
||||
// AllowedFrom Check if IP belongs to an allowed net
|
||||
func (a ACMETxt) AllowedFrom(ip string) bool {
|
||||
remoteIP := net.ParseIP(ip)
|
||||
// Range not limited
|
||||
@ -22,7 +22,7 @@ func (a ACMETxt) AllowedFrom(ip string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Go through list (most likely from headers) to check for the IP.
|
||||
// AllowedFromList Go through list (most likely from headers) to check for the IP.
|
||||
// Reason for this is that some setups use reverse proxy in front of acme-dns
|
||||
func (a ACMETxt) AllowedFromList(ips []string) bool {
|
||||
if len(ips) == 0 {
|
||||
|
||||
@ -8,6 +8,13 @@ import (
|
||||
"github.com/BurntSushi/toml"
|
||||
)
|
||||
|
||||
const (
|
||||
ApiTlsProviderNone = "none"
|
||||
ApiTlsProviderLetsEncrypt = "letsencrypt"
|
||||
ApiTlsProviderLetsEncryptStaging = "letsencryptstaging"
|
||||
ApiTlsProviderCert = "cert"
|
||||
)
|
||||
|
||||
func FileIsAccessible(fname string) bool {
|
||||
_, err := os.Stat(fname)
|
||||
if err != nil {
|
||||
@ -45,24 +52,31 @@ func prepareConfig(conf AcmeDnsConfig) (AcmeDnsConfig, error) {
|
||||
conf.API.ACMECacheDir = "api-certs"
|
||||
}
|
||||
|
||||
switch conf.API.TLS {
|
||||
case ApiTlsProviderCert, ApiTlsProviderLetsEncrypt, ApiTlsProviderLetsEncryptStaging, ApiTlsProviderNone:
|
||||
// we have a good value
|
||||
default:
|
||||
return conf, fmt.Errorf("invalid value for api.tls, expected one of [%s, %s, %s, %s]", ApiTlsProviderCert, ApiTlsProviderLetsEncrypt, ApiTlsProviderLetsEncryptStaging, ApiTlsProviderNone)
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func ReadConfig(configFile string) (AcmeDnsConfig, string, error) {
|
||||
func ReadConfig(configFile, fallback string) (AcmeDnsConfig, string, error) {
|
||||
var usedConfigFile string
|
||||
var config AcmeDnsConfig
|
||||
var err error
|
||||
if FileIsAccessible(configFile) {
|
||||
usedConfigFile = configFile
|
||||
config, err = readTomlConfig(configFile)
|
||||
} else if FileIsAccessible("./config.cfg") {
|
||||
usedConfigFile = "./config.cfg"
|
||||
config, err = readTomlConfig("./config.cfg")
|
||||
} else if FileIsAccessible(fallback) {
|
||||
usedConfigFile = fallback
|
||||
config, err = readTomlConfig(fallback)
|
||||
} else {
|
||||
err = fmt.Errorf("configuration file not found")
|
||||
}
|
||||
if err != nil {
|
||||
err = fmt.Errorf("encountered an error while trying to read configuration file: %s\n", err)
|
||||
err = fmt.Errorf("encountered an error while trying to read configuration file: %w", err)
|
||||
}
|
||||
return config, usedConfigFile, err
|
||||
}
|
||||
|
||||
@ -1,14 +1,18 @@
|
||||
package acmedns
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func SetupLogging(config AcmeDnsConfig) (*zap.Logger, error) {
|
||||
var logger *zap.Logger
|
||||
var (
|
||||
logger *zap.Logger
|
||||
zapCfg zap.Config
|
||||
err error
|
||||
)
|
||||
|
||||
logformat := "console"
|
||||
if config.Logconfig.Format == "json" {
|
||||
logformat = "json"
|
||||
@ -21,23 +25,22 @@ func SetupLogging(config AcmeDnsConfig) (*zap.Logger, error) {
|
||||
if config.Logconfig.Logtype == "file" {
|
||||
errorPath = config.Logconfig.File
|
||||
}
|
||||
zapConfigJson := fmt.Sprintf(`{
|
||||
"level": "%s",
|
||||
"encoding": "%s",
|
||||
"outputPaths": ["%s"],
|
||||
"errorOutputPaths": ["%s"],
|
||||
"encoderConfig": {
|
||||
"timeKey": "time",
|
||||
"messageKey": "msg",
|
||||
"levelKey": "level",
|
||||
"levelEncoder": "lowercase",
|
||||
"timeEncoder": "iso8601"
|
||||
}
|
||||
}`, config.Logconfig.Level, logformat, outputPath, errorPath)
|
||||
var zapCfg zap.Config
|
||||
if err := json.Unmarshal([]byte(zapConfigJson), &zapCfg); err != nil {
|
||||
|
||||
zapCfg.Level, err = zap.ParseAtomicLevel(config.Logconfig.Level)
|
||||
if err != nil {
|
||||
return logger, err
|
||||
}
|
||||
logger, err := zapCfg.Build()
|
||||
zapCfg.Encoding = logformat
|
||||
zapCfg.OutputPaths = []string{outputPath}
|
||||
zapCfg.ErrorOutputPaths = []string{errorPath}
|
||||
zapCfg.EncoderConfig = zapcore.EncoderConfig{
|
||||
TimeKey: "time",
|
||||
MessageKey: "msg",
|
||||
LevelKey: "level",
|
||||
EncodeLevel: zapcore.LowercaseLevelEncoder,
|
||||
EncodeTime: zapcore.ISO8601TimeEncoder,
|
||||
}
|
||||
|
||||
logger, err = zapCfg.Build()
|
||||
return logger, err
|
||||
}
|
||||
|
||||
36
pkg/acmedns/testdata/test_read_fallback_config.toml
vendored
Normal file
36
pkg/acmedns/testdata/test_read_fallback_config.toml
vendored
Normal file
@ -0,0 +1,36 @@
|
||||
[general]
|
||||
listen = "127.0.0.1:53"
|
||||
protocol = "both"
|
||||
domain = "test.example.org"
|
||||
nsname = "test.example.org"
|
||||
nsadmin = "test.example.org"
|
||||
records = [
|
||||
"test.example.org. A 127.0.0.1",
|
||||
"test.example.org. NS test.example.org.",
|
||||
]
|
||||
debug = true
|
||||
|
||||
[database]
|
||||
engine = "dinosaur"
|
||||
connection = "roar"
|
||||
|
||||
[api]
|
||||
ip = "0.0.0.0"
|
||||
disable_registration = false
|
||||
port = "443"
|
||||
tls = "none"
|
||||
tls_cert_privkey = "/etc/tls/example.org/privkey.pem"
|
||||
tls_cert_fullchain = "/etc/tls/example.org/fullchain.pem"
|
||||
acme_cache_dir = "api-certs"
|
||||
notification_email = ""
|
||||
corsorigins = [
|
||||
"*"
|
||||
]
|
||||
use_header = true
|
||||
header_name = "X-is-gonna-give-it-to-ya"
|
||||
|
||||
[logconfig]
|
||||
loglevel = "info"
|
||||
logtype = "stdout"
|
||||
logfile = "./acme-dns.log"
|
||||
logformat = "json"
|
||||
@ -8,7 +8,7 @@ type Account struct {
|
||||
Subdomain string
|
||||
}
|
||||
|
||||
// DNSConfig holds the config structure
|
||||
// AcmeDnsConfig holds the config structure
|
||||
type AcmeDnsConfig struct {
|
||||
General general
|
||||
Database dbsettings
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
package acmedns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"os"
|
||||
"reflect"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
@ -46,19 +49,33 @@ func TestSetupLoggingError(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
format string
|
||||
level string
|
||||
file string
|
||||
errexpected bool
|
||||
}{
|
||||
{"text", "warn", false},
|
||||
{"json", "debug", false},
|
||||
{"text", "info", false},
|
||||
{"json", "error", false},
|
||||
{"text", "something", true},
|
||||
{"text", "warn", "", false},
|
||||
{"json", "debug", "", false},
|
||||
{"text", "info", "", false},
|
||||
{"json", "error", "", false},
|
||||
{"text", "something", "", true},
|
||||
{"text", "info", "a path with\" in its name.txt", false},
|
||||
} {
|
||||
conf.Logconfig.Format = test.format
|
||||
conf.Logconfig.Level = test.level
|
||||
if test.file != "" {
|
||||
conf.Logconfig.File = test.file
|
||||
conf.Logconfig.Logtype = "file"
|
||||
|
||||
}
|
||||
_, err := SetupLogging(conf)
|
||||
if test.errexpected && err == nil {
|
||||
t.Errorf("Expected error but did not get one for loglevel: %s", err)
|
||||
} else if !test.errexpected && err != nil {
|
||||
t.Errorf("Unexpected error: %s", err)
|
||||
}
|
||||
|
||||
// clean up the file zap creates
|
||||
if test.file != "" {
|
||||
_ = os.Remove(test.file)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -88,7 +105,7 @@ func TestReadConfig(t *testing.T) {
|
||||
} {
|
||||
tmpfile, err := os.CreateTemp("", "acmedns")
|
||||
if err != nil {
|
||||
t.Error("Could not create temporary file")
|
||||
t.Fatalf("Could not create temporary file: %s", err)
|
||||
}
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
@ -99,7 +116,7 @@ func TestReadConfig(t *testing.T) {
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
t.Error("Could not close temporary file")
|
||||
}
|
||||
ret, _, _ := ReadConfig(tmpfile.Name())
|
||||
ret, _, _ := ReadConfig(tmpfile.Name(), "")
|
||||
if ret.General.Listen != test.output.General.Listen {
|
||||
t.Errorf("Test %d: Expected listen value %s, but got %s", i, test.output.General.Listen, ret.General.Listen)
|
||||
}
|
||||
@ -109,10 +126,112 @@ func TestReadConfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadConfigFallback(t *testing.T) {
|
||||
var (
|
||||
path string
|
||||
err error
|
||||
)
|
||||
|
||||
testPath := "testdata/test_read_fallback_config.toml"
|
||||
|
||||
path, err = getNonExistentPath()
|
||||
if err != nil {
|
||||
t.Errorf("failed getting non existant path: %s", err)
|
||||
}
|
||||
|
||||
cfg, used, err := ReadConfig(path, testPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read a config file when we should have: %s", err)
|
||||
}
|
||||
|
||||
if used != testPath {
|
||||
t.Fatalf("we read from the wrong file. got: %s, want: %s", used, testPath)
|
||||
}
|
||||
|
||||
expected := AcmeDnsConfig{
|
||||
General: general{
|
||||
Listen: "127.0.0.1:53",
|
||||
Proto: "both",
|
||||
Domain: "test.example.org",
|
||||
Nsname: "test.example.org",
|
||||
Nsadmin: "test.example.org",
|
||||
Debug: true,
|
||||
StaticRecords: []string{
|
||||
"test.example.org. A 127.0.0.1",
|
||||
"test.example.org. NS test.example.org.",
|
||||
},
|
||||
},
|
||||
Database: dbsettings{
|
||||
Engine: "dinosaur",
|
||||
Connection: "roar",
|
||||
},
|
||||
API: httpapi{
|
||||
Domain: "",
|
||||
IP: "0.0.0.0",
|
||||
DisableRegistration: false,
|
||||
AutocertPort: "",
|
||||
Port: "443",
|
||||
TLS: "none",
|
||||
TLSCertPrivkey: "/etc/tls/example.org/privkey.pem",
|
||||
TLSCertFullchain: "/etc/tls/example.org/fullchain.pem",
|
||||
ACMECacheDir: "api-certs",
|
||||
NotificationEmail: "",
|
||||
CorsOrigins: []string{"*"},
|
||||
UseHeader: true,
|
||||
HeaderName: "X-is-gonna-give-it-to-ya",
|
||||
},
|
||||
Logconfig: logconfig{
|
||||
Level: "info",
|
||||
Logtype: "stdout",
|
||||
File: "./acme-dns.log",
|
||||
Format: "json",
|
||||
},
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(cfg, expected) {
|
||||
t.Errorf("Did not read the config correctly: got %+v, want: %+v", cfg, expected)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func getNonExistentPath() (string, error) {
|
||||
path := fmt.Sprintf("/some/path/that/should/not/exist/on/any/filesystem/%10d.cfg", rand.Int())
|
||||
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("attempted non existant file exists!?: %s", path)
|
||||
}
|
||||
|
||||
// TestReadConfigFallbackError makes sure we error when we do not have a fallback config file
|
||||
func TestReadConfigFallbackError(t *testing.T) {
|
||||
var (
|
||||
badPaths []string
|
||||
i int
|
||||
)
|
||||
for len(badPaths) < 2 && i < 10 {
|
||||
i++
|
||||
|
||||
if path, err := getNonExistentPath(); err == nil {
|
||||
badPaths = append(badPaths, path)
|
||||
}
|
||||
}
|
||||
|
||||
if len(badPaths) != 2 {
|
||||
t.Fatalf("did not create exactly 2 bad paths")
|
||||
}
|
||||
|
||||
_, _, err := ReadConfig(badPaths[0], badPaths[1])
|
||||
if err == nil {
|
||||
t.Errorf("Should have failed reading non existant file: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileCheckPermissionDenied(t *testing.T) {
|
||||
tmpfile, err := os.CreateTemp("", "acmedns")
|
||||
if err != nil {
|
||||
t.Error("Could not create temporary file")
|
||||
t.Fatalf("Could not create temporary file: %s", err)
|
||||
}
|
||||
defer os.Remove(tmpfile.Name())
|
||||
_ = syscall.Chmod(tmpfile.Name(), 0000)
|
||||
@ -131,7 +250,7 @@ func TestFileCheckNotExists(t *testing.T) {
|
||||
func TestFileCheckOK(t *testing.T) {
|
||||
tmpfile, err := os.CreateTemp("", "acmedns")
|
||||
if err != nil {
|
||||
t.Error("Could not create temporary file")
|
||||
t.Fatalf("Could not create temporary file: %s", err)
|
||||
}
|
||||
defer os.Remove(tmpfile.Name())
|
||||
if !FileIsAccessible(tmpfile.Name()) {
|
||||
@ -144,9 +263,20 @@ func TestPrepareConfig(t *testing.T) {
|
||||
input AcmeDnsConfig
|
||||
shoulderror bool
|
||||
}{
|
||||
{AcmeDnsConfig{Database: dbsettings{Engine: "whatever", Connection: "whatever_too"}}, false},
|
||||
{AcmeDnsConfig{Database: dbsettings{Engine: "", Connection: "whatever_too"}}, true},
|
||||
{AcmeDnsConfig{Database: dbsettings{Engine: "whatever", Connection: ""}}, true},
|
||||
{AcmeDnsConfig{
|
||||
Database: dbsettings{Engine: "whatever", Connection: "whatever_too"},
|
||||
API: httpapi{TLS: ApiTlsProviderNone},
|
||||
}, false},
|
||||
{AcmeDnsConfig{Database: dbsettings{Engine: "", Connection: "whatever_too"},
|
||||
API: httpapi{TLS: ApiTlsProviderNone},
|
||||
}, true},
|
||||
{AcmeDnsConfig{Database: dbsettings{Engine: "whatever", Connection: ""},
|
||||
API: httpapi{TLS: ApiTlsProviderNone},
|
||||
}, true},
|
||||
{AcmeDnsConfig{
|
||||
Database: dbsettings{Engine: "whatever", Connection: "whatever_too"},
|
||||
API: httpapi{TLS: "whatever"},
|
||||
}, true},
|
||||
} {
|
||||
_, err := prepareConfig(test.input)
|
||||
if test.shoulderror {
|
||||
|
||||
@ -33,7 +33,6 @@ func (a *AcmednsAPI) Start(dnsservers []acmedns.AcmednsNS) {
|
||||
a.errChan <- err
|
||||
return
|
||||
}
|
||||
//legolog.Logger = stderrorlog
|
||||
api := httprouter.New()
|
||||
c := cors.New(cors.Options{
|
||||
AllowedOrigins: a.Config.API.CorsOrigins,
|
||||
@ -59,26 +58,7 @@ func (a *AcmednsAPI) Start(dnsservers []acmedns.AcmednsNS) {
|
||||
}
|
||||
|
||||
switch a.Config.API.TLS {
|
||||
case "letsencryptstaging":
|
||||
magic := a.setupTLS(dnsservers)
|
||||
err = magic.ManageAsync(context.Background(), []string{a.Config.General.Domain})
|
||||
if err != nil {
|
||||
a.errChan <- err
|
||||
return
|
||||
}
|
||||
cfg.GetCertificate = magic.GetCertificate
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: host,
|
||||
Handler: c.Handler(api),
|
||||
TLSConfig: cfg,
|
||||
ErrorLog: stderrorlog,
|
||||
}
|
||||
a.Logger.Infow("Listening HTTPS",
|
||||
"host", host,
|
||||
"domain", a.Config.General.Domain)
|
||||
err = srv.ListenAndServeTLS("", "")
|
||||
case "letsencrypt":
|
||||
case acmedns.ApiTlsProviderLetsEncrypt, acmedns.ApiTlsProviderLetsEncryptStaging:
|
||||
magic := a.setupTLS(dnsservers)
|
||||
err = magic.ManageAsync(context.Background(), []string{a.Config.General.Domain})
|
||||
if err != nil {
|
||||
@ -96,7 +76,7 @@ func (a *AcmednsAPI) Start(dnsservers []acmedns.AcmednsNS) {
|
||||
"host", host,
|
||||
"domain", a.Config.General.Domain)
|
||||
err = srv.ListenAndServeTLS("", "")
|
||||
case "cert":
|
||||
case acmedns.ApiTlsProviderCert:
|
||||
srv := &http.Server{
|
||||
Addr: host,
|
||||
Handler: c.Handler(api),
|
||||
@ -126,7 +106,7 @@ func (a *AcmednsAPI) setupTLS(dnsservers []acmedns.AcmednsNS) *certmagic.Config
|
||||
certmagic.DefaultACME.DNS01Solver = &provider
|
||||
certmagic.DefaultACME.Agreed = true
|
||||
certmagic.DefaultACME.Logger = a.Logger.Desugar()
|
||||
if a.Config.API.TLS == "letsencrypt" {
|
||||
if a.Config.API.TLS == acmedns.ApiTlsProviderLetsEncrypt {
|
||||
certmagic.DefaultACME.CA = certmagic.LetsEncryptProductionCA
|
||||
} else {
|
||||
certmagic.DefaultACME.CA = certmagic.LetsEncryptStagingCA
|
||||
|
||||
@ -10,8 +10,10 @@ import (
|
||||
|
||||
"github.com/joohoi/acme-dns/pkg/acmedns"
|
||||
"github.com/joohoi/acme-dns/pkg/database"
|
||||
"github.com/joohoi/acme-dns/pkg/nameserver"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/caddyserver/certmagic"
|
||||
"github.com/gavv/httpexpect"
|
||||
"github.com/google/uuid"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
@ -63,7 +65,7 @@ func setupRouter(debug bool, noauth bool) (http.Handler, AcmednsAPI, acmedns.Acm
|
||||
config, logger := fakeConfigAndLogger()
|
||||
config.API.Domain = ""
|
||||
config.API.Port = "8080"
|
||||
config.API.TLS = "none"
|
||||
config.API.TLS = acmedns.ApiTlsProviderNone
|
||||
config.API.CorsOrigins = []string{"*"}
|
||||
config.API.UseHeader = true
|
||||
config.API.HeaderName = "X-Forwarded-For"
|
||||
@ -498,3 +500,34 @@ func TestUpdateAllowedFromIP(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupTLS(t *testing.T) {
|
||||
_, svr, _ := setupRouter(false, false)
|
||||
|
||||
for _, test := range []struct {
|
||||
apiTls string
|
||||
expectedCA string
|
||||
}{
|
||||
{
|
||||
apiTls: acmedns.ApiTlsProviderLetsEncrypt,
|
||||
expectedCA: certmagic.LetsEncryptProductionCA,
|
||||
},
|
||||
{
|
||||
apiTls: acmedns.ApiTlsProviderLetsEncryptStaging,
|
||||
expectedCA: certmagic.LetsEncryptStagingCA,
|
||||
},
|
||||
} {
|
||||
svr.Config.API.TLS = test.apiTls
|
||||
ns := &nameserver.Nameserver{}
|
||||
magic := svr.setupTLS([]acmedns.AcmednsNS{ns})
|
||||
|
||||
if test.expectedCA != certmagic.DefaultACME.CA {
|
||||
t.Errorf("failed to configure default ACME CA. got %s, want %s", certmagic.DefaultACME.CA, test.expectedCA)
|
||||
}
|
||||
|
||||
if magic.DefaultServerName != svr.Config.General.Domain {
|
||||
t.Errorf("failed to set the correct doman. got: %s, want %s", magic.DefaultServerName, svr.Config.General.Domain)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -235,14 +235,14 @@ func TestAuthoritative(t *testing.T) {
|
||||
if answer.Ns[0].Header().Rrtype != dns.TypeSOA {
|
||||
t.Errorf("Was expecting SOA record as answer for NXDOMAIN but got [%s]", dns.TypeToString[answer.Ns[0].Header().Rrtype])
|
||||
}
|
||||
if !answer.MsgHdr.Authoritative {
|
||||
if !answer.Authoritative {
|
||||
t.Errorf("Was expecting authoritative bit to be set")
|
||||
}
|
||||
nanswer, _ := resolv.lookup("nonexsitent.nonauth.tld", dns.TypeA)
|
||||
if len(nanswer.Answer) > 0 {
|
||||
t.Errorf("Didn't expect answers for non authotitative domain query")
|
||||
}
|
||||
if nanswer.MsgHdr.Authoritative {
|
||||
if nanswer.Authoritative {
|
||||
t.Errorf("Authoritative bit should not be set for non-authoritative domain.")
|
||||
}
|
||||
}
|
||||
|
||||
@ -15,7 +15,7 @@ func (n *Nameserver) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if opt != nil {
|
||||
if opt.Version() != 0 {
|
||||
// Only EDNS0 is standardized
|
||||
m.MsgHdr.Rcode = dns.RcodeBadVers
|
||||
m.Rcode = dns.RcodeBadVers
|
||||
m.SetEdns0(512, false)
|
||||
} else {
|
||||
// We can safely do this as we know that we're not setting other OPT RRs within acme-dns.
|
||||
@ -39,13 +39,13 @@ func (n *Nameserver) readQuery(m *dns.Msg) {
|
||||
if auth {
|
||||
authoritative = auth
|
||||
}
|
||||
m.MsgHdr.Rcode = rc
|
||||
m.Rcode = rc
|
||||
m.Answer = append(m.Answer, rr...)
|
||||
}
|
||||
}
|
||||
m.MsgHdr.Authoritative = authoritative
|
||||
m.Authoritative = authoritative
|
||||
if authoritative {
|
||||
if m.MsgHdr.Rcode == dns.RcodeNameError {
|
||||
if m.Rcode == dns.RcodeNameError {
|
||||
m.Ns = append(m.Ns, n.SOA)
|
||||
}
|
||||
}
|
||||
|
||||
150
pkg/nameserver/handler_test.go
Normal file
150
pkg/nameserver/handler_test.go
Normal file
@ -0,0 +1,150 @@
|
||||
package nameserver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestNameserver_isOwnChallenge(t *testing.T) {
|
||||
type fields struct {
|
||||
OwnDomain string
|
||||
}
|
||||
type args struct {
|
||||
name string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "is own challenge",
|
||||
fields: fields{
|
||||
OwnDomain: "some-domain.test.",
|
||||
},
|
||||
args: args{
|
||||
name: "_acme-challenge.some-domain.test",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "challenge but not for us",
|
||||
fields: fields{
|
||||
OwnDomain: "some-domain.test.",
|
||||
},
|
||||
args: args{
|
||||
name: "_acme-challenge.some-other-domain.test",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "not a challenge",
|
||||
fields: fields{
|
||||
OwnDomain: "domain.test.",
|
||||
},
|
||||
args: args{
|
||||
name: "domain.test",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "other request challenge",
|
||||
fields: fields{
|
||||
OwnDomain: "domain.test.",
|
||||
},
|
||||
args: args{
|
||||
name: "my-domain.test",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
n := &Nameserver{
|
||||
OwnDomain: tt.fields.OwnDomain,
|
||||
}
|
||||
if got := n.isOwnChallenge(tt.args.name); got != tt.want {
|
||||
t.Errorf("isOwnChallenge() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameserver_isAuthoritative(t *testing.T) {
|
||||
type fields struct {
|
||||
OwnDomain string
|
||||
Domains map[string]Records
|
||||
}
|
||||
type args struct {
|
||||
q dns.Question
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "is authoritative own domain",
|
||||
fields: fields{
|
||||
OwnDomain: "auth.domain.",
|
||||
},
|
||||
args: args{
|
||||
q: dns.Question{Name: "auth.domain."},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "is authoritative other domain",
|
||||
fields: fields{
|
||||
OwnDomain: "auth.domain.",
|
||||
Domains: map[string]Records{
|
||||
"other-domain.test.": {Records: nil},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
q: dns.Question{Name: "other-domain.test."},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "is authoritative sub domain",
|
||||
fields: fields{
|
||||
OwnDomain: "auth.domain.",
|
||||
Domains: map[string]Records{
|
||||
"other-domain.test.": {Records: nil},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
q: dns.Question{Name: "sub.auth.domain."},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "is not authoritative own",
|
||||
fields: fields{
|
||||
OwnDomain: "auth.domain.",
|
||||
Domains: map[string]Records{
|
||||
"other-domain.test.": {Records: nil},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
q: dns.Question{Name: "special-auth.domain."},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
n := &Nameserver{
|
||||
OwnDomain: tt.fields.OwnDomain,
|
||||
Domains: tt.fields.Domains,
|
||||
}
|
||||
if got := n.isAuthoritative(tt.args.q); got != tt.want {
|
||||
t.Errorf("isAuthoritative() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
65
pkg/nameserver/validation_test.go
Normal file
65
pkg/nameserver/validation_test.go
Normal file
@ -0,0 +1,65 @@
|
||||
package nameserver
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestNameserver_answerOwnChallenge(t *testing.T) {
|
||||
type fields struct {
|
||||
personalAuthKey string
|
||||
}
|
||||
type args struct {
|
||||
q dns.Question
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want []dns.RR
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "answer own challenge",
|
||||
fields: fields{
|
||||
personalAuthKey: "some key text",
|
||||
},
|
||||
args: args{
|
||||
q: dns.Question{
|
||||
Name: "something",
|
||||
Qtype: 0,
|
||||
Qclass: 0,
|
||||
},
|
||||
},
|
||||
want: []dns.RR{
|
||||
&dns.TXT{
|
||||
Hdr: dns.RR_Header{Name: "something", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1},
|
||||
Txt: []string{"some key text"},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
n := &Nameserver{}
|
||||
|
||||
n.SetOwnAuthKey(tt.fields.personalAuthKey)
|
||||
if n.personalAuthKey != tt.fields.personalAuthKey {
|
||||
t.Errorf("failed to set personal auth key: got = %s, want %s", n.personalAuthKey, tt.fields.personalAuthKey)
|
||||
return
|
||||
}
|
||||
|
||||
got, err := n.answerOwnChallenge(tt.args.q)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("answerOwnChallenge() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("answerOwnChallenge() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user