diff --git a/main.go b/main.go index dbea682..1c217cd 100644 --- a/main.go +++ b/main.go @@ -16,13 +16,20 @@ import ( func main() { // Read global config - if fileExists("/etc/acme-dns/config.cfg") { - Config = readConfig("/etc/acme-dns/config.cfg") + var err error + if fileIsAccessible("/etc/acme-dns/config.cfg") { log.WithFields(log.Fields{"file": "/etc/acme-dns/config.cfg"}).Info("Using config file") - - } else { + Config, err = readConfig("/etc/acme-dns/config.cfg") + } else if fileIsAccessible("./config.cfg") { log.WithFields(log.Fields{"file": "./config.cfg"}).Info("Using config file") - Config = readConfig("config.cfg") + Config, err = readConfig("./config.cfg") + } else { + log.Errorf("Configuration file not found.") + os.Exit(1) + } + if err != nil { + log.Errorf("Encountered an error while trying to read configuration file: %s", err) + os.Exit(1) } setupLogging(Config.Logconfig.Format, Config.Logconfig.Level) @@ -32,7 +39,7 @@ func main() { // Open database newDB := new(acmedb) - err := newDB.Init(Config.Database.Engine, Config.Database.Connection) + err = newDB.Init(Config.Database.Engine, Config.Database.Connection) if err != nil { log.Errorf("Could not open database [%v]", err) os.Exit(1) diff --git a/util.go b/util.go index dcd1d15..675593e 100644 --- a/util.go +++ b/util.go @@ -17,19 +17,23 @@ func jsonError(message string) []byte { return []byte(fmt.Sprintf("{\"error\": \"%s\"}", message)) } -func fileExists(fname string) bool { +func fileIsAccessible(fname string) bool { _, err := os.Stat(fname) if err != nil { return false } + f, err := os.Open(fname) + if err != nil { + return false + } + f.Close() return true } -func readConfig(fname string) DNSConfig { +func readConfig(fname string) (DNSConfig, error) { var conf DNSConfig - // Practically never errors - _, _ = toml.DecodeFile(fname, &conf) - return conf + _, err := toml.DecodeFile(fname, &conf) + return conf, err } func sanitizeString(s string) string { diff --git a/util_test.go b/util_test.go index 0f9c17e..f7630c2 100644 --- a/util_test.go +++ b/util_test.go @@ -1,10 +1,12 @@ package main import ( - log "github.com/sirupsen/logrus" "io/ioutil" "os" + "syscall" "testing" + + log "github.com/sirupsen/logrus" ) func TestSetupLogging(t *testing.T) { @@ -62,7 +64,7 @@ func TestReadConfig(t *testing.T) { if err := tmpfile.Close(); err != nil { t.Error("Could not close temporary file") } - ret := readConfig(tmpfile.Name()) + ret, _ := readConfig(tmpfile.Name()) if ret.General.Listen != test.output.General.Listen { t.Errorf("Test %d: Expected listen value %s, but got %s", i, test.output.General.Listen, ret.General.Listen) } @@ -95,3 +97,33 @@ func TestGetIPListFromHeader(t *testing.T) { } } } + +func TestFileCheckPermissionDenied(t *testing.T) { + tmpfile, err := ioutil.TempFile("", "acmedns") + if err != nil { + t.Error("Could not create temporary file") + } + defer os.Remove(tmpfile.Name()) + syscall.Chmod(tmpfile.Name(), 0000) + if fileIsAccessible(tmpfile.Name()) { + t.Errorf("File should not be accessible") + } + syscall.Chmod(tmpfile.Name(), 0644) +} + +func TestFileCheckNotExists(t *testing.T) { + if fileIsAccessible("/path/that/does/not/exist") { + t.Errorf("File should not be accessible") + } +} + +func TestFileCheckOK(t *testing.T) { + tmpfile, err := ioutil.TempFile("", "acmedns") + if err != nil { + t.Error("Could not create temporary file") + } + defer os.Remove(tmpfile.Name()) + if !fileIsAccessible(tmpfile.Name()) { + t.Errorf("File should be accessible") + } +}