Support for multiple TXT records per subdomain (#29)

* Support for multiple TXT records per subdomain and database upgrade functionality

* Linter fixes

* Make sure the database upgrade routine works for PostgreSQL

* Move subdomain query outside of the upgrade transaction
This commit is contained in:
Joona Hoikkala 2018-01-22 09:53:07 +02:00 committed by GitHub
parent ba695134ce
commit 733245fb3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 199 additions and 71 deletions

View File

@ -12,8 +12,7 @@ type ACMETxt struct {
Username uuid.UUID Username uuid.UUID
Password string Password string
ACMETxtPost ACMETxtPost
LastActive int64 AllowFrom cidrslice
AllowFrom cidrslice
} }
// ACMETxtPost holds the DNS part of the ACMETxt struct // ACMETxtPost holds the DNS part of the ACMETxt struct

193
db.go
View File

@ -4,7 +4,9 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"regexp" "regexp"
"strconv"
"time" "time"
_ "github.com/lib/pq" _ "github.com/lib/pq"
@ -14,16 +16,38 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
var recordsTable = ` // DBVersion shows the database version this code uses. This is used for update checks.
var DBVersion = 1
var acmeTable = `
CREATE TABLE IF NOT EXISTS acmedns(
Name TEXT,
Value TEXT
);`
var userTable = `
CREATE TABLE IF NOT EXISTS records( CREATE TABLE IF NOT EXISTS records(
Username TEXT UNIQUE NOT NULL PRIMARY KEY, Username TEXT UNIQUE NOT NULL PRIMARY KEY,
Password TEXT UNIQUE NOT NULL, Password TEXT UNIQUE NOT NULL,
Subdomain TEXT UNIQUE NOT NULL, Subdomain TEXT UNIQUE NOT NULL,
Value TEXT,
LastActive INT,
AllowFrom TEXT AllowFrom TEXT
);` );`
var txtTable = `
CREATE TABLE IF NOT EXISTS txt(
Subdomain TEXT NOT NULL,
Value TEXT NOT NULL DEFAULT '',
LastUpdate INT
);`
var txtTablePG = `
CREATE TABLE IF NOT EXISTS txt(
rowid SERIAL,
Subdomain TEXT NOT NULL,
Value TEXT NOT NULL DEFAULT '',
LastUpdate INT
);`
// getSQLiteStmt replaces all PostgreSQL prepared statement placeholders (eg. $1, $2) with SQLite variant "?" // getSQLiteStmt replaces all PostgreSQL prepared statement placeholders (eg. $1, $2) with SQLite variant "?"
func getSQLiteStmt(s string) string { func getSQLiteStmt(s string) string {
re, _ := regexp.Compile("\\$[0-9]") re, _ := regexp.Compile("\\$[0-9]")
@ -38,44 +62,151 @@ func (d *acmedb) Init(engine string, connection string) error {
return err return err
} }
d.DB = db d.DB = db
//d.DB.SetMaxOpenConns(1) // Check version first to try to catch old versions without version string
_, err = d.DB.Exec(recordsTable) var versionString string
_ = d.DB.QueryRow("SELECT Value FROM acmedns WHERE Name='db_version'").Scan(&versionString)
if versionString == "" {
versionString = "0"
}
_, err = d.DB.Exec(acmeTable)
_, err = d.DB.Exec(userTable)
if Config.Database.Engine == "sqlite3" {
_, err = d.DB.Exec(txtTable)
} else {
_, err = d.DB.Exec(txtTablePG)
}
// If everything is fine, handle db upgrade tasks
if err == nil {
err = d.checkDBUpgrades(versionString)
}
if err == nil {
if versionString == "0" {
// No errors so we should now be in version 1
insversion := fmt.Sprintf("INSERT INTO acmedns (Name, Value) values('db_version', '%d')", DBVersion)
_, err = db.Exec(insversion)
}
}
return err
}
func (d *acmedb) checkDBUpgrades(versionString string) error {
var err error
version, err := strconv.Atoi(versionString)
if err != nil { if err != nil {
return err return err
} }
if version != DBVersion {
return d.handleDBUpgrades(version)
}
return nil return nil
}
func (d *acmedb) handleDBUpgrades(version int) error {
if version == 0 {
return d.handleDBUpgradeTo1()
}
return nil
}
func (d *acmedb) handleDBUpgradeTo1() error {
var err error
var subdomains []string
rows, err := d.DB.Query("SELECT Subdomain FROM records")
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade")
return err
}
defer rows.Close()
for rows.Next() {
var subdomain string
err = rows.Scan(&subdomain)
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade while reading values")
return err
}
subdomains = append(subdomains, subdomain)
}
err = rows.Err()
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade while inserting values")
return err
}
tx, err := d.DB.Begin()
// Rollback if errored, commit if not
defer func() {
if err != nil {
tx.Rollback()
return
}
tx.Commit()
}()
_, _ = tx.Exec("DELETE FROM txt")
for _, subdomain := range subdomains {
if subdomain != "" {
// Insert two rows for each subdomain to txt table
err = d.NewTXTValuesInTransaction(tx, subdomain)
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade while inserting values")
return err
}
}
}
// SQLite doesn't support dropping columns
if Config.Database.Engine != "sqlite3" {
_, _ = tx.Exec("ALTER TABLE records DROP COLUMN IF EXISTS Value")
_, _ = tx.Exec("ALTER TABLE records DROP COLUMN IF EXISTS LastActive")
}
_, err = tx.Exec("UPDATE acmedns SET Value='1' WHERE Name='db_version'")
return err
}
// Create two rows for subdomain to the txt table
func (d *acmedb) NewTXTValuesInTransaction(tx *sql.Tx, subdomain string) error {
var err error
instr := fmt.Sprintf("INSERT INTO txt (Subdomain, LastUpdate) values('%s', 0)", subdomain)
_, err = tx.Exec(instr)
_, err = tx.Exec(instr)
return err
} }
func (d *acmedb) Register(afrom cidrslice) (ACMETxt, error) { func (d *acmedb) Register(afrom cidrslice) (ACMETxt, error) {
d.Lock() d.Lock()
defer d.Unlock() defer d.Unlock()
var err error
tx, err := d.DB.Begin()
// Rollback if errored, commit if not
defer func() {
if err != nil {
tx.Rollback()
return
}
tx.Commit()
}()
a := newACMETxt() a := newACMETxt()
a.AllowFrom = cidrslice(afrom.ValidEntries()) a.AllowFrom = cidrslice(afrom.ValidEntries())
passwordHash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10) passwordHash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10)
timenow := time.Now().Unix()
regSQL := ` regSQL := `
INSERT INTO records( INSERT INTO records(
Username, Username,
Password, Password,
Subdomain, Subdomain,
Value,
LastActive,
AllowFrom) AllowFrom)
values($1, $2, $3, '', $4, $5)` values($1, $2, $3, $4)`
if Config.Database.Engine == "sqlite3" { if Config.Database.Engine == "sqlite3" {
regSQL = getSQLiteStmt(regSQL) regSQL = getSQLiteStmt(regSQL)
} }
sm, err := d.DB.Prepare(regSQL) sm, err := tx.Prepare(regSQL)
if err != nil { if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("Database error in prepare") log.WithFields(log.Fields{"error": err.Error()}).Error("Database error in prepare")
return a, errors.New("SQL error") return a, errors.New("SQL error")
} }
defer sm.Close() defer sm.Close()
_, err = sm.Exec(a.Username.String(), passwordHash, a.Subdomain, timenow, a.AllowFrom.JSON()) _, err = sm.Exec(a.Username.String(), passwordHash, a.Subdomain, a.AllowFrom.JSON())
if err != nil { if err == nil {
return a, err err = d.NewTXTValuesInTransaction(tx, a.Subdomain)
} }
return a, nil return a, err
} }
func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) { func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
@ -83,7 +214,7 @@ func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
defer d.Unlock() defer d.Unlock()
var results []ACMETxt var results []ACMETxt
getSQL := ` getSQL := `
SELECT Username, Password, Subdomain, Value, LastActive, AllowFrom SELECT Username, Password, Subdomain, AllowFrom
FROM records FROM records
WHERE Username=$1 LIMIT 1 WHERE Username=$1 LIMIT 1
` `
@ -116,15 +247,13 @@ func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
return ACMETxt{}, errors.New("no user") return ACMETxt{}, errors.New("no user")
} }
func (d *acmedb) GetByDomain(domain string) ([]ACMETxt, error) { func (d *acmedb) GetTXTForDomain(domain string) ([]string, error) {
d.Lock() d.Lock()
defer d.Unlock() defer d.Unlock()
domain = sanitizeString(domain) domain = sanitizeString(domain)
var a []ACMETxt var txts []string
getSQL := ` getSQL := `
SELECT Username, Password, Subdomain, Value, LastActive, AllowFrom SELECT Value FROM txt WHERE Subdomain=$1 LIMIT 2
FROM records
WHERE Subdomain=$1 LIMIT 1
` `
if Config.Database.Engine == "sqlite3" { if Config.Database.Engine == "sqlite3" {
getSQL = getSQLiteStmt(getSQL) getSQL = getSQLiteStmt(getSQL)
@ -132,33 +261,37 @@ func (d *acmedb) GetByDomain(domain string) ([]ACMETxt, error) {
sm, err := d.DB.Prepare(getSQL) sm, err := d.DB.Prepare(getSQL)
if err != nil { if err != nil {
return a, err return txts, err
} }
defer sm.Close() defer sm.Close()
rows, err := sm.Query(domain) rows, err := sm.Query(domain)
if err != nil { if err != nil {
return a, err return txts, err
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
txt, err := getModelFromRow(rows) var rtxt string
err = rows.Scan(&rtxt)
if err != nil { if err != nil {
return a, err return txts, err
} }
a = append(a, txt) txts = append(txts, rtxt)
} }
return a, nil return txts, nil
} }
func (d *acmedb) Update(a ACMETxt) error { func (d *acmedb) Update(a ACMETxt) error {
d.Lock() d.Lock()
defer d.Unlock() defer d.Unlock()
var err error
// Data in a is already sanitized // Data in a is already sanitized
timenow := time.Now().Unix() timenow := time.Now().Unix()
updSQL := ` updSQL := `
UPDATE records SET Value=$1, LastActive=$2 UPDATE txt SET Value=$1, LastUpdate=$2
WHERE Username=$3 AND Subdomain=$4 WHERE rowid=(
SELECT rowid FROM txt WHERE Subdomain=$3 ORDER BY LastUpdate LIMIT 1)
` `
if Config.Database.Engine == "sqlite3" { if Config.Database.Engine == "sqlite3" {
updSQL = getSQLiteStmt(updSQL) updSQL = getSQLiteStmt(updSQL)
@ -169,7 +302,7 @@ func (d *acmedb) Update(a ACMETxt) error {
return err return err
} }
defer sm.Close() defer sm.Close()
_, err = sm.Exec(a.Value, timenow, a.Username, a.Subdomain) _, err = sm.Exec(a.Value, timenow, a.Subdomain)
if err != nil { if err != nil {
return err return err
} }
@ -183,8 +316,6 @@ func getModelFromRow(r *sql.Rows) (ACMETxt, error) {
&txt.Username, &txt.Username,
&txt.Password, &txt.Password,
&txt.Subdomain, &txt.Subdomain,
&txt.Value,
&txt.LastActive,
&afrom) &afrom)
if err != nil { if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("Row scan error") log.WithFields(log.Fields{"error": err.Error()}).Error("Row scan error")

View File

@ -118,7 +118,7 @@ func TestPrepareErrors(t *testing.T) {
t.Errorf("Expected error, but didn't get one") t.Errorf("Expected error, but didn't get one")
} }
_, err = DB.GetByDomain(reg.Subdomain) _, err = DB.GetTXTForDomain(reg.Subdomain)
if err == nil { if err == nil {
t.Errorf("Expected error, but didn't get one") t.Errorf("Expected error, but didn't get one")
} }
@ -151,7 +151,7 @@ func TestQueryExecErrors(t *testing.T) {
t.Errorf("Expected error from exec, but got none") t.Errorf("Expected error from exec, but got none")
} }
_, err = DB.GetByDomain(reg.Subdomain) _, err = DB.GetTXTForDomain(reg.Subdomain)
if err == nil { if err == nil {
t.Errorf("Expected error from exec in GetByDomain, but got none") t.Errorf("Expected error from exec in GetByDomain, but got none")
} }
@ -195,11 +195,6 @@ func TestQueryScanErrors(t *testing.T) {
if err == nil { if err == nil {
t.Errorf("Expected error from scan in, but got none") t.Errorf("Expected error from scan in, but got none")
} }
_, err = DB.GetByDomain(reg.Subdomain)
if err == nil {
t.Errorf("Expected error from scan in GetByDomain, but got none")
}
} }
func TestBadDBValues(t *testing.T) { func TestBadDBValues(t *testing.T) {
@ -226,46 +221,55 @@ func TestBadDBValues(t *testing.T) {
t.Errorf("Expected error from scan in, but got none") t.Errorf("Expected error from scan in, but got none")
} }
_, err = DB.GetByDomain(reg.Subdomain) _, err = DB.GetTXTForDomain(reg.Subdomain)
if err == nil { if err == nil {
t.Errorf("Expected error from scan in GetByDomain, but got none") t.Errorf("Expected error from scan in GetByDomain, but got none")
} }
} }
func TestGetByDomain(t *testing.T) { func TestGetTXTForDomain(t *testing.T) {
var regDomain = ACMETxt{}
// Create reg to refer to // Create reg to refer to
reg, err := DB.Register(cidrslice{}) reg, err := DB.Register(cidrslice{})
if err != nil { if err != nil {
t.Errorf("Registration failed, got error [%v]", err) t.Errorf("Registration failed, got error [%v]", err)
} }
regDomainSlice, err := DB.GetByDomain(reg.Subdomain) txtval1 := "___validation_token_received_from_the_ca___"
txtval2 := "___validation_token_received_YEAH_the_ca___"
reg.Value = txtval1
_ = DB.Update(reg)
reg.Value = txtval2
_ = DB.Update(reg)
regDomainSlice, err := DB.GetTXTForDomain(reg.Subdomain)
if err != nil { if err != nil {
t.Errorf("Could not get test user, got error [%v]", err) t.Errorf("Could not get test user, got error [%v]", err)
} }
if len(regDomainSlice) == 0 { if len(regDomainSlice) == 0 {
t.Errorf("No rows returned for GetByDomain [%s]", reg.Subdomain) t.Errorf("No rows returned for GetTXTForDomain [%s]", reg.Subdomain)
} else {
regDomain = regDomainSlice[0]
} }
if reg.Username != regDomain.Username { var val1found = false
t.Errorf("GetByUsername username [%q] did not match the original [%q]", regDomain.Username, reg.Username) var val2found = false
for _, v := range regDomainSlice {
if v == txtval1 {
val1found = true
}
if v == txtval2 {
val2found = true
}
} }
if !val1found {
if reg.Subdomain != regDomain.Subdomain { t.Errorf("No TXT value found for val1")
t.Errorf("GetByUsername subdomain [%q] did not match the original [%q]", regDomain.Subdomain, reg.Subdomain)
} }
if !val2found {
// regDomain password already is a bcrypt hash t.Errorf("No TXT value found for val2")
if !correctPassword(reg.Password, regDomain.Password) {
t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regDomain.Password)
} }
// Not found // Not found
regNotfound, _ := DB.GetByDomain("does-not-exist") regNotfound, _ := DB.GetTXTForDomain("does-not-exist")
if len(regNotfound) > 0 { if len(regNotfound) > 0 {
t.Errorf("No records should be returned.") t.Errorf("No records should be returned.")
} }
@ -294,12 +298,4 @@ func TestUpdate(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("DB Update failed, got error: [%v]", err) t.Errorf("DB Update failed, got error: [%v]", err)
} }
updUser, err := DB.GetByUsername(regUser.Username)
if err != nil {
t.Errorf("GetByUsername threw error [%v]", err)
}
if updUser.Value != validTXT {
t.Errorf("Update failed, fetched value [%s] does not match the update value [%s]", updUser.Value, validTXT)
}
} }

8
dns.go
View File

@ -2,8 +2,8 @@ package main
import ( import (
"fmt" "fmt"
log "github.com/sirupsen/logrus"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"strings" "strings"
"time" "time"
) )
@ -23,16 +23,16 @@ func answerTXT(q dns.Question) ([]dns.RR, int, error) {
var ra []dns.RR var ra []dns.RR
rcode := dns.RcodeNameError rcode := dns.RcodeNameError
subdomain := sanitizeDomainQuestion(q.Name) subdomain := sanitizeDomainQuestion(q.Name)
atxt, err := DB.GetByDomain(subdomain) atxt, err := DB.GetTXTForDomain(subdomain)
if err != nil { if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record") log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record")
return ra, dns.RcodeNameError, err return ra, dns.RcodeNameError, err
} }
for _, v := range atxt { for _, v := range atxt {
if len(v.Value) > 0 { if len(v) > 0 {
r := new(dns.TXT) r := new(dns.TXT)
r.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1} r.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1}
r.Txt = append(r.Txt, v.Value) r.Txt = append(r.Txt, v)
ra = append(ra, r) ra = append(ra, r)
rcode = dns.RcodeSuccess rcode = dns.RcodeSuccess
} }

View File

@ -36,6 +36,8 @@ func main() {
if err != nil { if err != nil {
log.Errorf("Could not open database [%v]", err) log.Errorf("Could not open database [%v]", err)
os.Exit(1) os.Exit(1)
} else {
log.Info("Connected to database")
} }
DB = newDB DB = newDB
defer DB.Close() defer DB.Close()

View File

@ -79,7 +79,7 @@ type database interface {
Init(string, string) error Init(string, string) error
Register(cidrslice) (ACMETxt, error) Register(cidrslice) (ACMETxt, error)
GetByUsername(uuid.UUID) (ACMETxt, error) GetByUsername(uuid.UUID) (ACMETxt, error)
GetByDomain(string) ([]ACMETxt, error) GetTXTForDomain(string) ([]string, error)
Update(ACMETxt) error Update(ACMETxt) error
GetBackend() *sql.DB GetBackend() *sql.DB
SetBackend(*sql.DB) SetBackend(*sql.DB)