Support for multiple TXT records per subdomain (#29)
* Support for multiple TXT records per subdomain and database upgrade functionality * Linter fixes * Make sure the database upgrade routine works for PostgreSQL * Move subdomain query outside of the upgrade transaction
This commit is contained in:
parent
ba695134ce
commit
733245fb3d
@ -12,7 +12,6 @@ type ACMETxt struct {
|
|||||||
Username uuid.UUID
|
Username uuid.UUID
|
||||||
Password string
|
Password string
|
||||||
ACMETxtPost
|
ACMETxtPost
|
||||||
LastActive int64
|
|
||||||
AllowFrom cidrslice
|
AllowFrom cidrslice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
193
db.go
193
db.go
@ -4,7 +4,9 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
@ -14,16 +16,38 @@ import (
|
|||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
var recordsTable = `
|
// DBVersion shows the database version this code uses. This is used for update checks.
|
||||||
|
var DBVersion = 1
|
||||||
|
|
||||||
|
var acmeTable = `
|
||||||
|
CREATE TABLE IF NOT EXISTS acmedns(
|
||||||
|
Name TEXT,
|
||||||
|
Value TEXT
|
||||||
|
);`
|
||||||
|
|
||||||
|
var userTable = `
|
||||||
CREATE TABLE IF NOT EXISTS records(
|
CREATE TABLE IF NOT EXISTS records(
|
||||||
Username TEXT UNIQUE NOT NULL PRIMARY KEY,
|
Username TEXT UNIQUE NOT NULL PRIMARY KEY,
|
||||||
Password TEXT UNIQUE NOT NULL,
|
Password TEXT UNIQUE NOT NULL,
|
||||||
Subdomain TEXT UNIQUE NOT NULL,
|
Subdomain TEXT UNIQUE NOT NULL,
|
||||||
Value TEXT,
|
|
||||||
LastActive INT,
|
|
||||||
AllowFrom TEXT
|
AllowFrom TEXT
|
||||||
);`
|
);`
|
||||||
|
|
||||||
|
var txtTable = `
|
||||||
|
CREATE TABLE IF NOT EXISTS txt(
|
||||||
|
Subdomain TEXT NOT NULL,
|
||||||
|
Value TEXT NOT NULL DEFAULT '',
|
||||||
|
LastUpdate INT
|
||||||
|
);`
|
||||||
|
|
||||||
|
var txtTablePG = `
|
||||||
|
CREATE TABLE IF NOT EXISTS txt(
|
||||||
|
rowid SERIAL,
|
||||||
|
Subdomain TEXT NOT NULL,
|
||||||
|
Value TEXT NOT NULL DEFAULT '',
|
||||||
|
LastUpdate INT
|
||||||
|
);`
|
||||||
|
|
||||||
// getSQLiteStmt replaces all PostgreSQL prepared statement placeholders (eg. $1, $2) with SQLite variant "?"
|
// getSQLiteStmt replaces all PostgreSQL prepared statement placeholders (eg. $1, $2) with SQLite variant "?"
|
||||||
func getSQLiteStmt(s string) string {
|
func getSQLiteStmt(s string) string {
|
||||||
re, _ := regexp.Compile("\\$[0-9]")
|
re, _ := regexp.Compile("\\$[0-9]")
|
||||||
@ -38,44 +62,151 @@ func (d *acmedb) Init(engine string, connection string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
d.DB = db
|
d.DB = db
|
||||||
//d.DB.SetMaxOpenConns(1)
|
// Check version first to try to catch old versions without version string
|
||||||
_, err = d.DB.Exec(recordsTable)
|
var versionString string
|
||||||
|
_ = d.DB.QueryRow("SELECT Value FROM acmedns WHERE Name='db_version'").Scan(&versionString)
|
||||||
|
if versionString == "" {
|
||||||
|
versionString = "0"
|
||||||
|
}
|
||||||
|
_, err = d.DB.Exec(acmeTable)
|
||||||
|
_, err = d.DB.Exec(userTable)
|
||||||
|
if Config.Database.Engine == "sqlite3" {
|
||||||
|
_, err = d.DB.Exec(txtTable)
|
||||||
|
} else {
|
||||||
|
_, err = d.DB.Exec(txtTablePG)
|
||||||
|
}
|
||||||
|
// If everything is fine, handle db upgrade tasks
|
||||||
|
if err == nil {
|
||||||
|
err = d.checkDBUpgrades(versionString)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
if versionString == "0" {
|
||||||
|
// No errors so we should now be in version 1
|
||||||
|
insversion := fmt.Sprintf("INSERT INTO acmedns (Name, Value) values('db_version', '%d')", DBVersion)
|
||||||
|
_, err = db.Exec(insversion)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *acmedb) checkDBUpgrades(versionString string) error {
|
||||||
|
var err error
|
||||||
|
version, err := strconv.Atoi(versionString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if version != DBVersion {
|
||||||
|
return d.handleDBUpgrades(version)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *acmedb) handleDBUpgrades(version int) error {
|
||||||
|
if version == 0 {
|
||||||
|
return d.handleDBUpgradeTo1()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *acmedb) handleDBUpgradeTo1() error {
|
||||||
|
var err error
|
||||||
|
var subdomains []string
|
||||||
|
rows, err := d.DB.Query("SELECT Subdomain FROM records")
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
var subdomain string
|
||||||
|
err = rows.Scan(&subdomain)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade while reading values")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
subdomains = append(subdomains, subdomain)
|
||||||
|
}
|
||||||
|
err = rows.Err()
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade while inserting values")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
tx, err := d.DB.Begin()
|
||||||
|
// Rollback if errored, commit if not
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tx.Commit()
|
||||||
|
}()
|
||||||
|
_, _ = tx.Exec("DELETE FROM txt")
|
||||||
|
for _, subdomain := range subdomains {
|
||||||
|
if subdomain != "" {
|
||||||
|
// Insert two rows for each subdomain to txt table
|
||||||
|
err = d.NewTXTValuesInTransaction(tx, subdomain)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade while inserting values")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// SQLite doesn't support dropping columns
|
||||||
|
if Config.Database.Engine != "sqlite3" {
|
||||||
|
_, _ = tx.Exec("ALTER TABLE records DROP COLUMN IF EXISTS Value")
|
||||||
|
_, _ = tx.Exec("ALTER TABLE records DROP COLUMN IF EXISTS LastActive")
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("UPDATE acmedns SET Value='1' WHERE Name='db_version'")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create two rows for subdomain to the txt table
|
||||||
|
func (d *acmedb) NewTXTValuesInTransaction(tx *sql.Tx, subdomain string) error {
|
||||||
|
var err error
|
||||||
|
instr := fmt.Sprintf("INSERT INTO txt (Subdomain, LastUpdate) values('%s', 0)", subdomain)
|
||||||
|
_, err = tx.Exec(instr)
|
||||||
|
_, err = tx.Exec(instr)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *acmedb) Register(afrom cidrslice) (ACMETxt, error) {
|
func (d *acmedb) Register(afrom cidrslice) (ACMETxt, error) {
|
||||||
d.Lock()
|
d.Lock()
|
||||||
defer d.Unlock()
|
defer d.Unlock()
|
||||||
|
var err error
|
||||||
|
tx, err := d.DB.Begin()
|
||||||
|
// Rollback if errored, commit if not
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tx.Commit()
|
||||||
|
}()
|
||||||
a := newACMETxt()
|
a := newACMETxt()
|
||||||
a.AllowFrom = cidrslice(afrom.ValidEntries())
|
a.AllowFrom = cidrslice(afrom.ValidEntries())
|
||||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10)
|
passwordHash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10)
|
||||||
timenow := time.Now().Unix()
|
|
||||||
regSQL := `
|
regSQL := `
|
||||||
INSERT INTO records(
|
INSERT INTO records(
|
||||||
Username,
|
Username,
|
||||||
Password,
|
Password,
|
||||||
Subdomain,
|
Subdomain,
|
||||||
Value,
|
|
||||||
LastActive,
|
|
||||||
AllowFrom)
|
AllowFrom)
|
||||||
values($1, $2, $3, '', $4, $5)`
|
values($1, $2, $3, $4)`
|
||||||
if Config.Database.Engine == "sqlite3" {
|
if Config.Database.Engine == "sqlite3" {
|
||||||
regSQL = getSQLiteStmt(regSQL)
|
regSQL = getSQLiteStmt(regSQL)
|
||||||
}
|
}
|
||||||
sm, err := d.DB.Prepare(regSQL)
|
sm, err := tx.Prepare(regSQL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithFields(log.Fields{"error": err.Error()}).Error("Database error in prepare")
|
log.WithFields(log.Fields{"error": err.Error()}).Error("Database error in prepare")
|
||||||
return a, errors.New("SQL error")
|
return a, errors.New("SQL error")
|
||||||
}
|
}
|
||||||
defer sm.Close()
|
defer sm.Close()
|
||||||
_, err = sm.Exec(a.Username.String(), passwordHash, a.Subdomain, timenow, a.AllowFrom.JSON())
|
_, err = sm.Exec(a.Username.String(), passwordHash, a.Subdomain, a.AllowFrom.JSON())
|
||||||
if err != nil {
|
if err == nil {
|
||||||
return a, err
|
err = d.NewTXTValuesInTransaction(tx, a.Subdomain)
|
||||||
}
|
}
|
||||||
return a, nil
|
return a, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
||||||
@ -83,7 +214,7 @@ func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
|||||||
defer d.Unlock()
|
defer d.Unlock()
|
||||||
var results []ACMETxt
|
var results []ACMETxt
|
||||||
getSQL := `
|
getSQL := `
|
||||||
SELECT Username, Password, Subdomain, Value, LastActive, AllowFrom
|
SELECT Username, Password, Subdomain, AllowFrom
|
||||||
FROM records
|
FROM records
|
||||||
WHERE Username=$1 LIMIT 1
|
WHERE Username=$1 LIMIT 1
|
||||||
`
|
`
|
||||||
@ -116,15 +247,13 @@ func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
|||||||
return ACMETxt{}, errors.New("no user")
|
return ACMETxt{}, errors.New("no user")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *acmedb) GetByDomain(domain string) ([]ACMETxt, error) {
|
func (d *acmedb) GetTXTForDomain(domain string) ([]string, error) {
|
||||||
d.Lock()
|
d.Lock()
|
||||||
defer d.Unlock()
|
defer d.Unlock()
|
||||||
domain = sanitizeString(domain)
|
domain = sanitizeString(domain)
|
||||||
var a []ACMETxt
|
var txts []string
|
||||||
getSQL := `
|
getSQL := `
|
||||||
SELECT Username, Password, Subdomain, Value, LastActive, AllowFrom
|
SELECT Value FROM txt WHERE Subdomain=$1 LIMIT 2
|
||||||
FROM records
|
|
||||||
WHERE Subdomain=$1 LIMIT 1
|
|
||||||
`
|
`
|
||||||
if Config.Database.Engine == "sqlite3" {
|
if Config.Database.Engine == "sqlite3" {
|
||||||
getSQL = getSQLiteStmt(getSQL)
|
getSQL = getSQLiteStmt(getSQL)
|
||||||
@ -132,33 +261,37 @@ func (d *acmedb) GetByDomain(domain string) ([]ACMETxt, error) {
|
|||||||
|
|
||||||
sm, err := d.DB.Prepare(getSQL)
|
sm, err := d.DB.Prepare(getSQL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return a, err
|
return txts, err
|
||||||
}
|
}
|
||||||
defer sm.Close()
|
defer sm.Close()
|
||||||
rows, err := sm.Query(domain)
|
rows, err := sm.Query(domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return a, err
|
return txts, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
txt, err := getModelFromRow(rows)
|
var rtxt string
|
||||||
|
err = rows.Scan(&rtxt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return a, err
|
return txts, err
|
||||||
}
|
}
|
||||||
a = append(a, txt)
|
txts = append(txts, rtxt)
|
||||||
}
|
}
|
||||||
return a, nil
|
return txts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *acmedb) Update(a ACMETxt) error {
|
func (d *acmedb) Update(a ACMETxt) error {
|
||||||
d.Lock()
|
d.Lock()
|
||||||
defer d.Unlock()
|
defer d.Unlock()
|
||||||
|
var err error
|
||||||
// Data in a is already sanitized
|
// Data in a is already sanitized
|
||||||
timenow := time.Now().Unix()
|
timenow := time.Now().Unix()
|
||||||
|
|
||||||
updSQL := `
|
updSQL := `
|
||||||
UPDATE records SET Value=$1, LastActive=$2
|
UPDATE txt SET Value=$1, LastUpdate=$2
|
||||||
WHERE Username=$3 AND Subdomain=$4
|
WHERE rowid=(
|
||||||
|
SELECT rowid FROM txt WHERE Subdomain=$3 ORDER BY LastUpdate LIMIT 1)
|
||||||
`
|
`
|
||||||
if Config.Database.Engine == "sqlite3" {
|
if Config.Database.Engine == "sqlite3" {
|
||||||
updSQL = getSQLiteStmt(updSQL)
|
updSQL = getSQLiteStmt(updSQL)
|
||||||
@ -169,7 +302,7 @@ func (d *acmedb) Update(a ACMETxt) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer sm.Close()
|
defer sm.Close()
|
||||||
_, err = sm.Exec(a.Value, timenow, a.Username, a.Subdomain)
|
_, err = sm.Exec(a.Value, timenow, a.Subdomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -183,8 +316,6 @@ func getModelFromRow(r *sql.Rows) (ACMETxt, error) {
|
|||||||
&txt.Username,
|
&txt.Username,
|
||||||
&txt.Password,
|
&txt.Password,
|
||||||
&txt.Subdomain,
|
&txt.Subdomain,
|
||||||
&txt.Value,
|
|
||||||
&txt.LastActive,
|
|
||||||
&afrom)
|
&afrom)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithFields(log.Fields{"error": err.Error()}).Error("Row scan error")
|
log.WithFields(log.Fields{"error": err.Error()}).Error("Row scan error")
|
||||||
|
|||||||
62
db_test.go
62
db_test.go
@ -118,7 +118,7 @@ func TestPrepareErrors(t *testing.T) {
|
|||||||
t.Errorf("Expected error, but didn't get one")
|
t.Errorf("Expected error, but didn't get one")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = DB.GetByDomain(reg.Subdomain)
|
_, err = DB.GetTXTForDomain(reg.Subdomain)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected error, but didn't get one")
|
t.Errorf("Expected error, but didn't get one")
|
||||||
}
|
}
|
||||||
@ -151,7 +151,7 @@ func TestQueryExecErrors(t *testing.T) {
|
|||||||
t.Errorf("Expected error from exec, but got none")
|
t.Errorf("Expected error from exec, but got none")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = DB.GetByDomain(reg.Subdomain)
|
_, err = DB.GetTXTForDomain(reg.Subdomain)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected error from exec in GetByDomain, but got none")
|
t.Errorf("Expected error from exec in GetByDomain, but got none")
|
||||||
}
|
}
|
||||||
@ -195,11 +195,6 @@ func TestQueryScanErrors(t *testing.T) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected error from scan in, but got none")
|
t.Errorf("Expected error from scan in, but got none")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = DB.GetByDomain(reg.Subdomain)
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("Expected error from scan in GetByDomain, but got none")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBadDBValues(t *testing.T) {
|
func TestBadDBValues(t *testing.T) {
|
||||||
@ -226,46 +221,55 @@ func TestBadDBValues(t *testing.T) {
|
|||||||
t.Errorf("Expected error from scan in, but got none")
|
t.Errorf("Expected error from scan in, but got none")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = DB.GetByDomain(reg.Subdomain)
|
_, err = DB.GetTXTForDomain(reg.Subdomain)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected error from scan in GetByDomain, but got none")
|
t.Errorf("Expected error from scan in GetByDomain, but got none")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetByDomain(t *testing.T) {
|
func TestGetTXTForDomain(t *testing.T) {
|
||||||
var regDomain = ACMETxt{}
|
|
||||||
|
|
||||||
// Create reg to refer to
|
// Create reg to refer to
|
||||||
reg, err := DB.Register(cidrslice{})
|
reg, err := DB.Register(cidrslice{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Registration failed, got error [%v]", err)
|
t.Errorf("Registration failed, got error [%v]", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
regDomainSlice, err := DB.GetByDomain(reg.Subdomain)
|
txtval1 := "___validation_token_received_from_the_ca___"
|
||||||
|
txtval2 := "___validation_token_received_YEAH_the_ca___"
|
||||||
|
|
||||||
|
reg.Value = txtval1
|
||||||
|
_ = DB.Update(reg)
|
||||||
|
|
||||||
|
reg.Value = txtval2
|
||||||
|
_ = DB.Update(reg)
|
||||||
|
|
||||||
|
regDomainSlice, err := DB.GetTXTForDomain(reg.Subdomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Could not get test user, got error [%v]", err)
|
t.Errorf("Could not get test user, got error [%v]", err)
|
||||||
}
|
}
|
||||||
if len(regDomainSlice) == 0 {
|
if len(regDomainSlice) == 0 {
|
||||||
t.Errorf("No rows returned for GetByDomain [%s]", reg.Subdomain)
|
t.Errorf("No rows returned for GetTXTForDomain [%s]", reg.Subdomain)
|
||||||
} else {
|
|
||||||
regDomain = regDomainSlice[0]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if reg.Username != regDomain.Username {
|
var val1found = false
|
||||||
t.Errorf("GetByUsername username [%q] did not match the original [%q]", regDomain.Username, reg.Username)
|
var val2found = false
|
||||||
|
for _, v := range regDomainSlice {
|
||||||
|
if v == txtval1 {
|
||||||
|
val1found = true
|
||||||
}
|
}
|
||||||
|
if v == txtval2 {
|
||||||
if reg.Subdomain != regDomain.Subdomain {
|
val2found = true
|
||||||
t.Errorf("GetByUsername subdomain [%q] did not match the original [%q]", regDomain.Subdomain, reg.Subdomain)
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
// regDomain password already is a bcrypt hash
|
if !val1found {
|
||||||
if !correctPassword(reg.Password, regDomain.Password) {
|
t.Errorf("No TXT value found for val1")
|
||||||
t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regDomain.Password)
|
}
|
||||||
|
if !val2found {
|
||||||
|
t.Errorf("No TXT value found for val2")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not found
|
// Not found
|
||||||
regNotfound, _ := DB.GetByDomain("does-not-exist")
|
regNotfound, _ := DB.GetTXTForDomain("does-not-exist")
|
||||||
if len(regNotfound) > 0 {
|
if len(regNotfound) > 0 {
|
||||||
t.Errorf("No records should be returned.")
|
t.Errorf("No records should be returned.")
|
||||||
}
|
}
|
||||||
@ -294,12 +298,4 @@ func TestUpdate(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("DB Update failed, got error: [%v]", err)
|
t.Errorf("DB Update failed, got error: [%v]", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
updUser, err := DB.GetByUsername(regUser.Username)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("GetByUsername threw error [%v]", err)
|
|
||||||
}
|
|
||||||
if updUser.Value != validTXT {
|
|
||||||
t.Errorf("Update failed, fetched value [%s] does not match the update value [%s]", updUser.Value, validTXT)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
8
dns.go
8
dns.go
@ -2,8 +2,8 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -23,16 +23,16 @@ func answerTXT(q dns.Question) ([]dns.RR, int, error) {
|
|||||||
var ra []dns.RR
|
var ra []dns.RR
|
||||||
rcode := dns.RcodeNameError
|
rcode := dns.RcodeNameError
|
||||||
subdomain := sanitizeDomainQuestion(q.Name)
|
subdomain := sanitizeDomainQuestion(q.Name)
|
||||||
atxt, err := DB.GetByDomain(subdomain)
|
atxt, err := DB.GetTXTForDomain(subdomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record")
|
log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record")
|
||||||
return ra, dns.RcodeNameError, err
|
return ra, dns.RcodeNameError, err
|
||||||
}
|
}
|
||||||
for _, v := range atxt {
|
for _, v := range atxt {
|
||||||
if len(v.Value) > 0 {
|
if len(v) > 0 {
|
||||||
r := new(dns.TXT)
|
r := new(dns.TXT)
|
||||||
r.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1}
|
r.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1}
|
||||||
r.Txt = append(r.Txt, v.Value)
|
r.Txt = append(r.Txt, v)
|
||||||
ra = append(ra, r)
|
ra = append(ra, r)
|
||||||
rcode = dns.RcodeSuccess
|
rcode = dns.RcodeSuccess
|
||||||
}
|
}
|
||||||
|
|||||||
2
main.go
2
main.go
@ -36,6 +36,8 @@ func main() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Could not open database [%v]", err)
|
log.Errorf("Could not open database [%v]", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
} else {
|
||||||
|
log.Info("Connected to database")
|
||||||
}
|
}
|
||||||
DB = newDB
|
DB = newDB
|
||||||
defer DB.Close()
|
defer DB.Close()
|
||||||
|
|||||||
2
types.go
2
types.go
@ -79,7 +79,7 @@ type database interface {
|
|||||||
Init(string, string) error
|
Init(string, string) error
|
||||||
Register(cidrslice) (ACMETxt, error)
|
Register(cidrslice) (ACMETxt, error)
|
||||||
GetByUsername(uuid.UUID) (ACMETxt, error)
|
GetByUsername(uuid.UUID) (ACMETxt, error)
|
||||||
GetByDomain(string) ([]ACMETxt, error)
|
GetTXTForDomain(string) ([]string, error)
|
||||||
Update(ACMETxt) error
|
Update(ACMETxt) error
|
||||||
GetBackend() *sql.DB
|
GetBackend() *sql.DB
|
||||||
SetBackend(*sql.DB)
|
SetBackend(*sql.DB)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user