Refactoring and comments

This commit is contained in:
Joona Hoikkala 2016-11-23 18:07:38 +02:00
parent ba63bad793
commit 670c20f904
5 changed files with 36 additions and 46 deletions

34
api.go
View File

@ -6,20 +6,8 @@ import (
"github.com/kataras/iris" "github.com/kataras/iris"
) )
func GetHandlerMap() map[string]func(*iris.Context) { // Serve is an authentication middlware function used to authenticate update requests
return map[string]func(*iris.Context){ func (a authMiddleware) Serve(ctx *iris.Context) {
"/register": WebRegisterGet,
}
}
func PostHandlerMap() map[string]func(*iris.Context) {
return map[string]func(*iris.Context){
"/register": WebRegisterPost,
"/update": WebUpdatePost,
}
}
func (a AuthMiddleware) Serve(ctx *iris.Context) {
usernameStr := ctx.RequestHeader("X-Api-User") usernameStr := ctx.RequestHeader("X-Api-User")
password := ctx.RequestHeader("X-Api-Key") password := ctx.RequestHeader("X-Api-Key")
postData := ACMETxt{} postData := ACMETxt{}
@ -44,7 +32,7 @@ func (a AuthMiddleware) Serve(ctx *iris.Context) {
ctx.JSON(iris.StatusUnauthorized, iris.Map{"error": "unauthorized"}) ctx.JSON(iris.StatusUnauthorized, iris.Map{"error": "unauthorized"})
} }
func WebRegisterPost(ctx *iris.Context) { func webRegisterPost(ctx *iris.Context) {
// Create new user // Create new user
nu, err := DB.Register() nu, err := DB.Register()
var regJSON iris.Map var regJSON iris.Map
@ -63,25 +51,25 @@ func WebRegisterPost(ctx *iris.Context) {
ctx.JSON(regStatus, regJSON) ctx.JSON(regStatus, regJSON)
} }
func WebRegisterGet(ctx *iris.Context) { func webRegisterGet(ctx *iris.Context) {
// This is placeholder for now // This is placeholder for now
WebRegisterPost(ctx) webRegisterPost(ctx)
} }
func WebUpdatePost(ctx *iris.Context) { func webUpdatePost(ctx *iris.Context) {
// User auth done in middleware // User auth done in middleware
a := ACMETxt{} a := ACMETxt{}
userStr := ctx.RequestHeader("X-API-User") userStr := ctx.RequestHeader("X-API-User")
username, err := getValidUsername(userStr) username, err := getValidUsername(userStr)
if err != nil { if err != nil {
log.Warningf("Error while getting username [%s]. This should never happen because of auth middlware.", userStr) log.Warningf("Error while getting username [%s]. This should never happen because of auth middlware.", userStr)
WebUpdatePostError(ctx, err, iris.StatusUnauthorized) webUpdatePostError(ctx, err, iris.StatusUnauthorized)
return return
} }
if err := ctx.ReadJSON(&a); err != nil { if err := ctx.ReadJSON(&a); err != nil {
// Handle bad post data // Handle bad post data
log.Warningf("Could not unmarshal: [%v]", err) log.Warningf("Could not unmarshal: [%v]", err)
WebUpdatePostError(ctx, err, iris.StatusBadRequest) webUpdatePostError(ctx, err, iris.StatusBadRequest)
return return
} }
a.Username = username a.Username = username
@ -90,18 +78,18 @@ func WebUpdatePost(ctx *iris.Context) {
err := DB.Update(a) err := DB.Update(a)
if err != nil { if err != nil {
log.Warningf("Error trying to update [%v]", err) log.Warningf("Error trying to update [%v]", err)
WebUpdatePostError(ctx, errors.New("internal error"), iris.StatusInternalServerError) webUpdatePostError(ctx, errors.New("internal error"), iris.StatusInternalServerError)
return return
} }
ctx.JSON(iris.StatusOK, iris.Map{"txt": a.Value}) ctx.JSON(iris.StatusOK, iris.Map{"txt": a.Value})
} else { } else {
log.Warningf("Bad data, subdomain: [%s], txt: [%s]", a.Subdomain, a.Value) log.Warningf("Bad data, subdomain: [%s], txt: [%s]", a.Subdomain, a.Value)
WebUpdatePostError(ctx, errors.New("bad data"), iris.StatusBadRequest) webUpdatePostError(ctx, errors.New("bad data"), iris.StatusBadRequest)
return return
} }
} }
func WebUpdatePostError(ctx *iris.Context, err error, status int) { func webUpdatePostError(ctx *iris.Context, err error, status int) {
errStr := fmt.Sprintf("%v", err) errStr := fmt.Sprintf("%v", err)
updJSON := iris.Map{"error": errStr} updJSON := iris.Map{"error": errStr}
ctx.JSON(status, updJSON) ctx.JSON(status, updJSON)

12
db.go
View File

@ -11,7 +11,7 @@ import (
"time" "time"
) )
type Database struct { type database struct {
DB *sql.DB DB *sql.DB
} }
@ -34,7 +34,7 @@ func getSQLiteStmt(s string) string {
return re.ReplaceAllString(s, "?") return re.ReplaceAllString(s, "?")
} }
func (d *Database) Init(engine string, connection string) error { func (d *database) Init(engine string, connection string) error {
db, err := sql.Open(engine, connection) db, err := sql.Open(engine, connection)
if err != nil { if err != nil {
return err return err
@ -48,7 +48,7 @@ func (d *Database) Init(engine string, connection string) error {
return nil return nil
} }
func (d *Database) Register() (ACMETxt, error) { func (d *database) Register() (ACMETxt, error) {
a, err := newACMETxt() a, err := newACMETxt()
if err != nil { if err != nil {
return ACMETxt{}, err return ACMETxt{}, err
@ -78,7 +78,7 @@ func (d *Database) Register() (ACMETxt, error) {
return a, nil return a, nil
} }
func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) { func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
var results []ACMETxt var results []ACMETxt
getSQL := ` getSQL := `
SELECT Username, Password, Subdomain, Value, LastActive SELECT Username, Password, Subdomain, Value, LastActive
@ -120,7 +120,7 @@ func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
return ACMETxt{}, errors.New("no user") return ACMETxt{}, errors.New("no user")
} }
func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) { func (d *database) GetByDomain(domain string) ([]ACMETxt, error) {
domain = sanitizeString(domain) domain = sanitizeString(domain)
log.Debugf("Trying to select domain [%s] from table", domain) log.Debugf("Trying to select domain [%s] from table", domain)
var a []ACMETxt var a []ACMETxt
@ -155,7 +155,7 @@ func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) {
return a, nil return a, nil
} }
func (d *Database) Update(a ACMETxt) error { func (d *database) Update(a ACMETxt) error {
// Data in a is already sanitized // Data in a is already sanitized
log.Debugf("Trying to update domain [%s] with TXT data [%s]", a.Subdomain, a.Value) log.Debugf("Trying to update domain [%s] with TXT data [%s]", a.Subdomain, a.Value)
timenow := time.Now().Unix() timenow := time.Now().Unix()

12
dns.go
View File

@ -20,7 +20,7 @@ func readQuery(m *dns.Msg) {
func answerTXT(q dns.Question) ([]dns.RR, int, error) { func answerTXT(q dns.Question) ([]dns.RR, int, error) {
var ra []dns.RR var ra []dns.RR
var rcode int = dns.RcodeNameError var rcode = dns.RcodeNameError
var domain = strings.ToLower(q.Name) var domain = strings.ToLower(q.Name)
atxt, err := DB.GetByDomain(sanitizeDomainQuestion(domain)) atxt, err := DB.GetByDomain(sanitizeDomainQuestion(domain))
@ -46,9 +46,9 @@ func answer(q dns.Question) ([]dns.RR, int, error) {
return answerTXT(q) return answerTXT(q)
} }
var r []dns.RR var r []dns.RR
var rcode int = dns.RcodeSuccess var rcode = dns.RcodeSuccess
var domain = strings.ToLower(q.Name) var domain = strings.ToLower(q.Name)
var rtype uint16 = q.Qtype var rtype = q.Qtype
r, ok := RR.Records[rtype][domain] r, ok := RR.Records[rtype][domain]
if !ok { if !ok {
rcode = dns.RcodeNameError rcode = dns.RcodeNameError
@ -78,7 +78,7 @@ func (r *Records) Parse(recs []string) {
continue continue
} }
// Add parsed RR to the list // Add parsed RR to the list
rrmap = AppendRR(rrmap, rr) rrmap = appendRR(rrmap, rr)
} }
// Create serial // Create serial
serial := time.Now().Format("2006010215") serial := time.Now().Format("2006010215")
@ -88,12 +88,12 @@ func (r *Records) Parse(recs []string) {
if err != nil { if err != nil {
log.Errorf("Error [%v] while trying to add SOA record: [%s]", err, SOAstring) log.Errorf("Error [%v] while trying to add SOA record: [%s]", err, SOAstring)
} else { } else {
rrmap = AppendRR(rrmap, soarr) rrmap = appendRR(rrmap, soarr)
} }
r.Records = rrmap r.Records = rrmap
} }
func AppendRR(rrmap map[uint16]map[string][]dns.RR, rr dns.RR) map[uint16]map[string][]dns.RR { func appendRR(rrmap map[uint16]map[string][]dns.RR, rr dns.RR) map[uint16]map[string][]dns.RR {
_, ok := rrmap[rr.Header().Rrtype] _, ok := rrmap[rr.Header().Rrtype]
if !ok { if !ok {
newrr := make(map[string][]dns.RR) newrr := make(map[string][]dns.RR)

15
main.go
View File

@ -12,12 +12,13 @@ import (
// Logging config // Logging config
var log = logging.MustGetLogger("acme-dns") var log = logging.MustGetLogger("acme-dns")
// Global configuration struct // DNSConf is global configuration struct
var DNSConf DNSConfig var DNSConf DNSConfig
var DB Database // DB is used to access the database functions in acme-dns
var DB database
// Static records // RR holds the static DNS records
var RR Records var RR Records
func main() { func main() {
@ -63,10 +64,10 @@ func main() {
Debug: DNSConf.General.Debug, Debug: DNSConf.General.Debug,
}) })
api.Use(crs) api.Use(crs)
var ForceAuth = AuthMiddleware{} var ForceAuth = authMiddleware{}
api.Get("/register", WebRegisterGet) api.Get("/register", webRegisterGet)
api.Post("/register", WebRegisterPost) api.Post("/register", webRegisterPost)
api.Post("/update", ForceAuth.Serve, WebUpdatePost) api.Post("/update", ForceAuth.Serve, webUpdatePost)
// TODO: migrate to api.Serve(iris.LETSENCRYPTPROD("mydomain.com")) // TODO: migrate to api.Serve(iris.LETSENCRYPTPROD("mydomain.com"))
switch DNSConf.API.TLS { switch DNSConf.API.TLS {
case "letsencrypt": case "letsencrypt":

View File

@ -5,12 +5,12 @@ import (
"github.com/satori/go.uuid" "github.com/satori/go.uuid"
) )
// Static records // Records is for static records
type Records struct { type Records struct {
Records map[uint16]map[string][]dns.RR Records map[uint16]map[string][]dns.RR
} }
// Config file main struct // DNSConfig holds the config structure
type DNSConfig struct { type DNSConfig struct {
General general General general
Database dbsettings Database dbsettings
@ -19,7 +19,7 @@ type DNSConfig struct {
} }
// Auth middleware // Auth middleware
type AuthMiddleware struct{} type authMiddleware struct{}
// Config file general section // Config file general section
type general struct { type general struct {
@ -53,7 +53,7 @@ type logconfig struct {
Format string `toml:"logformat"` Format string `toml:"logformat"`
} }
// The default object // ACMETxt is the default structure for the user controlled record
type ACMETxt struct { type ACMETxt struct {
Username uuid.UUID Username uuid.UUID
Password string Password string
@ -61,6 +61,7 @@ type ACMETxt struct {
LastActive int64 LastActive int64
} }
// ACMETxtPost holds the DNS part of the ACMETxt struct
type ACMETxtPost struct { type ACMETxtPost struct {
Subdomain string `json:"subdomain"` Subdomain string `json:"subdomain"`
Value string `json:"txt"` Value string `json:"txt"`