From 4c437c05064e3898694696964de7b312cef8697a Mon Sep 17 00:00:00 2001 From: Joona Hoikkala Date: Mon, 28 Nov 2016 22:46:24 +0200 Subject: [PATCH] Added protocol selection to DNS server --- config.cfg | 2 ++ dns_test.go | 9 ++------- main.go | 2 +- main_test.go | 12 ++++++++++-- types.go | 1 + util.go | 13 +++---------- 6 files changed, 19 insertions(+), 20 deletions(-) diff --git a/config.cfg b/config.cfg index 5cf6580..b0d593f 100644 --- a/config.cfg +++ b/config.cfg @@ -1,6 +1,8 @@ [general] # dns interface listen = ":53" +# protocol, "udp", "udp4", "udp6" or "tcp", "tcp4", "tcp6" +protocol = "udp" # domain name to serve th requests off of domain = "auth.example.org" # zone name server diff --git a/dns_test.go b/dns_test.go index 9a573ee..19c4d3e 100644 --- a/dns_test.go +++ b/dns_test.go @@ -5,7 +5,6 @@ import ( "database/sql/driver" "errors" "fmt" - log "github.com/Sirupsen/logrus" "github.com/erikstmartin/go-testdb" "github.com/miekg/dns" "strings" @@ -107,13 +106,9 @@ func TestParse(t *testing.T) { Debug: false, } var testRR Records - loghook.Reset() testRR.Parse(testcfg) - if len(loghook.Entries) != 1 { - t.Errorf("Expected exactly one logged line, instead there was %d line(s)", len(loghook.Entries)) - } - if loghook.LastEntry().Level != log.ErrorLevel { - t.Error("Expected error level of ERROR from last message") + if !loggerHasEntryWithMessage("Error while adding SOA record") { + t.Errorf("Expected SOA parsing to return error, but did not find one") } } diff --git a/main.go b/main.go index 21b9f31..d114ad3 100644 --- a/main.go +++ b/main.go @@ -30,7 +30,7 @@ func main() { defer DB.Close() // DNS server - startDNS(DNSConf.General.Listen) + startDNS(DNSConf.General.Listen, DNSConf.General.Proto) // HTTP API startHTTPAPI() diff --git a/main_test.go b/main_test.go index d5288b3..980d0d1 100644 --- a/main_test.go +++ b/main_test.go @@ -42,8 +42,7 @@ func TestMain(m *testing.M) { _ = newDb.Init("sqlite3", ":memory:") } DB = newDb - - server := startDNS("0.0.0.0:15353") + server := startDNS("0.0.0.0:15353", "udp") exitval := m.Run() server.Shutdown() DB.Close() @@ -84,3 +83,12 @@ func setupTestLogger() { log.SetOutput(ioutil.Discard) log.AddHook(loghook) } + +func loggerHasEntryWithMessage(message string) bool { + for _, v := range loghook.Entries { + if v.Message == message { + return true + } + } + return false +} diff --git a/types.go b/types.go index f71baa2..d7239a2 100644 --- a/types.go +++ b/types.go @@ -35,6 +35,7 @@ type authMiddleware struct{} // Config file general section type general struct { Listen string + Proto string `toml:"protocol"` Domain string Nsname string Nsadmin string diff --git a/util.go b/util.go index 67999cf..7f51d06 100644 --- a/util.go +++ b/util.go @@ -7,7 +7,6 @@ import ( "github.com/miekg/dns" "github.com/satori/go.uuid" "math/big" - "os" "regexp" "strings" ) @@ -72,16 +71,10 @@ func setupLogging(format string, level string) { // TODO: file logging } -func startDNS(listen string) *dns.Server { +func startDNS(listen string, proto string) *dns.Server { // DNS server part dns.HandleFunc(".", handleRequest) - server := &dns.Server{Addr: listen, Net: "udp"} - go func() { - err := server.ListenAndServe() - if err != nil { - log.Errorf("%v", err) - os.Exit(1) - } - }() + server := &dns.Server{Addr: listen, Net: proto} + go server.ListenAndServe() return server }