diff --git a/api/main.go b/api/main.go index b522788..73955a8 100644 --- a/api/main.go +++ b/api/main.go @@ -13,7 +13,7 @@ var ( ) func init() { - modules.Register("api", prep, start, nil, "database") + modules.Register("api", prep, start, nil, "core") } func prep() error { diff --git a/config/main.go b/config/main.go index 52a2a4d..2c5003e 100644 --- a/config/main.go +++ b/config/main.go @@ -1,23 +1,40 @@ package config import ( + "errors" "os" - "path" + "path/filepath" - "github.com/safing/portbase/database" "github.com/safing/portbase/modules" + "github.com/safing/portbase/utils" + "github.com/safing/portmaster/core/structure" ) +var ( + dataRoot *utils.DirStructure +) + +// SetDataRoot sets the data root from which the updates module derives its paths. +func SetDataRoot(root *utils.DirStructure) { + if dataRoot == nil { + dataRoot = root + } +} + func init() { - modules.Register("config", prep, start, nil, "database") + modules.Register("config", prep, start, nil, "core") } func prep() error { + SetDataRoot(structure.Root()) + if dataRoot == nil { + return errors.New("data root is not set") + } return nil } func start() error { - configFilePath = path.Join(database.GetDatabaseRoot(), "config.json") + configFilePath = filepath.Join(dataRoot.Path, "config.json") err := registerAsDatabase() if err != nil && !os.IsNotExist(err) { diff --git a/database/location.go b/database/location.go index 7ca9777..636bab8 100644 --- a/database/location.go +++ b/database/location.go @@ -1,33 +1 @@ package database - -import ( - "fmt" - "path/filepath" - - "github.com/safing/portbase/utils" -) - -const ( - databasesSubDir = "databases" -) - -var ( - rootDir string -) - -// GetDatabaseRoot returns the root directory of the database. -func GetDatabaseRoot() string { - return rootDir -} - -// getLocation returns the storage location for the given name and type. -func getLocation(name, storageType string) (string, error) { - location := filepath.Join(rootDir, databasesSubDir, name, storageType) - - // check location - err := utils.EnsureDirectory(location, 0700) - if err != nil { - return "", fmt.Errorf("location (%s) invalid: %s", location, err) - } - return location, nil -} diff --git a/database/main.go b/database/main.go index f918a3b..3a51e8e 100644 --- a/database/main.go +++ b/database/main.go @@ -9,34 +9,40 @@ import ( "github.com/tevino/abool" ) +const ( + databasesSubDir = "databases" +) + var ( initialized = abool.NewBool(false) shuttingDown = abool.NewBool(false) shutdownSignal = make(chan struct{}) + + rootStructure *utils.DirStructure + databasesStructure *utils.DirStructure ) -// SetLocation sets the location of the database. This is separate from the initialization to provide the location to other modules earlier. -func SetLocation(location string) (ok bool) { - if !initialized.IsSet() && rootDir == "" { - rootDir = location - return true - } - return false -} - -// Initialize initialized the database -func Initialize() error { +// Initialize initialized the database at the specified location. Supply either a path or dir structure. +func Initialize(dirPath string, dirStructureRoot *utils.DirStructure) error { if initialized.SetToIf(false, true) { - err := utils.EnsureDirectory(rootDir, 0755) + if dirStructureRoot != nil { + rootStructure = dirStructureRoot + } else { + rootStructure = utils.NewDirStructure(dirPath, 0755) + } + + // ensure root and databases dirs + databasesStructure = rootStructure.ChildDir(databasesSubDir, 0700) + err := databasesStructure.Ensure() if err != nil { - return fmt.Errorf("could not create/open database directory (%s): %s", rootDir, err) + return fmt.Errorf("could not create/open database directory (%s): %s", rootStructure.Path, err) } err = loadRegistry() if err != nil { - return fmt.Errorf("could not load database registry (%s): %s", filepath.Join(rootDir, registryFileName), err) + return fmt.Errorf("could not load database registry (%s): %s", filepath.Join(rootStructure.Path, registryFileName), err) } // start registry writer @@ -66,3 +72,14 @@ func Shutdown() (err error) { } return } + +// getLocation returns the storage location for the given name and type. +func getLocation(name, storageType string) (string, error) { + location := databasesStructure.ChildDir(name, 0700).ChildDir(storageType, 0700) + // check location + err := location.Ensure() + if err != nil { + return "", fmt.Errorf(`failed to create/check database dir "%s": %s`, location.Path, err) + } + return location.Path, nil +} diff --git a/database/registry.go b/database/registry.go index 7195bc9..3cdc11d 100644 --- a/database/registry.go +++ b/database/registry.go @@ -104,7 +104,7 @@ func loadRegistry() error { defer registryLock.Unlock() // read file - filePath := path.Join(rootDir, registryFileName) + filePath := path.Join(rootStructure.Path, registryFileName) data, err := ioutil.ReadFile(filePath) if err != nil { if os.IsNotExist(err) { @@ -139,7 +139,7 @@ func saveRegistry(lock bool) error { } // write file - filePath := path.Join(rootDir, registryFileName) + filePath := path.Join(rootStructure.Path, registryFileName) return ioutil.WriteFile(filePath, data, 0600) }