commit 5433444b2f607caf33124f3a800bfc340429d7f6 Author: Joona Hoikkala Date: Fri Nov 11 16:48:00 2016 +0200 Initial commit, PoC quality diff --git a/acmetxt.go b/acmetxt.go new file mode 100644 index 0000000..7bb812c --- /dev/null +++ b/acmetxt.go @@ -0,0 +1,27 @@ +package main + +import ( + "github.com/satori/go.uuid" + "time" +) + +// The default database object +type ACMETxt struct { + Username string `json:"username"` + Password string `json:"password"` + ACMETxtPost + LastActive time.Time +} + +type ACMETxtPost struct { + Subdomain string `json:"subdomain"` + Value string `json:"txt"` +} + +func NewACMETxt() ACMETxt { + var a ACMETxt = ACMETxt{} + a.Username = uuid.NewV4().String() + a.Password = uuid.NewV4().String() + a.Subdomain = uuid.NewV4().String() + return a +} diff --git a/api.go b/api.go new file mode 100644 index 0000000..fa1df85 --- /dev/null +++ b/api.go @@ -0,0 +1,97 @@ +package main + +import ( + "errors" + "fmt" + "github.com/kataras/iris" +) + +func GetHandlerMap() map[string]func(*iris.Context) { + return map[string]func(*iris.Context){ + "/register": WebRegisterGet, + } +} + +func PostHandlerMap() map[string]func(*iris.Context) { + return map[string]func(*iris.Context){ + "/register": WebRegisterPost, + "/update": WebUpdatePost, + } +} + +func WebRegisterPost(ctx *iris.Context) { + // Create new user + nu, err := DB.Register() + var reg_json iris.Map + var reg_status int + if err != nil { + errstr := fmt.Sprintf("%v", err) + + reg_json = iris.Map{"username": "", "password": "", "domain": "", "error": errstr} + reg_status = iris.StatusInternalServerError + } else { + reg_json = iris.Map{"username": nu.Username, "password": nu.Password, "fulldomain": nu.Subdomain + "." + DnsConf.General.Domain, "subdomain": nu.Subdomain} + reg_status = iris.StatusCreated + } + ctx.JSON(reg_status, reg_json) +} + +func WebRegisterGet(ctx *iris.Context) { + // This is placeholder for now + WebRegisterPost(ctx) +} + +func WebUpdatePost(ctx *iris.Context) { + var username, password string + var a ACMETxtPost = ACMETxtPost{} + username = ctx.RequestHeader("X-API-User") + password = ctx.RequestHeader("X-API-Key") + if err := ctx.ReadJSON(&a); err != nil { + // Handle bad post data + WebUpdatePostError(ctx, err, iris.StatusBadRequest) + return + } + // Sanitized by db function + euser, err := DB.GetByUsername(username) + if err != nil { + // DB error + WebUpdatePostError(ctx, err, iris.StatusInternalServerError) + return + } + if len(euser) == 0 { + // User not found + // TODO: do bcrypt to avoid side channel + WebUpdatePostError(ctx, errors.New("invalid user or api key"), iris.StatusUnauthorized) + return + } + // Get first (and the only) user + upduser := euser[0] + // Validate password + if upduser.Password != password { + // Invalid password + WebUpdatePostError(ctx, errors.New("invalid user or api key"), iris.StatusUnauthorized) + return + } else { + // Do update + if len(a.Value) == 0 { + WebUpdatePostError(ctx, errors.New("missing txt value"), iris.StatusBadRequest) + return + } else { + upduser.Value = a.Value + err = DB.Update(upduser) + if err != nil { + WebUpdatePostError(ctx, err, iris.StatusInternalServerError) + return + } + // All ok + ctx.JSON(iris.StatusOK, iris.Map{"txt": upduser.Value}) + } + } + +} + +func WebUpdatePostError(ctx *iris.Context, err error, status int) { + err_str := fmt.Sprintf("%v", err) + upd_json := iris.Map{"error": err_str} + ctx.JSON(status, upd_json) +} diff --git a/config.cfg b/config.cfg new file mode 100644 index 0000000..f7ca620 --- /dev/null +++ b/config.cfg @@ -0,0 +1,27 @@ +[general] +# domain name to serve th requests off of +domain = "auth.example.org" +# zone name server +nsname = "ns1.auth.example.org" +# admin email address, with @ substituted with . +nsadmin = "admin.example.org" + +# possible values: "letsencrypt", "cert", "false" +tls = "letsencrypt" + +# only used if tls = "cert" +tls_cert_privkey = "/etc/tls/example.org/privkey.pem" +tls_cert_fullchain = "/etc/tls/example.org/fullchain.pem" + +# predefined records that we're serving in addition to the TXT + +records = [ + # default A + "auth.example.org. A 192.168.1.100", + # A + "ns1.auth.example.org. A 192.168.1.100", + "ns2.auth.example.org. A 192.168.1.100", + # NS + "auth.example.org. NS ns1.auth.example.org.", + "auth.example.org. NS ns2.auth.example.org.", +] diff --git a/config.go b/config.go new file mode 100644 index 0000000..463d12a --- /dev/null +++ b/config.go @@ -0,0 +1,14 @@ +package main + +import ( + "errors" + "github.com/BurntSushi/toml" +) + +func ReadConfig(fname string) (DnsConfig, error) { + var conf DnsConfig + if _, err := toml.DecodeFile(fname, &conf); err != nil { + return DnsConfig{}, errors.New("Malformed configuration file") + } + return conf, nil +} diff --git a/db.go b/db.go new file mode 100644 index 0000000..72a65fb --- /dev/null +++ b/db.go @@ -0,0 +1,189 @@ +package main + +import ( + "database/sql" + //"encoding/json" + //"github.com/boltdb/bolt" + _ "github.com/mattn/go-sqlite3" + //"strings" +) + +type Database struct { + DB *sql.DB +} + +var records_table string = ` + CREATE TABLE IF NOT EXISTS records( + Username TEXT UNIQUE NOT NULL PRIMARY KEY, + Password TEXT UNIQUE NOT NULL, + Subdomain TEXT UNIQUE NOT NULL, + Value TEXT, + LastActive DATETIME + );` + +func (d *Database) Init(filename string) error { + db, err := sql.Open("sqlite3", filename) + if err != nil { + return err + } + d.DB = db + _, err = d.DB.Exec(records_table) + if err != nil { + return err + } + return nil +} + +func (d *Database) Register() (ACMETxt, error) { + a := NewACMETxt() + reg_sql := ` + INSERT INTO records( + Username, + Password, + Subdomain, + Value, + LastActive) + values(?, ?, ?, ?, CURRENT_TIMESTAMP)` + sm, err := d.DB.Prepare(reg_sql) + if err != nil { + return a, err + } + defer sm.Close() + _, err = sm.Exec(a.Username, a.Password, a.Subdomain, a.Value) + if err != nil { + return a, err + } + // Do an insert check + /* + id, err := status.LastInsertId() + if err != nil { + return a, err + }*/ + + return a, nil +} + +func (d *Database) GetByUsername(u string) ([]ACMETxt, error) { + u = NormalizeString(u, 36) + log.Debugf("Trying to select by user [%s] from table", u) + var results []ACMETxt + get_sql := ` + SELECT Username, Password, Subdomain, Value + FROM records + WHERE Username=? LIMIT 1 + ` + sm, err := d.DB.Prepare(get_sql) + if err != nil { + return nil, err + } + defer sm.Close() + rows, err := sm.Query(u) + if err != nil { + return nil, err + } + defer rows.Close() + + // It will only be one row though + for rows.Next() { + var a ACMETxt = ACMETxt{} + err = rows.Scan(&a.Username, &a.Password, &a.Subdomain, &a.Value) + if err != nil { + return nil, err + } + results = append(results, a) + } + return results, nil +} + +func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) { + domain = NormalizeString(domain, 36) + log.Debugf("Trying to select domain [%s] from table", domain) + var a []ACMETxt + get_sql := ` + SELECT Username, Password, Subdomain, Value + FROM records + WHERE Subdomain=? LIMIT 1 + ` + sm, err := d.DB.Prepare(get_sql) + if err != nil { + return a, err + } + defer sm.Close() + rows, err := sm.Query(domain) + if err != nil { + return a, err + } + defer rows.Close() + + for rows.Next() { + txt := ACMETxt{} + err = rows.Scan(&txt.Username, &txt.Password, &txt.Subdomain, &txt.Value) + if err != nil { + return a, err + } + a = append(a, txt) + } + return a, nil +} + +func (d *Database) Update(a ACMETxt) error { + // Data in a is already sanitized + log.Debugf("Trying to update domain [%s] with TXT data [%s]", a.Subdomain, a.Value) + upd_sql := ` + UPDATE records SET Value=? + WHERE Username=? AND Subdomain=? + ` + sm, err := d.DB.Prepare(upd_sql) + if err != nil { + return err + } + defer sm.Close() + _, err = sm.Exec(a.Value, a.Username, a.Subdomain) + if err != nil { + return err + } + return nil +} + +/* +func addTXT(txt ACMETxt) error { + + err := db.Update(func(tx *bolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists([]byte("domains")) + if err != nil { + return err + } + jtxt, err := json.Marshal(txt) + if err != nil { + return err + } + + // put returns nil if successful, nil return commits db.Update + return bucket.Put([]byte(strings.ToLower(txt.Domain)), jtxt) + }) + return err + +} + +func getTXT(domain string) (ACMETxt, error) { + var atxt ACMETxt + err := db.View(func(tx *bolt.Tx) error { + bucket := tx.Bucket([]byte("domains")) + value := bucket.Get([]byte(strings.ToLower(domain))) + if len(value) == 0 { + // Not found + log.Debugf("Record for [%s] not found", domain) + atxt = ACMETxt{} + } else { + if err := json.Unmarshal(value, &atxt); err != nil { + return err + } + } + return nil + }) + if err != nil { + return ACMETxt{}, err + } + return atxt, err +} +*/ diff --git a/dns.go b/dns.go new file mode 100644 index 0000000..e5be504 --- /dev/null +++ b/dns.go @@ -0,0 +1,104 @@ +package main + +import ( + "fmt" + "github.com/miekg/dns" + "time" +) + +func readQuery(m *dns.Msg) { + for _, que := range m.Question { + if rr, rc, err := answer(que); err == nil { + m.MsgHdr.Rcode = rc + for _, r := range rr { + m.Answer = append(m.Answer, r) + } + } + } +} + +func answerTXT(q dns.Question) ([]dns.RR, int, error) { + var ra []dns.RR + var rcode int = dns.RcodeNameError + var domain string = q.Name + + atxt, err := DB.GetByDomain(domain) + if err != nil { + log.Errorf("Error while trying to get record [%v]", err) + return ra, dns.RcodeNameError, err + } + for _, v := range atxt { + if len(v.Value) > 0 { + r := new(dns.TXT) + r.Hdr = dns.RR_Header{Name: domain, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1} + r.Txt = append(r.Txt, v.Value) + ra = append(ra, r) + rcode = dns.RcodeSuccess + } + } + log.Debugf("Answering TXT question for domain [%s]", domain) + return ra, rcode, nil +} + +func answer(q dns.Question) ([]dns.RR, int, error) { + if q.Qtype == dns.TypeTXT { + return answerTXT(q) + } + var r []dns.RR + var rcode int = dns.RcodeSuccess + var domain string = q.Name + var rtype uint16 = q.Qtype + r, ok := RR.Records[rtype][domain] + if !ok { + rcode = dns.RcodeNameError + } + log.Debugf("Answering [%s] question for domain [%s] with rcode [%s]", dns.TypeToString[rtype], domain, dns.RcodeToString[rcode]) + return r, rcode, nil +} + +func handleRequest(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + + if r.Opcode == dns.OpcodeQuery { + readQuery(m) + } + + w.WriteMsg(m) +} + +// Parse config records +func (r *Records) Parse(recs []string) { + rrmap := make(map[uint16]map[string][]dns.RR) + for _, v := range recs { + rr, err := dns.NewRR(v) + if err != nil { + log.Errorf("Could not parse RR from config: [%v] for RR: [%s]", err, v) + continue + } + // Add parsed RR to the list + rrmap = AppendRR(rrmap, rr) + } + // Create serial + serial := time.Now().Format("2006010215") + // Add SOA + SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 28800 7200 604800 86400", DnsConf.General.Domain, DnsConf.General.Nsname, DnsConf.General.Nsadmin, serial) + soarr, err := dns.NewRR(SOAstring) + if err != nil { + log.Errorf("Error [%v] while trying to add SOA record: [%s]", err, SOAstring) + } else { + rrmap = AppendRR(rrmap, soarr) + } + r.Records = rrmap +} + +func AppendRR(rrmap map[uint16]map[string][]dns.RR, rr dns.RR) map[uint16]map[string][]dns.RR { + _, ok := rrmap[rr.Header().Rrtype] + if !ok { + newrr := make(map[string][]dns.RR) + rrmap[rr.Header().Rrtype] = newrr + } + rrmap[rr.Header().Rrtype][rr.Header().Name] = append(rrmap[rr.Header().Rrtype][rr.Header().Name], rr) + log.Debugf("Adding new record of type [%s] for domain [%s]", dns.TypeToString[rr.Header().Rrtype], rr.Header().Name) + return rrmap +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..02e7de9 --- /dev/null +++ b/main.go @@ -0,0 +1,89 @@ +package main + +import ( + "fmt" + "github.com/kataras/iris" + "github.com/miekg/dns" + "github.com/op/go-logging" + "os" +) + +// Logging config +var logfile_path = "acme-dns.log" +var log = logging.MustGetLogger("acme-dns") + +// Global configuration struct +var DnsConf DnsConfig + +var DB Database + +// Static records +var RR Records + +func main() { + // Setup logging + var stdout_format = logging.MustStringFormatter( + `%{color}%{time:15:04:05.000} %{shortfunc} ▶ %{level:.4s} %{id:03x}%{color:reset} %{message}`, + ) + var file_format = logging.MustStringFormatter( + `%{time:15:04:05.000} %{shortfunc} - %{level:.4s} %{id:03x} %{message}`, + ) + // Setup logging - stdout + logStdout := logging.NewLogBackend(os.Stdout, "", 0) + logStdoutFormatter := logging.NewBackendFormatter(logStdout, stdout_format) + // Setup logging - file + // Logging to file + logfh, err := os.OpenFile(logfile_path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) + if err != nil { + fmt.Printf("Could not open log file %s\n", logfile_path) + os.Exit(1) + } + defer logfh.Close() + logFile := logging.NewLogBackend(logfh, "", 0) + logFileFormatter := logging.NewBackendFormatter(logFile, file_format) + /* To limit logging to a level + logFileLeveled := logging.AddModuleLevel(logFile) + logFileLeveled.SetLevel(logging.ERROR, "") + */ + + // Start logging + logging.SetBackend(logStdoutFormatter, logFileFormatter) + log.Debug("Starting up...") + + // Read global config + if DnsConf, err = ReadConfig("config.cfg"); err != nil { + log.Errorf("Got error %v", err) + os.Exit(1) + } + RR.Parse(DnsConf.General.StaticRecords) + + // Open database + err = DB.Init("acme-dns.db") + if err != nil { + log.Errorf("Could not open database [%v]", err) + os.Exit(1) + } + defer DB.DB.Close() + + // DNS server part + dns.HandleFunc(".", handleRequest) + server := &dns.Server{Addr: ":53", Net: "udp"} + go func() { + err = server.ListenAndServe() + if err != nil { + log.Errorf("%v", err) + os.Exit(1) + } + }() + + // API server + api := iris.New() + for path, handlerfunc := range GetHandlerMap() { + api.Get(path, handlerfunc) + } + for path, handlerfunc := range PostHandlerMap() { + api.Post(path, handlerfunc) + } + api.Listen(":8080") + log.Debugf("Shutting down...") +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..3fd2d35 --- /dev/null +++ b/types.go @@ -0,0 +1,26 @@ +package main + +import ( + "github.com/miekg/dns" +) + +// Static records +type Records struct { + Records map[uint16]map[string][]dns.RR +} + +// Config file main struct +type DnsConfig struct { + General general +} + +// Config file general section +type general struct { + Domain string + Nsname string + Nsadmin string + Tls string + Tls_cert_privkey string + Tls_cert_fullchain string + StaticRecords []string `toml:"records"` +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..175b175 --- /dev/null +++ b/util.go @@ -0,0 +1,20 @@ +package main + +import ( + "regexp" + "unicode/utf8" +) + +func NormalizeString(s string, len int) string { + var ret string + re, err := regexp.Compile("[^A-Za-z\\-0-9]+") + if err != nil { + log.Errorf("%v", err) + return "" + } + ret = re.ReplaceAllString(s, "") + if utf8.RuneCountInString(ret) > len { + ret = ret[0:len] + } + return ret +}