diff --git a/config.cfg b/config.cfg index a180021..f992964 100644 --- a/config.cfg +++ b/config.cfg @@ -18,6 +18,12 @@ records = [ # specify that auth.example.org will resolve any *.auth.example.org records "auth.example.org. NS auth.example.org.", ] +# path to cache SOA serial +serialpath = "soa-serial.save" +# slaves to notify on update and allowed to request AXFR +slaves = [ + # "10.5.1.1" +] # debug messages from CORS etc debug = false diff --git a/pkg/acmedns/interfaces.go b/pkg/acmedns/interfaces.go index bd38b49..7fa9880 100644 --- a/pkg/acmedns/interfaces.go +++ b/pkg/acmedns/interfaces.go @@ -10,6 +10,7 @@ type AcmednsDB interface { Register(cidrslice Cidrslice) (ACMETxt, error) GetByUsername(uuid.UUID) (ACMETxt, error) GetTXTForDomain(string) ([]string, error) + GetTXTForAllDomains() ([]TXTRecord, error) Update(ACMETxtPost) error GetBackend() *sql.DB SetBackend(*sql.DB) @@ -21,4 +22,10 @@ type AcmednsNS interface { SetOwnAuthKey(key string) SetNotifyStartedFunc(func()) ParseRecords() + BumpSerial() error +} + +type TXTRecord struct { + Subdomain string + Value string } diff --git a/pkg/acmedns/types.go b/pkg/acmedns/types.go index 53f73a3..d2c08f0 100644 --- a/pkg/acmedns/types.go +++ b/pkg/acmedns/types.go @@ -25,6 +25,8 @@ type general struct { Nsadmin string Debug bool StaticRecords []string `toml:"records"` + Serialpath string + SlaveHosts []string `toml:"slaves"` } type dbsettings struct { diff --git a/pkg/api/api.go b/pkg/api/api.go index be306a6..8aa157d 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -15,10 +15,11 @@ import ( ) type AcmednsAPI struct { - Config *acmedns.AcmeDnsConfig - DB acmedns.AcmednsDB - Logger *zap.SugaredLogger - errChan chan error + Config *acmedns.AcmeDnsConfig + DB acmedns.AcmednsDB + Logger *zap.SugaredLogger + errChan chan error + dnsServers []acmedns.AcmednsNS } func Init(config *acmedns.AcmeDnsConfig, db acmedns.AcmednsDB, logger *zap.SugaredLogger, errChan chan error) AcmednsAPI { @@ -27,6 +28,8 @@ func Init(config *acmedns.AcmeDnsConfig, db acmedns.AcmednsDB, logger *zap.Sugar } func (a *AcmednsAPI) Start(dnsservers []acmedns.AcmednsNS) { + //we need the dnsservers later to bump serial + a.dnsServers = dnsservers var err error //TODO: do we want to debug log the HTTP server? stderrorlog, err := zap.NewStdLogAt(a.Logger.Desugar(), zap.ErrorLevel) diff --git a/pkg/api/update.go b/pkg/api/update.go index 55c6152..caba19b 100644 --- a/pkg/api/update.go +++ b/pkg/api/update.go @@ -49,6 +49,10 @@ func (a *AcmednsAPI) webUpdatePost(w http.ResponseWriter, r *http.Request, _ htt upd = []byte("{\"txt\": \"" + atxt.Value + "\"}") } } + for _, s := range a.dnsServers { + //bump SOA serial on update (and notify slaves if configured) + s.BumpSerial() + } w.Header().Set("Content-Type", "application/json") w.WriteHeader(updStatus) _, _ = w.Write(upd) diff --git a/pkg/database/db.go b/pkg/database/db.go index 00f60ef..2b0688d 100644 --- a/pkg/database/db.go +++ b/pkg/database/db.go @@ -263,6 +263,50 @@ func (d *acmednsdb) GetByUsername(u uuid.UUID) (acmedns.ACMETxt, error) { return acmedns.ACMETxt{}, fmt.Errorf("user not found: %s", u.String()) } +func (d *acmednsdb) GetTXTForAllDomains() ([]acmedns.TXTRecord, error) { + d.Mutex.Lock() + defer d.Mutex.Unlock() + + var txts []acmedns.TXTRecord + + getSQL := ` + SELECT Subdomain, Value FROM txt + ` + if d.Config.Database.Engine == "sqlite" { + getSQL = getSQLiteStmt(getSQL) + } + + sm, err := d.DB.Prepare(getSQL) + if err != nil { + return txts, err + } + defer sm.Close() + + rows, err := sm.Query() + if err != nil { + return txts, err + } + defer rows.Close() + + for rows.Next() { + var subdomain string + var value string + + err = rows.Scan(&subdomain, &value) + if err != nil { + return txts, err + } + + d.Logger.Debugw("GetTXTForAllDomains() TXT Record:", subdomain, value) + txts = append(txts, acmedns.TXTRecord{ + Subdomain: subdomain, + Value: value, + }) + } + + return txts, nil +} + func (d *acmednsdb) GetTXTForDomain(domain string) ([]string, error) { d.Mutex.Lock() defer d.Mutex.Unlock() diff --git a/pkg/nameserver/handler.go b/pkg/nameserver/handler.go index 18bf2af..204390f 100644 --- a/pkg/nameserver/handler.go +++ b/pkg/nameserver/handler.go @@ -2,12 +2,48 @@ package nameserver import ( "fmt" + "net" "strings" "github.com/miekg/dns" ) func (n *Nameserver) handleRequest(w dns.ResponseWriter, r *dns.Msg) { + + if len(r.Question) == 1 { + q := r.Question[0] + if q.Qtype == dns.TypeAXFR || q.Qtype == dns.TypeIXFR { // Get remote IP + remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String()) + if err != nil { + n.Logger.Errorw("Failed to parse remote address", "err", err) + m := new(dns.Msg) + m.SetReply(r) + m.Rcode = dns.RcodeRefused + _ = w.WriteMsg(m) + return + } + + // Check if remote IP is in slave list + allowed := false + for _, slave := range n.Config.General.SlaveHosts { + if remoteIP == slave { + allowed = true + break + } + } + + if !allowed { + n.Logger.Warnw("AXFR/IXFR request denied", "remote", remoteIP) + m := new(dns.Msg) + m.SetReply(r) + m.Rcode = dns.RcodeRefused + _ = w.WriteMsg(m) + return + } + n.handleAXFR(w, r) + return + } + } m := new(dns.Msg) m.SetReply(r) // handle edns0 @@ -71,6 +107,9 @@ func (n *Nameserver) answer(q dns.Question) ([]dns.RR, int, bool, error) { r = append(r, txtRRs...) } } + if q.Qtype == dns.TypeSOA { + r = append(r, n.SOA) + } if len(r) > 0 { // Make sure that we return NOERROR if there were dynamic records for the domain rcode = dns.RcodeSuccess @@ -158,3 +197,77 @@ func (n *Nameserver) getRecord(name string, qtype uint16) ([]dns.RR, error) { } return rr, nil } + +func (n *Nameserver) handleAXFR(w dns.ResponseWriter, r *dns.Msg) { + + if len(r.Question) == 0 { + return + } + + zone := dns.Fqdn(r.Question[0].Name) + + records, ok := n.Domains[zone] + if !ok { + m := new(dns.Msg) + m.SetRcode(r, dns.RcodeNameError) + _ = w.WriteMsg(m) + return + } + + // AXFR muss über Transfer laufen + tr := new(dns.Transfer) + + c := make(chan *dns.Envelope) + + go func() { + defer close(c) + + var rr []dns.RR + + // Start SOA + rr = append(rr, n.SOA) + + // NS + rr = append(rr, records.NS...) + + // Andere Records + // rr = append(rr, filterSOA(records.Records)...) + rr = append(rr, records.Records...) + + // TXT Records nur für diese Zone! + txtRecords, err := n.DB.GetTXTForAllDomains() + if err == nil { + for _, rec := range txtRecords { + if rec.Value == "" { + continue + } + + fqdn := dns.Fqdn(rec.Subdomain + "." + zone) + + txtRR := &dns.TXT{ + Hdr: dns.RR_Header{ + Name: fqdn, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: 1, + }, + Txt: []string{rec.Value}, + } + + rr = append(rr, txtRR) + + n.Logger.Debugw("handleAXFR TXT Record", "subdomain", rec.Subdomain, "value", rec.Value, "fqdn", fqdn) + rr = append(rr, txtRR) + } + } else { + n.Logger.Errorw("Failed to get TXT records for AXFR", "error", err) + } + + // End SOA + rr = append(rr, n.SOA) + + c <- &dns.Envelope{RR: rr} + }() + + _ = tr.Out(w, r, c) +} diff --git a/pkg/nameserver/initialize.go b/pkg/nameserver/initialize.go index a6bcf6e..3624367 100644 --- a/pkg/nameserver/initialize.go +++ b/pkg/nameserver/initialize.go @@ -14,6 +14,7 @@ import ( // Records is a slice of ResourceRecords type Records struct { Records []dns.RR + NS []dns.RR } type Nameserver struct { @@ -23,7 +24,7 @@ type Nameserver struct { Server *dns.Server OwnDomain string NotifyStartedFunc func() - SOA dns.RR + SOA *dns.SOA mu sync.RWMutex personalAuthKey string Domains map[string]Records diff --git a/pkg/nameserver/parseconfig.go b/pkg/nameserver/parseconfig.go index 86c3379..d63e5ce 100644 --- a/pkg/nameserver/parseconfig.go +++ b/pkg/nameserver/parseconfig.go @@ -2,12 +2,64 @@ package nameserver import ( "fmt" + "os" + "strconv" "strings" "time" "github.com/miekg/dns" ) +func loadSerial(path string) (uint32, error) { + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return 0, nil // first start + } + return 0, err + } + + s := strings.TrimSpace(string(data)) + val, err := strconv.ParseUint(s, 10, 32) + if err != nil { + return 0, err + } + + return uint32(val), nil +} + +func saveSerial(path string, serial uint32) error { + tmp := path + ".tmp" + + data := []byte(fmt.Sprintf("%d\n", serial)) + + // write temp file + if err := os.WriteFile(tmp, data, 0644); err != nil { + return err + } + + // atomic replace + return os.Rename(tmp, path) +} + +func nextSerial(old uint32) uint32 { + today := time.Now().Format("20060102") + + oldStr := fmt.Sprintf("%d", old) + + if strings.HasPrefix(oldStr, today) { + return old + 1 + } + + newSerial, _ := strconv.Atoi(today + "00") + + if uint32(newSerial) <= old { + return old + 1 + } + + return uint32(newSerial) +} + // ParseRecords parses a slice of DNS record string func (n *Nameserver) ParseRecords() { for _, v := range n.Config.General.StaticRecords { @@ -22,25 +74,72 @@ func (n *Nameserver) ParseRecords() { n.appendRR(rr) } // Create serial - serial := time.Now().Format("2006010215") - // Add SOA - SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 28800 7200 604800 86400", strings.ToLower(n.Config.General.Domain), strings.ToLower(n.Config.General.Nsname), strings.ToLower(n.Config.General.Nsadmin), serial) - soarr, err := dns.NewRR(SOAstring) + serial, err := loadSerial(n.Config.General.Serialpath) + if err != nil { - n.Logger.Errorw("Error while adding SOA record", - "error", err.Error(), - "soa", SOAstring) - } else { - n.appendRR(soarr) - n.SOA = soarr + n.Logger.Errorw("Could not load temp serial", + "error", err.Error()) } + if serial == 0 { + serial = uint32(time.Now().Unix()) + } + // Add SOA + //Refresh = 30s → Slaves fragen alle 30s nach Änderungen + //Retry = 10s → Wenn Master nicht erreichbar, probiert der Slave alle 10s erneut + //Expire = 604800s (1w) → Wie lange der Slave die Zone noch behält, falls Master ausfällt + //Minimum TTL = 20s → Resolver cachen die TXT-Einträge nur kurz + //SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 5 10 604800 20", strings.ToLower(n.Config.General.Domain), strings.ToLower(n.Config.General.Nsname), strings.ToLower(n.Config.General.Nsadmin), serial) + n.SOA = &dns.SOA{ + Hdr: dns.RR_Header{ + Name: dns.Fqdn(n.Config.General.Domain), + Rrtype: dns.TypeSOA, + Class: dns.ClassINET, + Ttl: 5, + }, + Ns: dns.Fqdn(n.Config.General.Nsname), + Mbox: dns.Fqdn(n.Config.General.Nsadmin), + Serial: serial, + Refresh: 5, + Retry: 10, + Expire: 604800, + Minttl: 1, + } +} + +func sendNotify(zone string, slaveAddr string) error { + m := new(dns.Msg) + m.SetNotify(dns.Fqdn(zone)) // Set opcode to NOTIFY + m.Authoritative = true // Must be authoritative + + c := new(dns.Client) + _, _, err := c.Exchange(m, slaveAddr) + return err +} + +func (n *Nameserver) BumpSerial() error { + n.mu.Lock() + defer n.mu.Unlock() + n.SOA.Serial = nextSerial(n.SOA.Serial) + + for _, slave := range n.Config.General.SlaveHosts { + slave := slave + ":53" + if err := sendNotify(n.SOA.Hdr.Name, slave); err != nil { + n.Logger.Errorw("Failed to notify slave", "slave", slave, "err", err) + } else { + n.Logger.Debugw("Notify send to slave", "slave", slave) + } + } + return saveSerial(n.Config.General.Serialpath, n.SOA.Serial) } func (n *Nameserver) appendRR(rr dns.RR) { addDomain := rr.Header().Name _, ok := n.Domains[addDomain] if !ok { - n.Domains[addDomain] = Records{[]dns.RR{rr}} + n.Domains[addDomain] = Records{ + Records: []dns.RR{rr}, // initialisiere Records + NS: []dns.RR{}, // leeres NS-Slice, sonst Fehler + } } else { drecs := n.Domains[addDomain] drecs.Records = append(drecs.Records, rr)