From 0de310503af9de1528ef28f18dc569fbcd61baef Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 4 Jul 2019 13:47:18 +0200 Subject: [PATCH] Add util function for creating/checking dirs --- database/location.go | 36 +++--------------------------------- database/main.go | 3 ++- log/input.go | 5 ++++- utils/fs.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 51 insertions(+), 35 deletions(-) create mode 100644 utils/fs.go diff --git a/database/location.go b/database/location.go index 5d1a1be..7ca9777 100644 --- a/database/location.go +++ b/database/location.go @@ -1,11 +1,10 @@ package database import ( - "errors" "fmt" - "os" "path/filepath" - "runtime" + + "github.com/safing/portbase/utils" ) const ( @@ -16,35 +15,6 @@ var ( rootDir string ) -func ensureDirectory(dirPath string, permissions os.FileMode) error { - // open dir - dir, err := os.Open(dirPath) - if err != nil { - if os.IsNotExist(err) { - return os.MkdirAll(dirPath, permissions) - } - return err - } - defer dir.Close() - - fileInfo, err := dir.Stat() - if err != nil { - return err - } - if !fileInfo.IsDir() { - return errors.New("path exists and is not a directory") - } - - if runtime.GOOS == "windows" { - // TODO - // acl.Chmod(dirPath, permissions) - } else if fileInfo.Mode().Perm() != permissions { - return dir.Chmod(permissions) - } - - return nil -} - // GetDatabaseRoot returns the root directory of the database. func GetDatabaseRoot() string { return rootDir @@ -55,7 +25,7 @@ func getLocation(name, storageType string) (string, error) { location := filepath.Join(rootDir, databasesSubDir, name, storageType) // check location - err := ensureDirectory(location, 0700) + err := utils.EnsureDirectory(location, 0700) if err != nil { return "", fmt.Errorf("location (%s) invalid: %s", location, err) } diff --git a/database/main.go b/database/main.go index 988b826..f918a3b 100644 --- a/database/main.go +++ b/database/main.go @@ -5,6 +5,7 @@ import ( "fmt" "path/filepath" + "github.com/safing/portbase/utils" "github.com/tevino/abool" ) @@ -28,7 +29,7 @@ func SetLocation(location string) (ok bool) { func Initialize() error { if initialized.SetToIf(false, true) { - err := ensureDirectory(rootDir, 0755) + err := utils.EnsureDirectory(rootDir, 0755) if err != nil { return fmt.Errorf("could not create/open database directory (%s): %s", rootDir, err) } diff --git a/log/input.go b/log/input.go index beea9a7..6c64b0c 100644 --- a/log/input.go +++ b/log/input.go @@ -56,7 +56,10 @@ func log(level severity, msg string, trace *ContextTracer) { // check if level is enabled for file or generally if fileLevelsActive.IsSet() { fileOnly := strings.Split(file, "/") - sev, ok := fileLevels[fileOnly[len(fileOnly)-1]] + if len(fileOnly) < 2 { + return + } + sev, ok := fileLevels[fileOnly[len(fileOnly)-2]] if ok { if level < sev { return diff --git a/utils/fs.go b/utils/fs.go new file mode 100644 index 0000000..46f67b3 --- /dev/null +++ b/utils/fs.go @@ -0,0 +1,42 @@ +package utils + +import ( + "fmt" + "os" + "runtime" +) + +// EnsureDirectory ensures that the given directoy exists and that is has the given permissions set. +// If path is a file, it is deleted and a directory created. +// If a directory is created, also all missing directories up to the required one are created with the given permissions. +func EnsureDirectory(path string, perm os.FileMode) error { + // open path + f, err := os.Stat(path) + if err == nil { + // file exists + if f.IsDir() { + // directory exists, check permissions + if runtime.GOOS == "windows" { + // TODO: set correct permission on windows + // acl.Chmod(path, perm) + } else if f.Mode().Perm() != perm { + return os.Chmod(path, perm) + } + return nil + } + err = os.Remove(path) + if err != nil { + return fmt.Errorf("could not remove file %s to place dir: %s", path, err) + } + } + // file does not exist (or has been deleted) + if err == nil || os.IsNotExist(err) { + err = os.MkdirAll(path, perm) + if err != nil { + return fmt.Errorf("could not create dir %s: %s", path, err) + } + return nil + } + // other error opening path + return fmt.Errorf("failed to access %s: %s", path, err) +}