Support for multiple TXT records per subdomain (#29)
* Support for multiple TXT records per subdomain and database upgrade functionality * Linter fixes * Make sure the database upgrade routine works for PostgreSQL * Move subdomain query outside of the upgrade transaction
This commit is contained in:
parent
ba695134ce
commit
733245fb3d
@ -12,8 +12,7 @@ type ACMETxt struct {
|
||||
Username uuid.UUID
|
||||
Password string
|
||||
ACMETxtPost
|
||||
LastActive int64
|
||||
AllowFrom cidrslice
|
||||
AllowFrom cidrslice
|
||||
}
|
||||
|
||||
// ACMETxtPost holds the DNS part of the ACMETxt struct
|
||||
|
||||
193
db.go
193
db.go
@ -4,7 +4,9 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
@ -14,16 +16,38 @@ import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var recordsTable = `
|
||||
// DBVersion shows the database version this code uses. This is used for update checks.
|
||||
var DBVersion = 1
|
||||
|
||||
var acmeTable = `
|
||||
CREATE TABLE IF NOT EXISTS acmedns(
|
||||
Name TEXT,
|
||||
Value TEXT
|
||||
);`
|
||||
|
||||
var userTable = `
|
||||
CREATE TABLE IF NOT EXISTS records(
|
||||
Username TEXT UNIQUE NOT NULL PRIMARY KEY,
|
||||
Password TEXT UNIQUE NOT NULL,
|
||||
Subdomain TEXT UNIQUE NOT NULL,
|
||||
Value TEXT,
|
||||
LastActive INT,
|
||||
AllowFrom TEXT
|
||||
);`
|
||||
|
||||
var txtTable = `
|
||||
CREATE TABLE IF NOT EXISTS txt(
|
||||
Subdomain TEXT NOT NULL,
|
||||
Value TEXT NOT NULL DEFAULT '',
|
||||
LastUpdate INT
|
||||
);`
|
||||
|
||||
var txtTablePG = `
|
||||
CREATE TABLE IF NOT EXISTS txt(
|
||||
rowid SERIAL,
|
||||
Subdomain TEXT NOT NULL,
|
||||
Value TEXT NOT NULL DEFAULT '',
|
||||
LastUpdate INT
|
||||
);`
|
||||
|
||||
// getSQLiteStmt replaces all PostgreSQL prepared statement placeholders (eg. $1, $2) with SQLite variant "?"
|
||||
func getSQLiteStmt(s string) string {
|
||||
re, _ := regexp.Compile("\\$[0-9]")
|
||||
@ -38,44 +62,151 @@ func (d *acmedb) Init(engine string, connection string) error {
|
||||
return err
|
||||
}
|
||||
d.DB = db
|
||||
//d.DB.SetMaxOpenConns(1)
|
||||
_, err = d.DB.Exec(recordsTable)
|
||||
// Check version first to try to catch old versions without version string
|
||||
var versionString string
|
||||
_ = d.DB.QueryRow("SELECT Value FROM acmedns WHERE Name='db_version'").Scan(&versionString)
|
||||
if versionString == "" {
|
||||
versionString = "0"
|
||||
}
|
||||
_, err = d.DB.Exec(acmeTable)
|
||||
_, err = d.DB.Exec(userTable)
|
||||
if Config.Database.Engine == "sqlite3" {
|
||||
_, err = d.DB.Exec(txtTable)
|
||||
} else {
|
||||
_, err = d.DB.Exec(txtTablePG)
|
||||
}
|
||||
// If everything is fine, handle db upgrade tasks
|
||||
if err == nil {
|
||||
err = d.checkDBUpgrades(versionString)
|
||||
}
|
||||
if err == nil {
|
||||
if versionString == "0" {
|
||||
// No errors so we should now be in version 1
|
||||
insversion := fmt.Sprintf("INSERT INTO acmedns (Name, Value) values('db_version', '%d')", DBVersion)
|
||||
_, err = db.Exec(insversion)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *acmedb) checkDBUpgrades(versionString string) error {
|
||||
var err error
|
||||
version, err := strconv.Atoi(versionString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if version != DBVersion {
|
||||
return d.handleDBUpgrades(version)
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (d *acmedb) handleDBUpgrades(version int) error {
|
||||
if version == 0 {
|
||||
return d.handleDBUpgradeTo1()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *acmedb) handleDBUpgradeTo1() error {
|
||||
var err error
|
||||
var subdomains []string
|
||||
rows, err := d.DB.Query("SELECT Subdomain FROM records")
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade")
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var subdomain string
|
||||
err = rows.Scan(&subdomain)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade while reading values")
|
||||
return err
|
||||
}
|
||||
subdomains = append(subdomains, subdomain)
|
||||
}
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade while inserting values")
|
||||
return err
|
||||
}
|
||||
tx, err := d.DB.Begin()
|
||||
// Rollback if errored, commit if not
|
||||
defer func() {
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
}()
|
||||
_, _ = tx.Exec("DELETE FROM txt")
|
||||
for _, subdomain := range subdomains {
|
||||
if subdomain != "" {
|
||||
// Insert two rows for each subdomain to txt table
|
||||
err = d.NewTXTValuesInTransaction(tx, subdomain)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade while inserting values")
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
// SQLite doesn't support dropping columns
|
||||
if Config.Database.Engine != "sqlite3" {
|
||||
_, _ = tx.Exec("ALTER TABLE records DROP COLUMN IF EXISTS Value")
|
||||
_, _ = tx.Exec("ALTER TABLE records DROP COLUMN IF EXISTS LastActive")
|
||||
}
|
||||
_, err = tx.Exec("UPDATE acmedns SET Value='1' WHERE Name='db_version'")
|
||||
return err
|
||||
}
|
||||
|
||||
// Create two rows for subdomain to the txt table
|
||||
func (d *acmedb) NewTXTValuesInTransaction(tx *sql.Tx, subdomain string) error {
|
||||
var err error
|
||||
instr := fmt.Sprintf("INSERT INTO txt (Subdomain, LastUpdate) values('%s', 0)", subdomain)
|
||||
_, err = tx.Exec(instr)
|
||||
_, err = tx.Exec(instr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *acmedb) Register(afrom cidrslice) (ACMETxt, error) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
var err error
|
||||
tx, err := d.DB.Begin()
|
||||
// Rollback if errored, commit if not
|
||||
defer func() {
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
}()
|
||||
a := newACMETxt()
|
||||
a.AllowFrom = cidrslice(afrom.ValidEntries())
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10)
|
||||
timenow := time.Now().Unix()
|
||||
regSQL := `
|
||||
INSERT INTO records(
|
||||
Username,
|
||||
Password,
|
||||
Subdomain,
|
||||
Value,
|
||||
LastActive,
|
||||
AllowFrom)
|
||||
values($1, $2, $3, '', $4, $5)`
|
||||
values($1, $2, $3, $4)`
|
||||
if Config.Database.Engine == "sqlite3" {
|
||||
regSQL = getSQLiteStmt(regSQL)
|
||||
}
|
||||
sm, err := d.DB.Prepare(regSQL)
|
||||
sm, err := tx.Prepare(regSQL)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"error": err.Error()}).Error("Database error in prepare")
|
||||
return a, errors.New("SQL error")
|
||||
}
|
||||
defer sm.Close()
|
||||
_, err = sm.Exec(a.Username.String(), passwordHash, a.Subdomain, timenow, a.AllowFrom.JSON())
|
||||
if err != nil {
|
||||
return a, err
|
||||
_, err = sm.Exec(a.Username.String(), passwordHash, a.Subdomain, a.AllowFrom.JSON())
|
||||
if err == nil {
|
||||
err = d.NewTXTValuesInTransaction(tx, a.Subdomain)
|
||||
}
|
||||
return a, nil
|
||||
return a, err
|
||||
}
|
||||
|
||||
func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
||||
@ -83,7 +214,7 @@ func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
||||
defer d.Unlock()
|
||||
var results []ACMETxt
|
||||
getSQL := `
|
||||
SELECT Username, Password, Subdomain, Value, LastActive, AllowFrom
|
||||
SELECT Username, Password, Subdomain, AllowFrom
|
||||
FROM records
|
||||
WHERE Username=$1 LIMIT 1
|
||||
`
|
||||
@ -116,15 +247,13 @@ func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
||||
return ACMETxt{}, errors.New("no user")
|
||||
}
|
||||
|
||||
func (d *acmedb) GetByDomain(domain string) ([]ACMETxt, error) {
|
||||
func (d *acmedb) GetTXTForDomain(domain string) ([]string, error) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
domain = sanitizeString(domain)
|
||||
var a []ACMETxt
|
||||
var txts []string
|
||||
getSQL := `
|
||||
SELECT Username, Password, Subdomain, Value, LastActive, AllowFrom
|
||||
FROM records
|
||||
WHERE Subdomain=$1 LIMIT 1
|
||||
SELECT Value FROM txt WHERE Subdomain=$1 LIMIT 2
|
||||
`
|
||||
if Config.Database.Engine == "sqlite3" {
|
||||
getSQL = getSQLiteStmt(getSQL)
|
||||
@ -132,33 +261,37 @@ func (d *acmedb) GetByDomain(domain string) ([]ACMETxt, error) {
|
||||
|
||||
sm, err := d.DB.Prepare(getSQL)
|
||||
if err != nil {
|
||||
return a, err
|
||||
return txts, err
|
||||
}
|
||||
defer sm.Close()
|
||||
rows, err := sm.Query(domain)
|
||||
if err != nil {
|
||||
return a, err
|
||||
return txts, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
txt, err := getModelFromRow(rows)
|
||||
var rtxt string
|
||||
err = rows.Scan(&rtxt)
|
||||
if err != nil {
|
||||
return a, err
|
||||
return txts, err
|
||||
}
|
||||
a = append(a, txt)
|
||||
txts = append(txts, rtxt)
|
||||
}
|
||||
return a, nil
|
||||
return txts, nil
|
||||
}
|
||||
|
||||
func (d *acmedb) Update(a ACMETxt) error {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
var err error
|
||||
// Data in a is already sanitized
|
||||
timenow := time.Now().Unix()
|
||||
|
||||
updSQL := `
|
||||
UPDATE records SET Value=$1, LastActive=$2
|
||||
WHERE Username=$3 AND Subdomain=$4
|
||||
UPDATE txt SET Value=$1, LastUpdate=$2
|
||||
WHERE rowid=(
|
||||
SELECT rowid FROM txt WHERE Subdomain=$3 ORDER BY LastUpdate LIMIT 1)
|
||||
`
|
||||
if Config.Database.Engine == "sqlite3" {
|
||||
updSQL = getSQLiteStmt(updSQL)
|
||||
@ -169,7 +302,7 @@ func (d *acmedb) Update(a ACMETxt) error {
|
||||
return err
|
||||
}
|
||||
defer sm.Close()
|
||||
_, err = sm.Exec(a.Value, timenow, a.Username, a.Subdomain)
|
||||
_, err = sm.Exec(a.Value, timenow, a.Subdomain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -183,8 +316,6 @@ func getModelFromRow(r *sql.Rows) (ACMETxt, error) {
|
||||
&txt.Username,
|
||||
&txt.Password,
|
||||
&txt.Subdomain,
|
||||
&txt.Value,
|
||||
&txt.LastActive,
|
||||
&afrom)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"error": err.Error()}).Error("Row scan error")
|
||||
|
||||
62
db_test.go
62
db_test.go
@ -118,7 +118,7 @@ func TestPrepareErrors(t *testing.T) {
|
||||
t.Errorf("Expected error, but didn't get one")
|
||||
}
|
||||
|
||||
_, err = DB.GetByDomain(reg.Subdomain)
|
||||
_, err = DB.GetTXTForDomain(reg.Subdomain)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error, but didn't get one")
|
||||
}
|
||||
@ -151,7 +151,7 @@ func TestQueryExecErrors(t *testing.T) {
|
||||
t.Errorf("Expected error from exec, but got none")
|
||||
}
|
||||
|
||||
_, err = DB.GetByDomain(reg.Subdomain)
|
||||
_, err = DB.GetTXTForDomain(reg.Subdomain)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error from exec in GetByDomain, but got none")
|
||||
}
|
||||
@ -195,11 +195,6 @@ func TestQueryScanErrors(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error from scan in, but got none")
|
||||
}
|
||||
|
||||
_, err = DB.GetByDomain(reg.Subdomain)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error from scan in GetByDomain, but got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadDBValues(t *testing.T) {
|
||||
@ -226,46 +221,55 @@ func TestBadDBValues(t *testing.T) {
|
||||
t.Errorf("Expected error from scan in, but got none")
|
||||
}
|
||||
|
||||
_, err = DB.GetByDomain(reg.Subdomain)
|
||||
_, err = DB.GetTXTForDomain(reg.Subdomain)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error from scan in GetByDomain, but got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetByDomain(t *testing.T) {
|
||||
var regDomain = ACMETxt{}
|
||||
|
||||
func TestGetTXTForDomain(t *testing.T) {
|
||||
// Create reg to refer to
|
||||
reg, err := DB.Register(cidrslice{})
|
||||
if err != nil {
|
||||
t.Errorf("Registration failed, got error [%v]", err)
|
||||
}
|
||||
|
||||
regDomainSlice, err := DB.GetByDomain(reg.Subdomain)
|
||||
txtval1 := "___validation_token_received_from_the_ca___"
|
||||
txtval2 := "___validation_token_received_YEAH_the_ca___"
|
||||
|
||||
reg.Value = txtval1
|
||||
_ = DB.Update(reg)
|
||||
|
||||
reg.Value = txtval2
|
||||
_ = DB.Update(reg)
|
||||
|
||||
regDomainSlice, err := DB.GetTXTForDomain(reg.Subdomain)
|
||||
if err != nil {
|
||||
t.Errorf("Could not get test user, got error [%v]", err)
|
||||
}
|
||||
if len(regDomainSlice) == 0 {
|
||||
t.Errorf("No rows returned for GetByDomain [%s]", reg.Subdomain)
|
||||
} else {
|
||||
regDomain = regDomainSlice[0]
|
||||
t.Errorf("No rows returned for GetTXTForDomain [%s]", reg.Subdomain)
|
||||
}
|
||||
|
||||
if reg.Username != regDomain.Username {
|
||||
t.Errorf("GetByUsername username [%q] did not match the original [%q]", regDomain.Username, reg.Username)
|
||||
var val1found = false
|
||||
var val2found = false
|
||||
for _, v := range regDomainSlice {
|
||||
if v == txtval1 {
|
||||
val1found = true
|
||||
}
|
||||
if v == txtval2 {
|
||||
val2found = true
|
||||
}
|
||||
}
|
||||
|
||||
if reg.Subdomain != regDomain.Subdomain {
|
||||
t.Errorf("GetByUsername subdomain [%q] did not match the original [%q]", regDomain.Subdomain, reg.Subdomain)
|
||||
if !val1found {
|
||||
t.Errorf("No TXT value found for val1")
|
||||
}
|
||||
|
||||
// regDomain password already is a bcrypt hash
|
||||
if !correctPassword(reg.Password, regDomain.Password) {
|
||||
t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regDomain.Password)
|
||||
if !val2found {
|
||||
t.Errorf("No TXT value found for val2")
|
||||
}
|
||||
|
||||
// Not found
|
||||
regNotfound, _ := DB.GetByDomain("does-not-exist")
|
||||
regNotfound, _ := DB.GetTXTForDomain("does-not-exist")
|
||||
if len(regNotfound) > 0 {
|
||||
t.Errorf("No records should be returned.")
|
||||
}
|
||||
@ -294,12 +298,4 @@ func TestUpdate(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("DB Update failed, got error: [%v]", err)
|
||||
}
|
||||
|
||||
updUser, err := DB.GetByUsername(regUser.Username)
|
||||
if err != nil {
|
||||
t.Errorf("GetByUsername threw error [%v]", err)
|
||||
}
|
||||
if updUser.Value != validTXT {
|
||||
t.Errorf("Update failed, fetched value [%s] does not match the update value [%s]", updUser.Value, validTXT)
|
||||
}
|
||||
}
|
||||
|
||||
8
dns.go
8
dns.go
@ -2,8 +2,8 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@ -23,16 +23,16 @@ func answerTXT(q dns.Question) ([]dns.RR, int, error) {
|
||||
var ra []dns.RR
|
||||
rcode := dns.RcodeNameError
|
||||
subdomain := sanitizeDomainQuestion(q.Name)
|
||||
atxt, err := DB.GetByDomain(subdomain)
|
||||
atxt, err := DB.GetTXTForDomain(subdomain)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record")
|
||||
return ra, dns.RcodeNameError, err
|
||||
}
|
||||
for _, v := range atxt {
|
||||
if len(v.Value) > 0 {
|
||||
if len(v) > 0 {
|
||||
r := new(dns.TXT)
|
||||
r.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1}
|
||||
r.Txt = append(r.Txt, v.Value)
|
||||
r.Txt = append(r.Txt, v)
|
||||
ra = append(ra, r)
|
||||
rcode = dns.RcodeSuccess
|
||||
}
|
||||
|
||||
2
main.go
2
main.go
@ -36,6 +36,8 @@ func main() {
|
||||
if err != nil {
|
||||
log.Errorf("Could not open database [%v]", err)
|
||||
os.Exit(1)
|
||||
} else {
|
||||
log.Info("Connected to database")
|
||||
}
|
||||
DB = newDB
|
||||
defer DB.Close()
|
||||
|
||||
2
types.go
2
types.go
@ -79,7 +79,7 @@ type database interface {
|
||||
Init(string, string) error
|
||||
Register(cidrslice) (ACMETxt, error)
|
||||
GetByUsername(uuid.UUID) (ACMETxt, error)
|
||||
GetByDomain(string) ([]ACMETxt, error)
|
||||
GetTXTForDomain(string) ([]string, error)
|
||||
Update(ACMETxt) error
|
||||
GetBackend() *sql.DB
|
||||
SetBackend(*sql.DB)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user