From f71b1772c6089fdd0a72c872ed914f3531e3d483 Mon Sep 17 00:00:00 2001 From: Joona Hoikkala Date: Sat, 26 Nov 2016 10:02:32 +0200 Subject: [PATCH] DNS tests continued --- dns_test.go | 123 ++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 106 insertions(+), 17 deletions(-) diff --git a/dns_test.go b/dns_test.go index 2119ea4..508f708 100644 --- a/dns_test.go +++ b/dns_test.go @@ -2,6 +2,7 @@ package main import ( "errors" + "flag" "fmt" "github.com/miekg/dns" "github.com/op/go-logging" @@ -22,31 +23,45 @@ type resolver struct { server string } -func (r *resolver) lookup(host string, qtype uint16) (string, error) { +func (r *resolver) lookup(host string, qtype uint16) ([]dns.RR, error) { msg := new(dns.Msg) msg.Id = dns.Id() msg.Question = make([]dns.Question, 1) - msg.Question[0] = dns.Question{dns.Fqdn(host), qtype, dns.ClassINET} + msg.Question[0] = dns.Question{Name: dns.Fqdn(host), Qtype: qtype, Qclass: dns.ClassINET} in, err := dns.Exchange(msg, r.server) if err != nil { - return "", errors.New(fmt.Sprintf("Error querying the server [%v]", err)) + return []dns.RR{}, fmt.Errorf("Error querying the server [%v]", err) } if in != nil && in.Rcode != dns.RcodeSuccess { - return "", errors.New(fmt.Sprintf("Recieved error from the server [%s]", dns.RcodeToString[in.Rcode])) + return []dns.RR{}, fmt.Errorf("Recieved error from the server [%s]", dns.RcodeToString[in.Rcode]) } - if len(in.Answer) > 0 { - return in.Answer[0].String(), nil - } - return "", errors.New("No answer") + return in.Answer, nil } -func findRecord(rrstr string, host string, qtype uint16) error { +func hasExpectedTXTAnswer(answer []dns.RR, cmpTXT string) error { + for _, record := range answer { + // We expect only one answer, so no need to loop through the answer slice + if rec, ok := record.(*dns.TXT); ok { + for _, txtValue := range rec.Txt { + if txtValue == cmpTXT { + return nil + } + } + } else { + errmsg := fmt.Sprintf("Got answer of unexpected type [%q]", answer[0]) + return errors.New(errmsg) + } + } + return errors.New("Expected answer not found") +} + +func findRecordFromMemory(rrstr string, host string, qtype uint16) error { var errmsg = "No record found" arr, _ := dns.NewRR(strings.ToLower(rrstr)) - if arr_qt, ok := RR.Records[qtype]; ok { - if arr_hst, ok := arr_qt[host]; ok { - for _, v := range arr_hst { + if arrQt, ok := RR.Records[qtype]; ok { + if arrHst, ok := arrQt[host]; ok { + for _, v := range arrHst { if arr.String() == v.String() { return nil } @@ -61,6 +76,26 @@ func findRecord(rrstr string, host string, qtype uint16) error { } func startDNSServer(addr string) (*dns.Server, resolver) { + + var dbcfg = dbsettings{ + Engine: "sqlite3", + Connection: ":memory:", + } + + var generalcfg = general{ + Domain: "auth.example.org", + Nsname: "ns1.auth.example.org", + Nsadmin: "admin.example.org", + Debug: false, + } + + var dnscfg = DNSConfig{ + Database: dbcfg, + General: generalcfg, + } + + DNSConf = dnscfg + logging.InitForTesting(logging.DEBUG) // DNS server part dns.HandleFunc(".", handleRequest) @@ -77,14 +112,68 @@ func startDNSServer(addr string) (*dns.Server, resolver) { func TestResolveA(t *testing.T) { server, resolver := startDNSServer(testAddr) + defer server.Shutdown() RR.Parse(records) - a, err := resolver.lookup("auth.example.org", dns.TypeA) + answer, err := resolver.lookup("auth.example.org", dns.TypeA) if err != nil { t.Errorf("%v", err) } - err = findRecord(a, "auth.example.org.", dns.TypeA) - if err != nil { - t.Errorf("Answer [%s] did not match the expected, got error: [%s], debug: [%q]", a, err, RR.Records) + + if len(answer) > 0 { + err = findRecordFromMemory(answer[0].String(), "auth.example.org.", dns.TypeA) + if err != nil { + t.Errorf("Answer [%s] did not match the expected, got error: [%s], debug: [%q]", answer[0].String(), err, RR.Records) + } + + } else { + t.Error("No answer for DNS query") + } +} + +func TestResolveTXT(t *testing.T) { + flag.Parse() + if *postgres { + DNSConf.Database.Engine = "postgres" + err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns") + if err != nil { + t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"") + return + } + } else { + DNSConf.Database.Engine = "sqlite3" + _ = DB.Init("sqlite3", ":memory:") + } + defer DB.DB.Close() + + server, resolver := startDNSServer(testAddr) + defer server.Shutdown() + RR.Parse(records) + + validTXT := "______________valid_response_______________" + + atxt, err := DB.Register() + if err != nil { + t.Errorf("Could not initiate db record: [%v]", err) + return + } + atxt.Value = validTXT + err = DB.Update(atxt) + if err != nil { + t.Errorf("Could not update db record: [%v]", err) + return + } + answer, err := resolver.lookup(atxt.Subdomain+".auth.example.org", dns.TypeTXT) + if err != nil { + t.Errorf("%v", err) + return + } + + if len(answer) > 0 { + err = hasExpectedTXTAnswer(answer, validTXT) + if err != nil { + t.Errorf("%v", err) + } + } else { + t.Error("No answer for DNS query") } - server.Shutdown() }