diff --git a/api.go b/api.go index 1c7dec8..2cec374 100644 --- a/api.go +++ b/api.go @@ -6,20 +6,8 @@ import ( "github.com/kataras/iris" ) -func GetHandlerMap() map[string]func(*iris.Context) { - return map[string]func(*iris.Context){ - "/register": WebRegisterGet, - } -} - -func PostHandlerMap() map[string]func(*iris.Context) { - return map[string]func(*iris.Context){ - "/register": WebRegisterPost, - "/update": WebUpdatePost, - } -} - -func (a AuthMiddleware) Serve(ctx *iris.Context) { +// Serve is an authentication middlware function used to authenticate update requests +func (a authMiddleware) Serve(ctx *iris.Context) { usernameStr := ctx.RequestHeader("X-Api-User") password := ctx.RequestHeader("X-Api-Key") postData := ACMETxt{} @@ -44,7 +32,7 @@ func (a AuthMiddleware) Serve(ctx *iris.Context) { ctx.JSON(iris.StatusUnauthorized, iris.Map{"error": "unauthorized"}) } -func WebRegisterPost(ctx *iris.Context) { +func webRegisterPost(ctx *iris.Context) { // Create new user nu, err := DB.Register() var regJSON iris.Map @@ -63,25 +51,25 @@ func WebRegisterPost(ctx *iris.Context) { ctx.JSON(regStatus, regJSON) } -func WebRegisterGet(ctx *iris.Context) { +func webRegisterGet(ctx *iris.Context) { // This is placeholder for now - WebRegisterPost(ctx) + webRegisterPost(ctx) } -func WebUpdatePost(ctx *iris.Context) { +func webUpdatePost(ctx *iris.Context) { // User auth done in middleware a := ACMETxt{} userStr := ctx.RequestHeader("X-API-User") username, err := getValidUsername(userStr) if err != nil { log.Warningf("Error while getting username [%s]. This should never happen because of auth middlware.", userStr) - WebUpdatePostError(ctx, err, iris.StatusUnauthorized) + webUpdatePostError(ctx, err, iris.StatusUnauthorized) return } if err := ctx.ReadJSON(&a); err != nil { // Handle bad post data log.Warningf("Could not unmarshal: [%v]", err) - WebUpdatePostError(ctx, err, iris.StatusBadRequest) + webUpdatePostError(ctx, err, iris.StatusBadRequest) return } a.Username = username @@ -90,18 +78,18 @@ func WebUpdatePost(ctx *iris.Context) { err := DB.Update(a) if err != nil { log.Warningf("Error trying to update [%v]", err) - WebUpdatePostError(ctx, errors.New("internal error"), iris.StatusInternalServerError) + webUpdatePostError(ctx, errors.New("internal error"), iris.StatusInternalServerError) return } ctx.JSON(iris.StatusOK, iris.Map{"txt": a.Value}) } else { log.Warningf("Bad data, subdomain: [%s], txt: [%s]", a.Subdomain, a.Value) - WebUpdatePostError(ctx, errors.New("bad data"), iris.StatusBadRequest) + webUpdatePostError(ctx, errors.New("bad data"), iris.StatusBadRequest) return } } -func WebUpdatePostError(ctx *iris.Context, err error, status int) { +func webUpdatePostError(ctx *iris.Context, err error, status int) { errStr := fmt.Sprintf("%v", err) updJSON := iris.Map{"error": errStr} ctx.JSON(status, updJSON) diff --git a/db.go b/db.go index e41a232..bf85b56 100644 --- a/db.go +++ b/db.go @@ -11,7 +11,7 @@ import ( "time" ) -type Database struct { +type database struct { DB *sql.DB } @@ -34,7 +34,7 @@ func getSQLiteStmt(s string) string { return re.ReplaceAllString(s, "?") } -func (d *Database) Init(engine string, connection string) error { +func (d *database) Init(engine string, connection string) error { db, err := sql.Open(engine, connection) if err != nil { return err @@ -48,7 +48,7 @@ func (d *Database) Init(engine string, connection string) error { return nil } -func (d *Database) Register() (ACMETxt, error) { +func (d *database) Register() (ACMETxt, error) { a, err := newACMETxt() if err != nil { return ACMETxt{}, err @@ -78,7 +78,7 @@ func (d *Database) Register() (ACMETxt, error) { return a, nil } -func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) { +func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) { var results []ACMETxt getSQL := ` SELECT Username, Password, Subdomain, Value, LastActive @@ -120,7 +120,7 @@ func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) { return ACMETxt{}, errors.New("no user") } -func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) { +func (d *database) GetByDomain(domain string) ([]ACMETxt, error) { domain = sanitizeString(domain) log.Debugf("Trying to select domain [%s] from table", domain) var a []ACMETxt @@ -155,7 +155,7 @@ func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) { return a, nil } -func (d *Database) Update(a ACMETxt) error { +func (d *database) Update(a ACMETxt) error { // Data in a is already sanitized log.Debugf("Trying to update domain [%s] with TXT data [%s]", a.Subdomain, a.Value) timenow := time.Now().Unix() diff --git a/dns.go b/dns.go index 9a665dd..684673e 100644 --- a/dns.go +++ b/dns.go @@ -20,7 +20,7 @@ func readQuery(m *dns.Msg) { func answerTXT(q dns.Question) ([]dns.RR, int, error) { var ra []dns.RR - var rcode int = dns.RcodeNameError + var rcode = dns.RcodeNameError var domain = strings.ToLower(q.Name) atxt, err := DB.GetByDomain(sanitizeDomainQuestion(domain)) @@ -46,9 +46,9 @@ func answer(q dns.Question) ([]dns.RR, int, error) { return answerTXT(q) } var r []dns.RR - var rcode int = dns.RcodeSuccess + var rcode = dns.RcodeSuccess var domain = strings.ToLower(q.Name) - var rtype uint16 = q.Qtype + var rtype = q.Qtype r, ok := RR.Records[rtype][domain] if !ok { rcode = dns.RcodeNameError @@ -78,7 +78,7 @@ func (r *Records) Parse(recs []string) { continue } // Add parsed RR to the list - rrmap = AppendRR(rrmap, rr) + rrmap = appendRR(rrmap, rr) } // Create serial serial := time.Now().Format("2006010215") @@ -88,12 +88,12 @@ func (r *Records) Parse(recs []string) { if err != nil { log.Errorf("Error [%v] while trying to add SOA record: [%s]", err, SOAstring) } else { - rrmap = AppendRR(rrmap, soarr) + rrmap = appendRR(rrmap, soarr) } r.Records = rrmap } -func AppendRR(rrmap map[uint16]map[string][]dns.RR, rr dns.RR) map[uint16]map[string][]dns.RR { +func appendRR(rrmap map[uint16]map[string][]dns.RR, rr dns.RR) map[uint16]map[string][]dns.RR { _, ok := rrmap[rr.Header().Rrtype] if !ok { newrr := make(map[string][]dns.RR) diff --git a/main.go b/main.go index 6fac7e8..b2709d8 100644 --- a/main.go +++ b/main.go @@ -12,12 +12,13 @@ import ( // Logging config var log = logging.MustGetLogger("acme-dns") -// Global configuration struct +// DNSConf is global configuration struct var DNSConf DNSConfig -var DB Database +// DB is used to access the database functions in acme-dns +var DB database -// Static records +// RR holds the static DNS records var RR Records func main() { @@ -63,10 +64,10 @@ func main() { Debug: DNSConf.General.Debug, }) api.Use(crs) - var ForceAuth = AuthMiddleware{} - api.Get("/register", WebRegisterGet) - api.Post("/register", WebRegisterPost) - api.Post("/update", ForceAuth.Serve, WebUpdatePost) + var ForceAuth = authMiddleware{} + api.Get("/register", webRegisterGet) + api.Post("/register", webRegisterPost) + api.Post("/update", ForceAuth.Serve, webUpdatePost) // TODO: migrate to api.Serve(iris.LETSENCRYPTPROD("mydomain.com")) switch DNSConf.API.TLS { case "letsencrypt": diff --git a/types.go b/types.go index 105a1bc..91dec01 100644 --- a/types.go +++ b/types.go @@ -5,12 +5,12 @@ import ( "github.com/satori/go.uuid" ) -// Static records +// Records is for static records type Records struct { Records map[uint16]map[string][]dns.RR } -// Config file main struct +// DNSConfig holds the config structure type DNSConfig struct { General general Database dbsettings @@ -19,7 +19,7 @@ type DNSConfig struct { } // Auth middleware -type AuthMiddleware struct{} +type authMiddleware struct{} // Config file general section type general struct { @@ -53,7 +53,7 @@ type logconfig struct { Format string `toml:"logformat"` } -// The default object +// ACMETxt is the default structure for the user controlled record type ACMETxt struct { Username uuid.UUID Password string @@ -61,6 +61,7 @@ type ACMETxt struct { LastActive int64 } +// ACMETxtPost holds the DNS part of the ACMETxt struct type ACMETxtPost struct { Subdomain string `json:"subdomain"` Value string `json:"txt"`