From 5433444b2f607caf33124f3a800bfc340429d7f6 Mon Sep 17 00:00:00 2001 From: Joona Hoikkala Date: Fri, 11 Nov 2016 16:48:00 +0200 Subject: [PATCH] Initial commit, PoC quality --- acmetxt.go | 27 ++++++++ api.go | 97 +++++++++++++++++++++++++++ config.cfg | 27 ++++++++ config.go | 14 ++++ db.go | 189 +++++++++++++++++++++++++++++++++++++++++++++++++++++ dns.go | 104 +++++++++++++++++++++++++++++ main.go | 89 +++++++++++++++++++++++++ types.go | 26 ++++++++ util.go | 20 ++++++ 9 files changed, 593 insertions(+) create mode 100644 acmetxt.go create mode 100644 api.go create mode 100644 config.cfg create mode 100644 config.go create mode 100644 db.go create mode 100644 dns.go create mode 100644 main.go create mode 100644 types.go create mode 100644 util.go diff --git a/acmetxt.go b/acmetxt.go new file mode 100644 index 0000000..7bb812c --- /dev/null +++ b/acmetxt.go @@ -0,0 +1,27 @@ +package main + +import ( + "github.com/satori/go.uuid" + "time" +) + +// The default database object +type ACMETxt struct { + Username string `json:"username"` + Password string `json:"password"` + ACMETxtPost + LastActive time.Time +} + +type ACMETxtPost struct { + Subdomain string `json:"subdomain"` + Value string `json:"txt"` +} + +func NewACMETxt() ACMETxt { + var a ACMETxt = ACMETxt{} + a.Username = uuid.NewV4().String() + a.Password = uuid.NewV4().String() + a.Subdomain = uuid.NewV4().String() + return a +} diff --git a/api.go b/api.go new file mode 100644 index 0000000..fa1df85 --- /dev/null +++ b/api.go @@ -0,0 +1,97 @@ +package main + +import ( + "errors" + "fmt" + "github.com/kataras/iris" +) + +func GetHandlerMap() map[string]func(*iris.Context) { + return map[string]func(*iris.Context){ + "/register": WebRegisterGet, + } +} + +func PostHandlerMap() map[string]func(*iris.Context) { + return map[string]func(*iris.Context){ + "/register": WebRegisterPost, + "/update": WebUpdatePost, + } +} + +func WebRegisterPost(ctx *iris.Context) { + // Create new user + nu, err := DB.Register() + var reg_json iris.Map + var reg_status int + if err != nil { + errstr := fmt.Sprintf("%v", err) + + reg_json = iris.Map{"username": "", "password": "", "domain": "", "error": errstr} + reg_status = iris.StatusInternalServerError + } else { + reg_json = iris.Map{"username": nu.Username, "password": nu.Password, "fulldomain": nu.Subdomain + "." + DnsConf.General.Domain, "subdomain": nu.Subdomain} + reg_status = iris.StatusCreated + } + ctx.JSON(reg_status, reg_json) +} + +func WebRegisterGet(ctx *iris.Context) { + // This is placeholder for now + WebRegisterPost(ctx) +} + +func WebUpdatePost(ctx *iris.Context) { + var username, password string + var a ACMETxtPost = ACMETxtPost{} + username = ctx.RequestHeader("X-API-User") + password = ctx.RequestHeader("X-API-Key") + if err := ctx.ReadJSON(&a); err != nil { + // Handle bad post data + WebUpdatePostError(ctx, err, iris.StatusBadRequest) + return + } + // Sanitized by db function + euser, err := DB.GetByUsername(username) + if err != nil { + // DB error + WebUpdatePostError(ctx, err, iris.StatusInternalServerError) + return + } + if len(euser) == 0 { + // User not found + // TODO: do bcrypt to avoid side channel + WebUpdatePostError(ctx, errors.New("invalid user or api key"), iris.StatusUnauthorized) + return + } + // Get first (and the only) user + upduser := euser[0] + // Validate password + if upduser.Password != password { + // Invalid password + WebUpdatePostError(ctx, errors.New("invalid user or api key"), iris.StatusUnauthorized) + return + } else { + // Do update + if len(a.Value) == 0 { + WebUpdatePostError(ctx, errors.New("missing txt value"), iris.StatusBadRequest) + return + } else { + upduser.Value = a.Value + err = DB.Update(upduser) + if err != nil { + WebUpdatePostError(ctx, err, iris.StatusInternalServerError) + return + } + // All ok + ctx.JSON(iris.StatusOK, iris.Map{"txt": upduser.Value}) + } + } + +} + +func WebUpdatePostError(ctx *iris.Context, err error, status int) { + err_str := fmt.Sprintf("%v", err) + upd_json := iris.Map{"error": err_str} + ctx.JSON(status, upd_json) +} diff --git a/config.cfg b/config.cfg new file mode 100644 index 0000000..f7ca620 --- /dev/null +++ b/config.cfg @@ -0,0 +1,27 @@ +[general] +# domain name to serve th requests off of +domain = "auth.example.org" +# zone name server +nsname = "ns1.auth.example.org" +# admin email address, with @ substituted with . +nsadmin = "admin.example.org" + +# possible values: "letsencrypt", "cert", "false" +tls = "letsencrypt" + +# only used if tls = "cert" +tls_cert_privkey = "/etc/tls/example.org/privkey.pem" +tls_cert_fullchain = "/etc/tls/example.org/fullchain.pem" + +# predefined records that we're serving in addition to the TXT + +records = [ + # default A + "auth.example.org. A 192.168.1.100", + # A + "ns1.auth.example.org. A 192.168.1.100", + "ns2.auth.example.org. A 192.168.1.100", + # NS + "auth.example.org. NS ns1.auth.example.org.", + "auth.example.org. NS ns2.auth.example.org.", +] diff --git a/config.go b/config.go new file mode 100644 index 0000000..463d12a --- /dev/null +++ b/config.go @@ -0,0 +1,14 @@ +package main + +import ( + "errors" + "github.com/BurntSushi/toml" +) + +func ReadConfig(fname string) (DnsConfig, error) { + var conf DnsConfig + if _, err := toml.DecodeFile(fname, &conf); err != nil { + return DnsConfig{}, errors.New("Malformed configuration file") + } + return conf, nil +} diff --git a/db.go b/db.go new file mode 100644 index 0000000..72a65fb --- /dev/null +++ b/db.go @@ -0,0 +1,189 @@ +package main + +import ( + "database/sql" + //"encoding/json" + //"github.com/boltdb/bolt" + _ "github.com/mattn/go-sqlite3" + //"strings" +) + +type Database struct { + DB *sql.DB +} + +var records_table string = ` + CREATE TABLE IF NOT EXISTS records( + Username TEXT UNIQUE NOT NULL PRIMARY KEY, + Password TEXT UNIQUE NOT NULL, + Subdomain TEXT UNIQUE NOT NULL, + Value TEXT, + LastActive DATETIME + );` + +func (d *Database) Init(filename string) error { + db, err := sql.Open("sqlite3", filename) + if err != nil { + return err + } + d.DB = db + _, err = d.DB.Exec(records_table) + if err != nil { + return err + } + return nil +} + +func (d *Database) Register() (ACMETxt, error) { + a := NewACMETxt() + reg_sql := ` + INSERT INTO records( + Username, + Password, + Subdomain, + Value, + LastActive) + values(?, ?, ?, ?, CURRENT_TIMESTAMP)` + sm, err := d.DB.Prepare(reg_sql) + if err != nil { + return a, err + } + defer sm.Close() + _, err = sm.Exec(a.Username, a.Password, a.Subdomain, a.Value) + if err != nil { + return a, err + } + // Do an insert check + /* + id, err := status.LastInsertId() + if err != nil { + return a, err + }*/ + + return a, nil +} + +func (d *Database) GetByUsername(u string) ([]ACMETxt, error) { + u = NormalizeString(u, 36) + log.Debugf("Trying to select by user [%s] from table", u) + var results []ACMETxt + get_sql := ` + SELECT Username, Password, Subdomain, Value + FROM records + WHERE Username=? LIMIT 1 + ` + sm, err := d.DB.Prepare(get_sql) + if err != nil { + return nil, err + } + defer sm.Close() + rows, err := sm.Query(u) + if err != nil { + return nil, err + } + defer rows.Close() + + // It will only be one row though + for rows.Next() { + var a ACMETxt = ACMETxt{} + err = rows.Scan(&a.Username, &a.Password, &a.Subdomain, &a.Value) + if err != nil { + return nil, err + } + results = append(results, a) + } + return results, nil +} + +func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) { + domain = NormalizeString(domain, 36) + log.Debugf("Trying to select domain [%s] from table", domain) + var a []ACMETxt + get_sql := ` + SELECT Username, Password, Subdomain, Value + FROM records + WHERE Subdomain=? LIMIT 1 + ` + sm, err := d.DB.Prepare(get_sql) + if err != nil { + return a, err + } + defer sm.Close() + rows, err := sm.Query(domain) + if err != nil { + return a, err + } + defer rows.Close() + + for rows.Next() { + txt := ACMETxt{} + err = rows.Scan(&txt.Username, &txt.Password, &txt.Subdomain, &txt.Value) + if err != nil { + return a, err + } + a = append(a, txt) + } + return a, nil +} + +func (d *Database) Update(a ACMETxt) error { + // Data in a is already sanitized + log.Debugf("Trying to update domain [%s] with TXT data [%s]", a.Subdomain, a.Value) + upd_sql := ` + UPDATE records SET Value=? + WHERE Username=? AND Subdomain=? + ` + sm, err := d.DB.Prepare(upd_sql) + if err != nil { + return err + } + defer sm.Close() + _, err = sm.Exec(a.Value, a.Username, a.Subdomain) + if err != nil { + return err + } + return nil +} + +/* +func addTXT(txt ACMETxt) error { + + err := db.Update(func(tx *bolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists([]byte("domains")) + if err != nil { + return err + } + jtxt, err := json.Marshal(txt) + if err != nil { + return err + } + + // put returns nil if successful, nil return commits db.Update + return bucket.Put([]byte(strings.ToLower(txt.Domain)), jtxt) + }) + return err + +} + +func getTXT(domain string) (ACMETxt, error) { + var atxt ACMETxt + err := db.View(func(tx *bolt.Tx) error { + bucket := tx.Bucket([]byte("domains")) + value := bucket.Get([]byte(strings.ToLower(domain))) + if len(value) == 0 { + // Not found + log.Debugf("Record for [%s] not found", domain) + atxt = ACMETxt{} + } else { + if err := json.Unmarshal(value, &atxt); err != nil { + return err + } + } + return nil + }) + if err != nil { + return ACMETxt{}, err + } + return atxt, err +} +*/ diff --git a/dns.go b/dns.go new file mode 100644 index 0000000..e5be504 --- /dev/null +++ b/dns.go @@ -0,0 +1,104 @@ +package main + +import ( + "fmt" + "github.com/miekg/dns" + "time" +) + +func readQuery(m *dns.Msg) { + for _, que := range m.Question { + if rr, rc, err := answer(que); err == nil { + m.MsgHdr.Rcode = rc + for _, r := range rr { + m.Answer = append(m.Answer, r) + } + } + } +} + +func answerTXT(q dns.Question) ([]dns.RR, int, error) { + var ra []dns.RR + var rcode int = dns.RcodeNameError + var domain string = q.Name + + atxt, err := DB.GetByDomain(domain) + if err != nil { + log.Errorf("Error while trying to get record [%v]", err) + return ra, dns.RcodeNameError, err + } + for _, v := range atxt { + if len(v.Value) > 0 { + r := new(dns.TXT) + r.Hdr = dns.RR_Header{Name: domain, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1} + r.Txt = append(r.Txt, v.Value) + ra = append(ra, r) + rcode = dns.RcodeSuccess + } + } + log.Debugf("Answering TXT question for domain [%s]", domain) + return ra, rcode, nil +} + +func answer(q dns.Question) ([]dns.RR, int, error) { + if q.Qtype == dns.TypeTXT { + return answerTXT(q) + } + var r []dns.RR + var rcode int = dns.RcodeSuccess + var domain string = q.Name + var rtype uint16 = q.Qtype + r, ok := RR.Records[rtype][domain] + if !ok { + rcode = dns.RcodeNameError + } + log.Debugf("Answering [%s] question for domain [%s] with rcode [%s]", dns.TypeToString[rtype], domain, dns.RcodeToString[rcode]) + return r, rcode, nil +} + +func handleRequest(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + + if r.Opcode == dns.OpcodeQuery { + readQuery(m) + } + + w.WriteMsg(m) +} + +// Parse config records +func (r *Records) Parse(recs []string) { + rrmap := make(map[uint16]map[string][]dns.RR) + for _, v := range recs { + rr, err := dns.NewRR(v) + if err != nil { + log.Errorf("Could not parse RR from config: [%v] for RR: [%s]", err, v) + continue + } + // Add parsed RR to the list + rrmap = AppendRR(rrmap, rr) + } + // Create serial + serial := time.Now().Format("2006010215") + // Add SOA + SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 28800 7200 604800 86400", DnsConf.General.Domain, DnsConf.General.Nsname, DnsConf.General.Nsadmin, serial) + soarr, err := dns.NewRR(SOAstring) + if err != nil { + log.Errorf("Error [%v] while trying to add SOA record: [%s]", err, SOAstring) + } else { + rrmap = AppendRR(rrmap, soarr) + } + r.Records = rrmap +} + +func AppendRR(rrmap map[uint16]map[string][]dns.RR, rr dns.RR) map[uint16]map[string][]dns.RR { + _, ok := rrmap[rr.Header().Rrtype] + if !ok { + newrr := make(map[string][]dns.RR) + rrmap[rr.Header().Rrtype] = newrr + } + rrmap[rr.Header().Rrtype][rr.Header().Name] = append(rrmap[rr.Header().Rrtype][rr.Header().Name], rr) + log.Debugf("Adding new record of type [%s] for domain [%s]", dns.TypeToString[rr.Header().Rrtype], rr.Header().Name) + return rrmap +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..02e7de9 --- /dev/null +++ b/main.go @@ -0,0 +1,89 @@ +package main + +import ( + "fmt" + "github.com/kataras/iris" + "github.com/miekg/dns" + "github.com/op/go-logging" + "os" +) + +// Logging config +var logfile_path = "acme-dns.log" +var log = logging.MustGetLogger("acme-dns") + +// Global configuration struct +var DnsConf DnsConfig + +var DB Database + +// Static records +var RR Records + +func main() { + // Setup logging + var stdout_format = logging.MustStringFormatter( + `%{color}%{time:15:04:05.000} %{shortfunc} ▶ %{level:.4s} %{id:03x}%{color:reset} %{message}`, + ) + var file_format = logging.MustStringFormatter( + `%{time:15:04:05.000} %{shortfunc} - %{level:.4s} %{id:03x} %{message}`, + ) + // Setup logging - stdout + logStdout := logging.NewLogBackend(os.Stdout, "", 0) + logStdoutFormatter := logging.NewBackendFormatter(logStdout, stdout_format) + // Setup logging - file + // Logging to file + logfh, err := os.OpenFile(logfile_path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) + if err != nil { + fmt.Printf("Could not open log file %s\n", logfile_path) + os.Exit(1) + } + defer logfh.Close() + logFile := logging.NewLogBackend(logfh, "", 0) + logFileFormatter := logging.NewBackendFormatter(logFile, file_format) + /* To limit logging to a level + logFileLeveled := logging.AddModuleLevel(logFile) + logFileLeveled.SetLevel(logging.ERROR, "") + */ + + // Start logging + logging.SetBackend(logStdoutFormatter, logFileFormatter) + log.Debug("Starting up...") + + // Read global config + if DnsConf, err = ReadConfig("config.cfg"); err != nil { + log.Errorf("Got error %v", err) + os.Exit(1) + } + RR.Parse(DnsConf.General.StaticRecords) + + // Open database + err = DB.Init("acme-dns.db") + if err != nil { + log.Errorf("Could not open database [%v]", err) + os.Exit(1) + } + defer DB.DB.Close() + + // DNS server part + dns.HandleFunc(".", handleRequest) + server := &dns.Server{Addr: ":53", Net: "udp"} + go func() { + err = server.ListenAndServe() + if err != nil { + log.Errorf("%v", err) + os.Exit(1) + } + }() + + // API server + api := iris.New() + for path, handlerfunc := range GetHandlerMap() { + api.Get(path, handlerfunc) + } + for path, handlerfunc := range PostHandlerMap() { + api.Post(path, handlerfunc) + } + api.Listen(":8080") + log.Debugf("Shutting down...") +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..3fd2d35 --- /dev/null +++ b/types.go @@ -0,0 +1,26 @@ +package main + +import ( + "github.com/miekg/dns" +) + +// Static records +type Records struct { + Records map[uint16]map[string][]dns.RR +} + +// Config file main struct +type DnsConfig struct { + General general +} + +// Config file general section +type general struct { + Domain string + Nsname string + Nsadmin string + Tls string + Tls_cert_privkey string + Tls_cert_fullchain string + StaticRecords []string `toml:"records"` +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..175b175 --- /dev/null +++ b/util.go @@ -0,0 +1,20 @@ +package main + +import ( + "regexp" + "unicode/utf8" +) + +func NormalizeString(s string, len int) string { + var ret string + re, err := regexp.Compile("[^A-Za-z\\-0-9]+") + if err != nil { + log.Errorf("%v", err) + return "" + } + ret = re.ReplaceAllString(s, "") + if utf8.RuneCountInString(ret) > len { + ret = ret[0:len] + } + return ret +}