Refactoring - improving coverage (#371)
* Increase code coverage in acmedns * More testing of ReadConfig() and its fallback mechanism * Found that if someone put a '"' double quote into the filename that we configure zap to log to, it would cause the the JSON created to be invalid. I have replaced the JSON string with proper config * Better handling of config options for api.TLS - we now error on an invalid value instead of silently failing. added a basic test for api.setupTLS() (to increase test coverage) * testing nameserver isOwnChallenge and isAuthoritative methods * add a unit test for nameserver answerOwnChallenge * fix linting errors * bump go and golangci-lint versions in github actions * Update golangci-lint.yml Bumping github-actions workflow versions to accommodate some changes in upstream golanci-lint * Bump Golang version to 1.23 (currently the oldest supported version) Bump golanglint-ci to 2.0.2 and migrate the config file. This should resolve the math/rand/v2 issue * bump golanglint-ci action version * Fixing up new golanglint-ci warnings and errors --------- Co-authored-by: Joona Hoikkala <5235109+joohoi@users.noreply.github.com>
This commit is contained in:
parent
d20fae37c9
commit
e0f9745182
10
.github/workflows/golangci-lint.yml
vendored
10
.github/workflows/golangci-lint.yml
vendored
@ -8,14 +8,14 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Setup Go
|
- name: Setup Go
|
||||||
uses: actions/setup-go@v3
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: 1.18.4
|
go-version: 1.23
|
||||||
|
|
||||||
- name: Check out code
|
- name: Check out code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Run golangci-lint
|
- name: Run golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v3
|
uses: golangci/golangci-lint-action@v7
|
||||||
with:
|
with:
|
||||||
version: v1.48
|
version: v2.0.2
|
||||||
|
|||||||
@ -1,15 +1,30 @@
|
|||||||
|
version: "2"
|
||||||
|
linters:
|
||||||
|
exclusions:
|
||||||
|
generated: lax
|
||||||
|
presets:
|
||||||
|
- comments
|
||||||
|
- common-false-positives
|
||||||
|
- legacy
|
||||||
|
- std-error-handling
|
||||||
|
paths:
|
||||||
|
- third_party$
|
||||||
|
- builtin$
|
||||||
|
- examples$
|
||||||
issues:
|
issues:
|
||||||
max-issues-per-linter: 0
|
max-issues-per-linter: 0
|
||||||
max-same-issues: 0
|
max-same-issues: 0
|
||||||
linters:
|
formatters:
|
||||||
# Enable specific linter
|
|
||||||
# https://golangci-lint.run/usage/linters/#enabled-by-default
|
|
||||||
enable:
|
enable:
|
||||||
- gofmt
|
- gofmt
|
||||||
- goimports
|
- goimports
|
||||||
linters-settings:
|
settings:
|
||||||
goimports:
|
goimports:
|
||||||
# A comma-separated list of prefixes, which, if set, checks import paths
|
local-prefixes:
|
||||||
# with the given prefixes are grouped after 3rd-party packages.
|
- github.com/acme-dns/acme-dns
|
||||||
# Default: ""
|
exclusions:
|
||||||
local-prefixes: github.com/acme-dns/acme-dns
|
generated: lax
|
||||||
|
paths:
|
||||||
|
- third_party$
|
||||||
|
- builtin$
|
||||||
|
- examples$
|
||||||
|
|||||||
2
go.mod
2
go.mod
@ -1,6 +1,6 @@
|
|||||||
module github.com/joohoi/acme-dns
|
module github.com/joohoi/acme-dns
|
||||||
|
|
||||||
go 1.19
|
go 1.23
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/BurntSushi/toml v1.2.1
|
github.com/BurntSushi/toml v1.2.1
|
||||||
|
|||||||
2
main.go
2
main.go
@ -21,7 +21,7 @@ func main() {
|
|||||||
// Read global config
|
// Read global config
|
||||||
var err error
|
var err error
|
||||||
var logger *zap.Logger
|
var logger *zap.Logger
|
||||||
config, usedConfigFile, err := acmedns.ReadConfig(*configPtr)
|
config, usedConfigFile, err := acmedns.ReadConfig(*configPtr, "./config.cfg")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Error: %s\n", err)
|
fmt.Printf("Error: %s\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Check if IP belongs to an allowed net
|
// AllowedFrom Check if IP belongs to an allowed net
|
||||||
func (a ACMETxt) AllowedFrom(ip string) bool {
|
func (a ACMETxt) AllowedFrom(ip string) bool {
|
||||||
remoteIP := net.ParseIP(ip)
|
remoteIP := net.ParseIP(ip)
|
||||||
// Range not limited
|
// Range not limited
|
||||||
@ -22,7 +22,7 @@ func (a ACMETxt) AllowedFrom(ip string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Go through list (most likely from headers) to check for the IP.
|
// AllowedFromList Go through list (most likely from headers) to check for the IP.
|
||||||
// Reason for this is that some setups use reverse proxy in front of acme-dns
|
// Reason for this is that some setups use reverse proxy in front of acme-dns
|
||||||
func (a ACMETxt) AllowedFromList(ips []string) bool {
|
func (a ACMETxt) AllowedFromList(ips []string) bool {
|
||||||
if len(ips) == 0 {
|
if len(ips) == 0 {
|
||||||
|
|||||||
@ -8,6 +8,13 @@ import (
|
|||||||
"github.com/BurntSushi/toml"
|
"github.com/BurntSushi/toml"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ApiTlsProviderNone = "none"
|
||||||
|
ApiTlsProviderLetsEncrypt = "letsencrypt"
|
||||||
|
ApiTlsProviderLetsEncryptStaging = "letsencryptstaging"
|
||||||
|
ApiTlsProviderCert = "cert"
|
||||||
|
)
|
||||||
|
|
||||||
func FileIsAccessible(fname string) bool {
|
func FileIsAccessible(fname string) bool {
|
||||||
_, err := os.Stat(fname)
|
_, err := os.Stat(fname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -45,24 +52,31 @@ func prepareConfig(conf AcmeDnsConfig) (AcmeDnsConfig, error) {
|
|||||||
conf.API.ACMECacheDir = "api-certs"
|
conf.API.ACMECacheDir = "api-certs"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch conf.API.TLS {
|
||||||
|
case ApiTlsProviderCert, ApiTlsProviderLetsEncrypt, ApiTlsProviderLetsEncryptStaging, ApiTlsProviderNone:
|
||||||
|
// we have a good value
|
||||||
|
default:
|
||||||
|
return conf, fmt.Errorf("invalid value for api.tls, expected one of [%s, %s, %s, %s]", ApiTlsProviderCert, ApiTlsProviderLetsEncrypt, ApiTlsProviderLetsEncryptStaging, ApiTlsProviderNone)
|
||||||
|
}
|
||||||
|
|
||||||
return conf, nil
|
return conf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReadConfig(configFile string) (AcmeDnsConfig, string, error) {
|
func ReadConfig(configFile, fallback string) (AcmeDnsConfig, string, error) {
|
||||||
var usedConfigFile string
|
var usedConfigFile string
|
||||||
var config AcmeDnsConfig
|
var config AcmeDnsConfig
|
||||||
var err error
|
var err error
|
||||||
if FileIsAccessible(configFile) {
|
if FileIsAccessible(configFile) {
|
||||||
usedConfigFile = configFile
|
usedConfigFile = configFile
|
||||||
config, err = readTomlConfig(configFile)
|
config, err = readTomlConfig(configFile)
|
||||||
} else if FileIsAccessible("./config.cfg") {
|
} else if FileIsAccessible(fallback) {
|
||||||
usedConfigFile = "./config.cfg"
|
usedConfigFile = fallback
|
||||||
config, err = readTomlConfig("./config.cfg")
|
config, err = readTomlConfig(fallback)
|
||||||
} else {
|
} else {
|
||||||
err = fmt.Errorf("configuration file not found")
|
err = fmt.Errorf("configuration file not found")
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("encountered an error while trying to read configuration file: %s\n", err)
|
err = fmt.Errorf("encountered an error while trying to read configuration file: %w", err)
|
||||||
}
|
}
|
||||||
return config, usedConfigFile, err
|
return config, usedConfigFile, err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,14 +1,18 @@
|
|||||||
package acmedns
|
package acmedns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"go.uber.org/zap/zapcore"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetupLogging(config AcmeDnsConfig) (*zap.Logger, error) {
|
func SetupLogging(config AcmeDnsConfig) (*zap.Logger, error) {
|
||||||
var logger *zap.Logger
|
var (
|
||||||
|
logger *zap.Logger
|
||||||
|
zapCfg zap.Config
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
logformat := "console"
|
logformat := "console"
|
||||||
if config.Logconfig.Format == "json" {
|
if config.Logconfig.Format == "json" {
|
||||||
logformat = "json"
|
logformat = "json"
|
||||||
@ -21,23 +25,22 @@ func SetupLogging(config AcmeDnsConfig) (*zap.Logger, error) {
|
|||||||
if config.Logconfig.Logtype == "file" {
|
if config.Logconfig.Logtype == "file" {
|
||||||
errorPath = config.Logconfig.File
|
errorPath = config.Logconfig.File
|
||||||
}
|
}
|
||||||
zapConfigJson := fmt.Sprintf(`{
|
|
||||||
"level": "%s",
|
zapCfg.Level, err = zap.ParseAtomicLevel(config.Logconfig.Level)
|
||||||
"encoding": "%s",
|
if err != nil {
|
||||||
"outputPaths": ["%s"],
|
|
||||||
"errorOutputPaths": ["%s"],
|
|
||||||
"encoderConfig": {
|
|
||||||
"timeKey": "time",
|
|
||||||
"messageKey": "msg",
|
|
||||||
"levelKey": "level",
|
|
||||||
"levelEncoder": "lowercase",
|
|
||||||
"timeEncoder": "iso8601"
|
|
||||||
}
|
|
||||||
}`, config.Logconfig.Level, logformat, outputPath, errorPath)
|
|
||||||
var zapCfg zap.Config
|
|
||||||
if err := json.Unmarshal([]byte(zapConfigJson), &zapCfg); err != nil {
|
|
||||||
return logger, err
|
return logger, err
|
||||||
}
|
}
|
||||||
logger, err := zapCfg.Build()
|
zapCfg.Encoding = logformat
|
||||||
|
zapCfg.OutputPaths = []string{outputPath}
|
||||||
|
zapCfg.ErrorOutputPaths = []string{errorPath}
|
||||||
|
zapCfg.EncoderConfig = zapcore.EncoderConfig{
|
||||||
|
TimeKey: "time",
|
||||||
|
MessageKey: "msg",
|
||||||
|
LevelKey: "level",
|
||||||
|
EncodeLevel: zapcore.LowercaseLevelEncoder,
|
||||||
|
EncodeTime: zapcore.ISO8601TimeEncoder,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger, err = zapCfg.Build()
|
||||||
return logger, err
|
return logger, err
|
||||||
}
|
}
|
||||||
|
|||||||
36
pkg/acmedns/testdata/test_read_fallback_config.toml
vendored
Normal file
36
pkg/acmedns/testdata/test_read_fallback_config.toml
vendored
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
[general]
|
||||||
|
listen = "127.0.0.1:53"
|
||||||
|
protocol = "both"
|
||||||
|
domain = "test.example.org"
|
||||||
|
nsname = "test.example.org"
|
||||||
|
nsadmin = "test.example.org"
|
||||||
|
records = [
|
||||||
|
"test.example.org. A 127.0.0.1",
|
||||||
|
"test.example.org. NS test.example.org.",
|
||||||
|
]
|
||||||
|
debug = true
|
||||||
|
|
||||||
|
[database]
|
||||||
|
engine = "dinosaur"
|
||||||
|
connection = "roar"
|
||||||
|
|
||||||
|
[api]
|
||||||
|
ip = "0.0.0.0"
|
||||||
|
disable_registration = false
|
||||||
|
port = "443"
|
||||||
|
tls = "none"
|
||||||
|
tls_cert_privkey = "/etc/tls/example.org/privkey.pem"
|
||||||
|
tls_cert_fullchain = "/etc/tls/example.org/fullchain.pem"
|
||||||
|
acme_cache_dir = "api-certs"
|
||||||
|
notification_email = ""
|
||||||
|
corsorigins = [
|
||||||
|
"*"
|
||||||
|
]
|
||||||
|
use_header = true
|
||||||
|
header_name = "X-is-gonna-give-it-to-ya"
|
||||||
|
|
||||||
|
[logconfig]
|
||||||
|
loglevel = "info"
|
||||||
|
logtype = "stdout"
|
||||||
|
logfile = "./acme-dns.log"
|
||||||
|
logformat = "json"
|
||||||
@ -8,7 +8,7 @@ type Account struct {
|
|||||||
Subdomain string
|
Subdomain string
|
||||||
}
|
}
|
||||||
|
|
||||||
// DNSConfig holds the config structure
|
// AcmeDnsConfig holds the config structure
|
||||||
type AcmeDnsConfig struct {
|
type AcmeDnsConfig struct {
|
||||||
General general
|
General general
|
||||||
Database dbsettings
|
Database dbsettings
|
||||||
|
|||||||
@ -1,7 +1,10 @@
|
|||||||
package acmedns
|
package acmedns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand/v2"
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
"syscall"
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -46,19 +49,33 @@ func TestSetupLoggingError(t *testing.T) {
|
|||||||
for _, test := range []struct {
|
for _, test := range []struct {
|
||||||
format string
|
format string
|
||||||
level string
|
level string
|
||||||
|
file string
|
||||||
errexpected bool
|
errexpected bool
|
||||||
}{
|
}{
|
||||||
{"text", "warn", false},
|
{"text", "warn", "", false},
|
||||||
{"json", "debug", false},
|
{"json", "debug", "", false},
|
||||||
{"text", "info", false},
|
{"text", "info", "", false},
|
||||||
{"json", "error", false},
|
{"json", "error", "", false},
|
||||||
{"text", "something", true},
|
{"text", "something", "", true},
|
||||||
|
{"text", "info", "a path with\" in its name.txt", false},
|
||||||
} {
|
} {
|
||||||
conf.Logconfig.Format = test.format
|
conf.Logconfig.Format = test.format
|
||||||
conf.Logconfig.Level = test.level
|
conf.Logconfig.Level = test.level
|
||||||
|
if test.file != "" {
|
||||||
|
conf.Logconfig.File = test.file
|
||||||
|
conf.Logconfig.Logtype = "file"
|
||||||
|
|
||||||
|
}
|
||||||
_, err := SetupLogging(conf)
|
_, err := SetupLogging(conf)
|
||||||
if test.errexpected && err == nil {
|
if test.errexpected && err == nil {
|
||||||
t.Errorf("Expected error but did not get one for loglevel: %s", err)
|
t.Errorf("Expected error but did not get one for loglevel: %s", err)
|
||||||
|
} else if !test.errexpected && err != nil {
|
||||||
|
t.Errorf("Unexpected error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// clean up the file zap creates
|
||||||
|
if test.file != "" {
|
||||||
|
_ = os.Remove(test.file)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -88,7 +105,7 @@ func TestReadConfig(t *testing.T) {
|
|||||||
} {
|
} {
|
||||||
tmpfile, err := os.CreateTemp("", "acmedns")
|
tmpfile, err := os.CreateTemp("", "acmedns")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Could not create temporary file")
|
t.Fatalf("Could not create temporary file: %s", err)
|
||||||
}
|
}
|
||||||
defer os.Remove(tmpfile.Name())
|
defer os.Remove(tmpfile.Name())
|
||||||
|
|
||||||
@ -99,7 +116,7 @@ func TestReadConfig(t *testing.T) {
|
|||||||
if err := tmpfile.Close(); err != nil {
|
if err := tmpfile.Close(); err != nil {
|
||||||
t.Error("Could not close temporary file")
|
t.Error("Could not close temporary file")
|
||||||
}
|
}
|
||||||
ret, _, _ := ReadConfig(tmpfile.Name())
|
ret, _, _ := ReadConfig(tmpfile.Name(), "")
|
||||||
if ret.General.Listen != test.output.General.Listen {
|
if ret.General.Listen != test.output.General.Listen {
|
||||||
t.Errorf("Test %d: Expected listen value %s, but got %s", i, test.output.General.Listen, ret.General.Listen)
|
t.Errorf("Test %d: Expected listen value %s, but got %s", i, test.output.General.Listen, ret.General.Listen)
|
||||||
}
|
}
|
||||||
@ -109,10 +126,112 @@ func TestReadConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReadConfigFallback(t *testing.T) {
|
||||||
|
var (
|
||||||
|
path string
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
testPath := "testdata/test_read_fallback_config.toml"
|
||||||
|
|
||||||
|
path, err = getNonExistentPath()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed getting non existant path: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, used, err := ReadConfig(path, testPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read a config file when we should have: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if used != testPath {
|
||||||
|
t.Fatalf("we read from the wrong file. got: %s, want: %s", used, testPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := AcmeDnsConfig{
|
||||||
|
General: general{
|
||||||
|
Listen: "127.0.0.1:53",
|
||||||
|
Proto: "both",
|
||||||
|
Domain: "test.example.org",
|
||||||
|
Nsname: "test.example.org",
|
||||||
|
Nsadmin: "test.example.org",
|
||||||
|
Debug: true,
|
||||||
|
StaticRecords: []string{
|
||||||
|
"test.example.org. A 127.0.0.1",
|
||||||
|
"test.example.org. NS test.example.org.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Database: dbsettings{
|
||||||
|
Engine: "dinosaur",
|
||||||
|
Connection: "roar",
|
||||||
|
},
|
||||||
|
API: httpapi{
|
||||||
|
Domain: "",
|
||||||
|
IP: "0.0.0.0",
|
||||||
|
DisableRegistration: false,
|
||||||
|
AutocertPort: "",
|
||||||
|
Port: "443",
|
||||||
|
TLS: "none",
|
||||||
|
TLSCertPrivkey: "/etc/tls/example.org/privkey.pem",
|
||||||
|
TLSCertFullchain: "/etc/tls/example.org/fullchain.pem",
|
||||||
|
ACMECacheDir: "api-certs",
|
||||||
|
NotificationEmail: "",
|
||||||
|
CorsOrigins: []string{"*"},
|
||||||
|
UseHeader: true,
|
||||||
|
HeaderName: "X-is-gonna-give-it-to-ya",
|
||||||
|
},
|
||||||
|
Logconfig: logconfig{
|
||||||
|
Level: "info",
|
||||||
|
Logtype: "stdout",
|
||||||
|
File: "./acme-dns.log",
|
||||||
|
Format: "json",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(cfg, expected) {
|
||||||
|
t.Errorf("Did not read the config correctly: got %+v, want: %+v", cfg, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNonExistentPath() (string, error) {
|
||||||
|
path := fmt.Sprintf("/some/path/that/should/not/exist/on/any/filesystem/%10d.cfg", rand.Int())
|
||||||
|
|
||||||
|
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("attempted non existant file exists!?: %s", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestReadConfigFallbackError makes sure we error when we do not have a fallback config file
|
||||||
|
func TestReadConfigFallbackError(t *testing.T) {
|
||||||
|
var (
|
||||||
|
badPaths []string
|
||||||
|
i int
|
||||||
|
)
|
||||||
|
for len(badPaths) < 2 && i < 10 {
|
||||||
|
i++
|
||||||
|
|
||||||
|
if path, err := getNonExistentPath(); err == nil {
|
||||||
|
badPaths = append(badPaths, path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(badPaths) != 2 {
|
||||||
|
t.Fatalf("did not create exactly 2 bad paths")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err := ReadConfig(badPaths[0], badPaths[1])
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Should have failed reading non existant file: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFileCheckPermissionDenied(t *testing.T) {
|
func TestFileCheckPermissionDenied(t *testing.T) {
|
||||||
tmpfile, err := os.CreateTemp("", "acmedns")
|
tmpfile, err := os.CreateTemp("", "acmedns")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Could not create temporary file")
|
t.Fatalf("Could not create temporary file: %s", err)
|
||||||
}
|
}
|
||||||
defer os.Remove(tmpfile.Name())
|
defer os.Remove(tmpfile.Name())
|
||||||
_ = syscall.Chmod(tmpfile.Name(), 0000)
|
_ = syscall.Chmod(tmpfile.Name(), 0000)
|
||||||
@ -131,7 +250,7 @@ func TestFileCheckNotExists(t *testing.T) {
|
|||||||
func TestFileCheckOK(t *testing.T) {
|
func TestFileCheckOK(t *testing.T) {
|
||||||
tmpfile, err := os.CreateTemp("", "acmedns")
|
tmpfile, err := os.CreateTemp("", "acmedns")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Could not create temporary file")
|
t.Fatalf("Could not create temporary file: %s", err)
|
||||||
}
|
}
|
||||||
defer os.Remove(tmpfile.Name())
|
defer os.Remove(tmpfile.Name())
|
||||||
if !FileIsAccessible(tmpfile.Name()) {
|
if !FileIsAccessible(tmpfile.Name()) {
|
||||||
@ -144,9 +263,20 @@ func TestPrepareConfig(t *testing.T) {
|
|||||||
input AcmeDnsConfig
|
input AcmeDnsConfig
|
||||||
shoulderror bool
|
shoulderror bool
|
||||||
}{
|
}{
|
||||||
{AcmeDnsConfig{Database: dbsettings{Engine: "whatever", Connection: "whatever_too"}}, false},
|
{AcmeDnsConfig{
|
||||||
{AcmeDnsConfig{Database: dbsettings{Engine: "", Connection: "whatever_too"}}, true},
|
Database: dbsettings{Engine: "whatever", Connection: "whatever_too"},
|
||||||
{AcmeDnsConfig{Database: dbsettings{Engine: "whatever", Connection: ""}}, true},
|
API: httpapi{TLS: ApiTlsProviderNone},
|
||||||
|
}, false},
|
||||||
|
{AcmeDnsConfig{Database: dbsettings{Engine: "", Connection: "whatever_too"},
|
||||||
|
API: httpapi{TLS: ApiTlsProviderNone},
|
||||||
|
}, true},
|
||||||
|
{AcmeDnsConfig{Database: dbsettings{Engine: "whatever", Connection: ""},
|
||||||
|
API: httpapi{TLS: ApiTlsProviderNone},
|
||||||
|
}, true},
|
||||||
|
{AcmeDnsConfig{
|
||||||
|
Database: dbsettings{Engine: "whatever", Connection: "whatever_too"},
|
||||||
|
API: httpapi{TLS: "whatever"},
|
||||||
|
}, true},
|
||||||
} {
|
} {
|
||||||
_, err := prepareConfig(test.input)
|
_, err := prepareConfig(test.input)
|
||||||
if test.shoulderror {
|
if test.shoulderror {
|
||||||
|
|||||||
@ -33,7 +33,6 @@ func (a *AcmednsAPI) Start(dnsservers []acmedns.AcmednsNS) {
|
|||||||
a.errChan <- err
|
a.errChan <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
//legolog.Logger = stderrorlog
|
|
||||||
api := httprouter.New()
|
api := httprouter.New()
|
||||||
c := cors.New(cors.Options{
|
c := cors.New(cors.Options{
|
||||||
AllowedOrigins: a.Config.API.CorsOrigins,
|
AllowedOrigins: a.Config.API.CorsOrigins,
|
||||||
@ -59,26 +58,7 @@ func (a *AcmednsAPI) Start(dnsservers []acmedns.AcmednsNS) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch a.Config.API.TLS {
|
switch a.Config.API.TLS {
|
||||||
case "letsencryptstaging":
|
case acmedns.ApiTlsProviderLetsEncrypt, acmedns.ApiTlsProviderLetsEncryptStaging:
|
||||||
magic := a.setupTLS(dnsservers)
|
|
||||||
err = magic.ManageAsync(context.Background(), []string{a.Config.General.Domain})
|
|
||||||
if err != nil {
|
|
||||||
a.errChan <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cfg.GetCertificate = magic.GetCertificate
|
|
||||||
|
|
||||||
srv := &http.Server{
|
|
||||||
Addr: host,
|
|
||||||
Handler: c.Handler(api),
|
|
||||||
TLSConfig: cfg,
|
|
||||||
ErrorLog: stderrorlog,
|
|
||||||
}
|
|
||||||
a.Logger.Infow("Listening HTTPS",
|
|
||||||
"host", host,
|
|
||||||
"domain", a.Config.General.Domain)
|
|
||||||
err = srv.ListenAndServeTLS("", "")
|
|
||||||
case "letsencrypt":
|
|
||||||
magic := a.setupTLS(dnsservers)
|
magic := a.setupTLS(dnsservers)
|
||||||
err = magic.ManageAsync(context.Background(), []string{a.Config.General.Domain})
|
err = magic.ManageAsync(context.Background(), []string{a.Config.General.Domain})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -96,7 +76,7 @@ func (a *AcmednsAPI) Start(dnsservers []acmedns.AcmednsNS) {
|
|||||||
"host", host,
|
"host", host,
|
||||||
"domain", a.Config.General.Domain)
|
"domain", a.Config.General.Domain)
|
||||||
err = srv.ListenAndServeTLS("", "")
|
err = srv.ListenAndServeTLS("", "")
|
||||||
case "cert":
|
case acmedns.ApiTlsProviderCert:
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Addr: host,
|
Addr: host,
|
||||||
Handler: c.Handler(api),
|
Handler: c.Handler(api),
|
||||||
@ -126,7 +106,7 @@ func (a *AcmednsAPI) setupTLS(dnsservers []acmedns.AcmednsNS) *certmagic.Config
|
|||||||
certmagic.DefaultACME.DNS01Solver = &provider
|
certmagic.DefaultACME.DNS01Solver = &provider
|
||||||
certmagic.DefaultACME.Agreed = true
|
certmagic.DefaultACME.Agreed = true
|
||||||
certmagic.DefaultACME.Logger = a.Logger.Desugar()
|
certmagic.DefaultACME.Logger = a.Logger.Desugar()
|
||||||
if a.Config.API.TLS == "letsencrypt" {
|
if a.Config.API.TLS == acmedns.ApiTlsProviderLetsEncrypt {
|
||||||
certmagic.DefaultACME.CA = certmagic.LetsEncryptProductionCA
|
certmagic.DefaultACME.CA = certmagic.LetsEncryptProductionCA
|
||||||
} else {
|
} else {
|
||||||
certmagic.DefaultACME.CA = certmagic.LetsEncryptStagingCA
|
certmagic.DefaultACME.CA = certmagic.LetsEncryptStagingCA
|
||||||
|
|||||||
@ -10,8 +10,10 @@ import (
|
|||||||
|
|
||||||
"github.com/joohoi/acme-dns/pkg/acmedns"
|
"github.com/joohoi/acme-dns/pkg/acmedns"
|
||||||
"github.com/joohoi/acme-dns/pkg/database"
|
"github.com/joohoi/acme-dns/pkg/database"
|
||||||
|
"github.com/joohoi/acme-dns/pkg/nameserver"
|
||||||
|
|
||||||
"github.com/DATA-DOG/go-sqlmock"
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
"github.com/caddyserver/certmagic"
|
||||||
"github.com/gavv/httpexpect"
|
"github.com/gavv/httpexpect"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
@ -63,7 +65,7 @@ func setupRouter(debug bool, noauth bool) (http.Handler, AcmednsAPI, acmedns.Acm
|
|||||||
config, logger := fakeConfigAndLogger()
|
config, logger := fakeConfigAndLogger()
|
||||||
config.API.Domain = ""
|
config.API.Domain = ""
|
||||||
config.API.Port = "8080"
|
config.API.Port = "8080"
|
||||||
config.API.TLS = "none"
|
config.API.TLS = acmedns.ApiTlsProviderNone
|
||||||
config.API.CorsOrigins = []string{"*"}
|
config.API.CorsOrigins = []string{"*"}
|
||||||
config.API.UseHeader = true
|
config.API.UseHeader = true
|
||||||
config.API.HeaderName = "X-Forwarded-For"
|
config.API.HeaderName = "X-Forwarded-For"
|
||||||
@ -498,3 +500,34 @@ func TestUpdateAllowedFromIP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetupTLS(t *testing.T) {
|
||||||
|
_, svr, _ := setupRouter(false, false)
|
||||||
|
|
||||||
|
for _, test := range []struct {
|
||||||
|
apiTls string
|
||||||
|
expectedCA string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
apiTls: acmedns.ApiTlsProviderLetsEncrypt,
|
||||||
|
expectedCA: certmagic.LetsEncryptProductionCA,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
apiTls: acmedns.ApiTlsProviderLetsEncryptStaging,
|
||||||
|
expectedCA: certmagic.LetsEncryptStagingCA,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
svr.Config.API.TLS = test.apiTls
|
||||||
|
ns := &nameserver.Nameserver{}
|
||||||
|
magic := svr.setupTLS([]acmedns.AcmednsNS{ns})
|
||||||
|
|
||||||
|
if test.expectedCA != certmagic.DefaultACME.CA {
|
||||||
|
t.Errorf("failed to configure default ACME CA. got %s, want %s", certmagic.DefaultACME.CA, test.expectedCA)
|
||||||
|
}
|
||||||
|
|
||||||
|
if magic.DefaultServerName != svr.Config.General.Domain {
|
||||||
|
t.Errorf("failed to set the correct doman. got: %s, want %s", magic.DefaultServerName, svr.Config.General.Domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|||||||
@ -235,14 +235,14 @@ func TestAuthoritative(t *testing.T) {
|
|||||||
if answer.Ns[0].Header().Rrtype != dns.TypeSOA {
|
if answer.Ns[0].Header().Rrtype != dns.TypeSOA {
|
||||||
t.Errorf("Was expecting SOA record as answer for NXDOMAIN but got [%s]", dns.TypeToString[answer.Ns[0].Header().Rrtype])
|
t.Errorf("Was expecting SOA record as answer for NXDOMAIN but got [%s]", dns.TypeToString[answer.Ns[0].Header().Rrtype])
|
||||||
}
|
}
|
||||||
if !answer.MsgHdr.Authoritative {
|
if !answer.Authoritative {
|
||||||
t.Errorf("Was expecting authoritative bit to be set")
|
t.Errorf("Was expecting authoritative bit to be set")
|
||||||
}
|
}
|
||||||
nanswer, _ := resolv.lookup("nonexsitent.nonauth.tld", dns.TypeA)
|
nanswer, _ := resolv.lookup("nonexsitent.nonauth.tld", dns.TypeA)
|
||||||
if len(nanswer.Answer) > 0 {
|
if len(nanswer.Answer) > 0 {
|
||||||
t.Errorf("Didn't expect answers for non authotitative domain query")
|
t.Errorf("Didn't expect answers for non authotitative domain query")
|
||||||
}
|
}
|
||||||
if nanswer.MsgHdr.Authoritative {
|
if nanswer.Authoritative {
|
||||||
t.Errorf("Authoritative bit should not be set for non-authoritative domain.")
|
t.Errorf("Authoritative bit should not be set for non-authoritative domain.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -15,7 +15,7 @@ func (n *Nameserver) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
if opt != nil {
|
if opt != nil {
|
||||||
if opt.Version() != 0 {
|
if opt.Version() != 0 {
|
||||||
// Only EDNS0 is standardized
|
// Only EDNS0 is standardized
|
||||||
m.MsgHdr.Rcode = dns.RcodeBadVers
|
m.Rcode = dns.RcodeBadVers
|
||||||
m.SetEdns0(512, false)
|
m.SetEdns0(512, false)
|
||||||
} else {
|
} else {
|
||||||
// We can safely do this as we know that we're not setting other OPT RRs within acme-dns.
|
// We can safely do this as we know that we're not setting other OPT RRs within acme-dns.
|
||||||
@ -39,13 +39,13 @@ func (n *Nameserver) readQuery(m *dns.Msg) {
|
|||||||
if auth {
|
if auth {
|
||||||
authoritative = auth
|
authoritative = auth
|
||||||
}
|
}
|
||||||
m.MsgHdr.Rcode = rc
|
m.Rcode = rc
|
||||||
m.Answer = append(m.Answer, rr...)
|
m.Answer = append(m.Answer, rr...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
m.MsgHdr.Authoritative = authoritative
|
m.Authoritative = authoritative
|
||||||
if authoritative {
|
if authoritative {
|
||||||
if m.MsgHdr.Rcode == dns.RcodeNameError {
|
if m.Rcode == dns.RcodeNameError {
|
||||||
m.Ns = append(m.Ns, n.SOA)
|
m.Ns = append(m.Ns, n.SOA)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
150
pkg/nameserver/handler_test.go
Normal file
150
pkg/nameserver/handler_test.go
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
package nameserver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNameserver_isOwnChallenge(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
OwnDomain string
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "is own challenge",
|
||||||
|
fields: fields{
|
||||||
|
OwnDomain: "some-domain.test.",
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
name: "_acme-challenge.some-domain.test",
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "challenge but not for us",
|
||||||
|
fields: fields{
|
||||||
|
OwnDomain: "some-domain.test.",
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
name: "_acme-challenge.some-other-domain.test",
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not a challenge",
|
||||||
|
fields: fields{
|
||||||
|
OwnDomain: "domain.test.",
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
name: "domain.test",
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "other request challenge",
|
||||||
|
fields: fields{
|
||||||
|
OwnDomain: "domain.test.",
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
name: "my-domain.test",
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
n := &Nameserver{
|
||||||
|
OwnDomain: tt.fields.OwnDomain,
|
||||||
|
}
|
||||||
|
if got := n.isOwnChallenge(tt.args.name); got != tt.want {
|
||||||
|
t.Errorf("isOwnChallenge() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNameserver_isAuthoritative(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
OwnDomain string
|
||||||
|
Domains map[string]Records
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
q dns.Question
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "is authoritative own domain",
|
||||||
|
fields: fields{
|
||||||
|
OwnDomain: "auth.domain.",
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
q: dns.Question{Name: "auth.domain."},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "is authoritative other domain",
|
||||||
|
fields: fields{
|
||||||
|
OwnDomain: "auth.domain.",
|
||||||
|
Domains: map[string]Records{
|
||||||
|
"other-domain.test.": {Records: nil},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
q: dns.Question{Name: "other-domain.test."},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "is authoritative sub domain",
|
||||||
|
fields: fields{
|
||||||
|
OwnDomain: "auth.domain.",
|
||||||
|
Domains: map[string]Records{
|
||||||
|
"other-domain.test.": {Records: nil},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
q: dns.Question{Name: "sub.auth.domain."},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "is not authoritative own",
|
||||||
|
fields: fields{
|
||||||
|
OwnDomain: "auth.domain.",
|
||||||
|
Domains: map[string]Records{
|
||||||
|
"other-domain.test.": {Records: nil},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
q: dns.Question{Name: "special-auth.domain."},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
n := &Nameserver{
|
||||||
|
OwnDomain: tt.fields.OwnDomain,
|
||||||
|
Domains: tt.fields.Domains,
|
||||||
|
}
|
||||||
|
if got := n.isAuthoritative(tt.args.q); got != tt.want {
|
||||||
|
t.Errorf("isAuthoritative() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
65
pkg/nameserver/validation_test.go
Normal file
65
pkg/nameserver/validation_test.go
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
package nameserver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNameserver_answerOwnChallenge(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
personalAuthKey string
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
q dns.Question
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want []dns.RR
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "answer own challenge",
|
||||||
|
fields: fields{
|
||||||
|
personalAuthKey: "some key text",
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
q: dns.Question{
|
||||||
|
Name: "something",
|
||||||
|
Qtype: 0,
|
||||||
|
Qclass: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []dns.RR{
|
||||||
|
&dns.TXT{
|
||||||
|
Hdr: dns.RR_Header{Name: "something", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1},
|
||||||
|
Txt: []string{"some key text"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
n := &Nameserver{}
|
||||||
|
|
||||||
|
n.SetOwnAuthKey(tt.fields.personalAuthKey)
|
||||||
|
if n.personalAuthKey != tt.fields.personalAuthKey {
|
||||||
|
t.Errorf("failed to set personal auth key: got = %s, want %s", n.personalAuthKey, tt.fields.personalAuthKey)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := n.answerOwnChallenge(tt.args.q)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("answerOwnChallenge() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("answerOwnChallenge() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user