Release to master

This commit is contained in:
Daniel 2019-08-19 11:29:47 +02:00 committed by GitHub
commit d63a514ee2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
68 changed files with 2274 additions and 508 deletions

122
Gopkg.lock generated
View file

@ -9,6 +9,14 @@
pruneopts = "UT"
revision = "e2d15f34fcf99d5dbb871c820ec73f710fca9815"
[[projects]]
digest = "1:e92f5581902c345eb4ceffdcd4a854fb8f73cf436d47d837d1ec98ef1fe0a214"
name = "github.com/StackExchange/wmi"
packages = ["."]
pruneopts = "UT"
revision = "5d049714c4a64225c3c79a7cf7d02f7fb5b96338"
version = "1.0.0"
[[projects]]
digest = "1:dbd3a713434b6f32d9459b1e6786ad58cec128470b58555cdc7b3b7314a1706f"
name = "github.com/aead/serpent"
@ -34,19 +42,19 @@
version = "v1.1.1"
[[projects]]
digest = "1:5f5090f05382959db941fa45acbeb7f4c5241aa8ac0f8f4393dec696e5953f53"
digest = "1:6be8582a4f52ba2851d8a039eb9c3a3b90334b2820563d71e97de35580da128e"
name = "github.com/dgraph-io/badger"
packages = [
".",
"options",
"protos",
"pb",
"skl",
"table",
"y",
]
pruneopts = "UT"
revision = "99233d725dbdd26d156c61b2f42ae1671b794656"
version = "v1.5.4"
revision = "2fa005c9d4bf695277ab5214c1fbce3735b9562a"
version = "v1.6.0"
[[projects]]
branch = "master"
@ -57,12 +65,31 @@
revision = "6a90982ecee230ff6cba02d5bd386acc030be9d3"
[[projects]]
digest = "1:318f1c959a8a740366fce4b1e1eb2fd914036b4af58fbd0a003349b305f118ad"
digest = "1:6f9339c912bbdda81302633ad7e99a28dfa5a639c864061f1929510a9a64aa74"
name = "github.com/dustin/go-humanize"
packages = ["."]
pruneopts = "UT"
revision = "9f541cc9db5d55bce703bd99987c9d5cb8eea45e"
version = "v1.0.0"
[[projects]]
digest = "1:440028f55cb322d8cb5b9d5ebec298a00b7d74690a658fe6b1c0c0b44341bfae"
name = "github.com/go-ole/go-ole"
packages = [
".",
"oleutil",
]
pruneopts = "UT"
revision = "97b6244175ae18ea6eef668034fd6565847501c9"
version = "v1.2.4"
[[projects]]
digest = "1:573ca21d3669500ff845bdebee890eb7fc7f0f50c59f2132f2a0c6b03d85086a"
name = "github.com/golang/protobuf"
packages = ["proto"]
pruneopts = "UT"
revision = "b5d812f8a3706043e23a9cd5babf2e5423744d30"
version = "v1.3.1"
revision = "6c65a5562fc06764971b7c5d05c76c75e84bdbf7"
version = "v1.3.2"
[[projects]]
digest = "1:c3388642e07731a240e14f4bc7207df59cfcc009447c657b9de87fec072d07e3"
@ -88,9 +115,17 @@
revision = "66b9c49e59c6c48f0ffce28c2d8b8a5678502c6d"
version = "v1.4.0"
[[projects]]
digest = "1:88e0b0baeb9072f0a4afbcf12dda615fc8be001d1802357538591155998da21b"
name = "github.com/hashicorp/go-version"
packages = ["."]
pruneopts = "UT"
revision = "ac23dc3fea5d1a983c43f6a0f6e2c13f0195d8bd"
version = "v1.2.0"
[[projects]]
branch = "master"
digest = "1:ae08d850ba158ea3ba4a7bb90f8372608172d8920644e5a6693b940a1f4e5d01"
digest = "1:7e8b852581596acce37bcb939a05d7d5ff27156045b50057e659e299c16fc1ca"
name = "github.com/mmcloughlin/avo"
packages = [
"attr",
@ -108,7 +143,7 @@
"x86",
]
pruneopts = "UT"
revision = "83fbad1a6b3cba8ac7711170e57953fd12cdc40a"
revision = "bb615f61ce85790a1667efc145c66e917cce1a39"
[[projects]]
digest = "1:cf31692c14422fa27c83a05292eb5cbe0fb2775972e8f1f8446a71549bd8980b"
@ -118,6 +153,22 @@
revision = "ba968bfe8b2f7e042a574c888954fccecfa385b4"
version = "v0.8.1"
[[projects]]
branch = "develop"
digest = "1:d88649ff4a4a0746857dd9e39915aedddce2b08e442ac131a91e573cd45bde93"
name = "github.com/safing/portmaster"
packages = ["core/structure"]
pruneopts = "UT"
revision = "26c307b7a0db78d91b35ef9020706f106ebef8b6"
[[projects]]
digest = "1:274f67cb6fed9588ea2521ecdac05a6d62a8c51c074c1fccc6a49a40ba80e925"
name = "github.com/satori/go.uuid"
packages = ["."]
pruneopts = "UT"
revision = "f58768cc1a7a7e77a3bd49e98cdd21419399b6a3"
version = "v1.2.0"
[[projects]]
branch = "master"
digest = "1:f1ee4af7c43f206d87f13644636e3710a05e499a084a32ec2cc7d8aa25cef1aa"
@ -134,6 +185,29 @@
pruneopts = "UT"
revision = "e9f377c596061894b7f9ee69aab61e62c3ccc13e"
[[projects]]
digest = "1:7ebe52387a847d276a9b9e7b379b2e3e4d536843ad840a9c4426a0152e6b3d54"
name = "github.com/shirou/gopsutil"
packages = [
"cpu",
"host",
"internal/common",
"mem",
"net",
"process",
]
pruneopts = "UT"
revision = "d80c43f9c984a48783daf22f4bd9278006ae483a"
version = "v2.19.7"
[[projects]]
branch = "master"
digest = "1:99c6a6dab47067c9b898e8c8b13d130c6ab4ffbcc4b7cc6236c2cd0b1e344f5b"
name = "github.com/shirou/w32"
packages = ["."]
pruneopts = "UT"
revision = "bb4de0191aa41b5507caa14b0650cdbddcd9280b"
[[projects]]
branch = "master"
digest = "1:93d6687fc19da8a35c7352d72117a6acd2072dfb7e9bfd65646227bf2a913b2a"
@ -143,12 +217,12 @@
revision = "9b9efcf221b50905aab9bbabd3daed56dc10f339"
[[projects]]
digest = "1:7df351557a6d5c30804e7d6f7ed87f2fccb0619c08fcc84869a93f22bec96c11"
digest = "1:46aea6ffe39c3d95c13640c2515236ed3c5cdffc3c78c6c0ed4edec2caf7a0dc"
name = "github.com/tidwall/gjson"
packages = ["."]
pruneopts = "UT"
revision = "eee0b6226f0d1db2675a176fdfaa8419bcad4ca8"
version = "v1.2.1"
revision = "c5e72cdf74dff23857243dd662c465b810891c21"
version = "v1.3.2"
[[projects]]
digest = "1:8453ddbed197809ee8ca28b06bd04e127bec9912deb4ba451fea7a1eca578328"
@ -159,12 +233,12 @@
version = "v1.0.1"
[[projects]]
branch = "master"
digest = "1:ddfe0a54e5f9b29536a6d7b2defa376f2cb2b6e4234d676d7ff214d5b097cb50"
name = "github.com/tidwall/pretty"
packages = ["."]
pruneopts = "UT"
revision = "1166b9ac2b65e46a43d8618d30d1554f4652d49b"
version = "v1.0.0"
[[projects]]
digest = "1:b70c951ba6fdeecfbd50dabe95aa5e1b973866ae9abbece46ad60348112214f2"
@ -175,12 +249,12 @@
version = "v1.0.4"
[[projects]]
digest = "1:5f7414cf41466d4b4dd7ec52b2cd3e481e08cfd11e7e24fef730c0e483e88bb1"
digest = "1:f2ac2c724fc8214bb7b9dd6d4f5b7a983152051f5133320f228557182263cb94"
name = "go.etcd.io/bbolt"
packages = ["."]
pruneopts = "UT"
revision = "63597a96ec0ad9e6d43c3fc81e809909e0237461"
version = "v1.3.2"
revision = "a0458a2b35708eef59eb5f620ceb3cd1c01a824d"
version = "v1.3.3"
[[projects]]
branch = "master"
@ -192,7 +266,7 @@
"sha3",
]
pruneopts = "UT"
revision = "20be4c3c3ed52bfccdb2d59a412ee1a936d175a7"
revision = "4def268fd1a49955bfb3dda92fe3db4f924f2285"
[[projects]]
branch = "master"
@ -203,11 +277,11 @@
"trace",
]
pruneopts = "UT"
revision = "f3200d17e092c607f615320ecaad13d87ad9a2b3"
revision = "74dc4d7220e7acc4e100824340f3e66577424772"
[[projects]]
branch = "master"
digest = "1:8a3986af7a48f0991ce6168708859c56d39d2ff8b82b34d0805bbb545a9a32a6"
digest = "1:6c97b4baa9cf774b42192e23632afc7dee7b899d4a3dc98057d24708ce1f60ac"
name = "golang.org/x/sys"
packages = [
"cpu",
@ -215,11 +289,11 @@
"windows",
]
pruneopts = "UT"
revision = "46560c3f3c0a091352115a3d825af45663b983d8"
revision = "fde4db37ae7ad8191b03d30d27f258b5291ae4e3"
[[projects]]
branch = "master"
digest = "1:47717932fbd4293f80d296cd65254cd9c2ec9e51d8b4227b6254440cfa34da2a"
digest = "1:b3e64b43fe77039813b4f33100e4ef2e17960309f3c6c9ab120f5c55de747992"
name = "golang.org/x/tools"
packages = [
"go/ast/astutil",
@ -233,7 +307,7 @@
"internal/semver",
]
pruneopts = "UT"
revision = "75312fb06703a759656ae75b3d1c24b4aae95dfe"
revision = "caa95bb40b630f80d344d1f710f7e39be971d3e8"
[solve-meta]
analyzer-name = "dep"
@ -246,7 +320,11 @@
"github.com/google/renameio",
"github.com/gorilla/mux",
"github.com/gorilla/websocket",
"github.com/hashicorp/go-version",
"github.com/safing/portmaster/core/structure",
"github.com/satori/go.uuid",
"github.com/seehuhn/fortuna",
"github.com/shirou/gopsutil/host",
"github.com/tevino/abool",
"github.com/tidwall/gjson",
"github.com/tidwall/sjson",

View file

@ -84,3 +84,15 @@
[prune]
go-tests = true
unused-packages = true
[[constraint]]
name = "github.com/satori/go.uuid"
version = "1.2.0"
[[constraint]]
name = "github.com/shirou/gopsutil"
version = "2.19.6"
[[constraint]]
name = "github.com/hashicorp/go-version"
version = "1.2.0"

110
api/authentication.go Normal file
View file

@ -0,0 +1,110 @@
package api
import (
"encoding/base64"
"net/http"
"sync"
"time"
"github.com/safing/portbase/crypto/random"
"github.com/safing/portbase/log"
)
var (
validTokens map[string]time.Time
validTokensLock sync.Mutex
authFnLock sync.Mutex
authFn Authenticator
)
const (
cookieName = "T17"
// in seconds
cookieBaseTTL = 300 // 5 minutes
cookieTTL = cookieBaseTTL * time.Second
cookieRefresh = cookieBaseTTL * 0.9 * time.Second
)
// Authenticator is a function that can be set as the authenticator for the API endpoint. If none is set, all requests will be allowed.
type Authenticator func(s *http.Server, r *http.Request) (grantAccess bool, err error)
// SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be allowed.
func SetAuthenticator(fn Authenticator) error {
authFnLock.Lock()
defer authFnLock.Unlock()
if authFn == nil {
authFn = fn
return nil
}
return ErrAuthenticationAlreadySet
}
func authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// check existing auth cookie
c, err := r.Cookie(cookieName)
if err == nil {
// get token
validTokensLock.Lock()
validUntil, valid := validTokens[c.Value]
validTokensLock.Unlock()
// check if token is valid
if valid && time.Now().Before(validUntil) {
// maybe refresh cookie
if time.Now().After(validUntil.Add(-cookieRefresh)) {
validTokensLock.Lock()
validTokens[c.Value] = time.Now()
validTokensLock.Unlock()
}
next.ServeHTTP(w, r)
return
}
}
// get authenticator
authFnLock.Lock()
authenticator := authFn
authFnLock.Unlock()
// permit if no authenticator set
if authenticator == nil {
next.ServeHTTP(w, r)
return
}
// get auth decision
grantAccess, err := authenticator(server, r)
if err != nil {
log.Warningf("api: authenticator failed: %s", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
if !grantAccess {
log.Warningf("api: denying api access to %s", r.RemoteAddr)
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// write new cookie
token, err := random.Bytes(32) // 256 bit
if err != nil {
log.Warningf("api: failed to generate random token: %s", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
tokenString := base64.RawURLEncoding.EncodeToString(token)
http.SetCookie(w, &http.Cookie{
Name: cookieName,
Value: tokenString,
HttpOnly: true,
})
// serve
log.Tracef("api: granted %s", r.RemoteAddr)
next.ServeHTTP(w, r)
})
}

View file

@ -17,40 +17,41 @@ func init() {
flag.StringVar(&listenAddressFlag, "api-address", "", "override api listen address")
}
func checkFlags() error {
func logFlagOverrides() {
if listenAddressFlag != "" {
log.Warning("api: api/listenAddress config is being overridden by -api-address flag")
log.Warning("api: api/listenAddress default config is being overridden by -api-address flag")
}
return nil
}
func getListenAddress() string {
func getDefaultListenAddress() string {
// check if overridden
if listenAddressFlag != "" {
return listenAddressFlag
}
return listenAddressConfig()
// return internal default
return defaultListenAddress
}
func registerConfig() error {
err := config.Register(&config.Option{
Name: "API Address",
Key: "api/listenAddress",
Description: "Define on what IP and port the API should listen on. Be careful, changing this may become a security issue.",
ExpertiseLevel: config.ExpertiseLevelExpert,
Description: "Define on which IP and port the API should listen on.",
ExpertiseLevel: config.ExpertiseLevelDeveloper,
OptType: config.OptTypeString,
DefaultValue: defaultListenAddress,
DefaultValue: getDefaultListenAddress(),
ValidationRegex: "^([0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}:[0-9]{1,5}|\\[[:0-9A-Fa-f]+\\]:[0-9]{1,5})$",
RequiresRestart: true,
})
if err != nil {
return err
}
listenAddressConfig = config.GetAsString("api/listenAddress", defaultListenAddress)
listenAddressConfig = config.GetAsString("api/listenAddress", getDefaultListenAddress())
return nil
}
// SetDefaultAPIListenAddress sets the default listen address for the API.
func SetDefaultAPIListenAddress(address string) {
if defaultListenAddress == "" {
defaultListenAddress = address
}
defaultListenAddress = address
}

View file

@ -35,6 +35,10 @@ var (
dbAPISeperatorBytes = []byte(dbAPISeperator)
)
func init() {
RegisterHandleFunc("/api/database/v1", startDatabaseAPI) // net/http pattern matching only this exact path
}
// DatabaseAPI is a database API instance.
type DatabaseAPI struct {
conn *websocket.Conn

View file

@ -1,25 +1,68 @@
package api
import (
"bufio"
"errors"
"net"
"net/http"
"github.com/safing/portbase/log"
)
// EnrichedResponseWriter is a wrapper for http.ResponseWriter for better information extraction.
type EnrichedResponseWriter struct {
http.ResponseWriter
Status int
// LoggingResponseWriter is a wrapper for http.ResponseWriter for better request logging.
type LoggingResponseWriter struct {
ResponseWriter http.ResponseWriter
Request *http.Request
Status int
}
// NewEnrichedResponseWriter wraps a http.ResponseWriter.
func NewEnrichedResponseWriter(w http.ResponseWriter) *EnrichedResponseWriter {
return &EnrichedResponseWriter{
w,
0,
// NewLoggingResponseWriter wraps a http.ResponseWriter.
func NewLoggingResponseWriter(w http.ResponseWriter, r *http.Request) *LoggingResponseWriter {
return &LoggingResponseWriter{
ResponseWriter: w,
Request: r,
}
}
// WriteHeader wraps the original WriteHeader method to extract information.
func (ew *EnrichedResponseWriter) WriteHeader(code int) {
ew.Status = code
ew.ResponseWriter.WriteHeader(code)
// Header wraps the original Header method.
func (lrw *LoggingResponseWriter) Header() http.Header {
return lrw.ResponseWriter.Header()
}
// Write wraps the original Write method.
func (lrw *LoggingResponseWriter) Write(b []byte) (int, error) {
return lrw.ResponseWriter.Write(b)
}
// WriteHeader wraps the original WriteHeader method to extract information.
func (lrw *LoggingResponseWriter) WriteHeader(code int) {
lrw.Status = code
lrw.ResponseWriter.WriteHeader(code)
}
// Hijack wraps the original Hijack method, if available.
func (lrw *LoggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := lrw.ResponseWriter.(http.Hijacker)
if ok {
c, b, err := hijacker.Hijack()
if err != nil {
return nil, nil, err
}
log.Infof("api request: %s HIJ %s", lrw.Request.RemoteAddr, lrw.Request.RequestURI)
return c, b, nil
}
return nil, nil, errors.New("response does not implement http.Hijacker")
}
// RequestLogger is a logging middleware
func RequestLogger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI)
lrw := NewLoggingResponseWriter(w, r)
next.ServeHTTP(lrw, r)
if lrw.Status != 0 {
// request may have been hijacked
log.Infof("api request: %s %d %s", lrw.Request.RemoteAddr, lrw.Status, lrw.Request.RequestURI)
}
})
}

View file

@ -1,22 +1,37 @@
package api
import (
"context"
"errors"
"github.com/safing/portbase/modules"
)
// API Errors
var (
ErrAuthenticationAlreadySet = errors.New("the authentication function has already been set")
)
func init() {
modules.Register("api", prep, start, nil, "database")
modules.Register("api", prep, start, stop, "base", "database", "config")
}
func prep() error {
err := checkFlags()
if err != nil {
return err
if getDefaultListenAddress() == "" {
return errors.New("no listen address for api available")
}
return registerConfig()
}
func start() error {
logFlagOverrides()
go Serve()
return nil
}
func stop() error {
if server != nil {
return server.Shutdown(context.Background())
}
return nil
}

27
api/middleware.go Normal file
View file

@ -0,0 +1,27 @@
package api
import "net/http"
// Middleware is a function that can be added as a middleware to the API endpoint.
type Middleware func(next http.Handler) http.Handler
type mwHandler struct {
handlers []Middleware
final http.Handler
}
func (mwh *mwHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
handlerLock.RLock()
defer handlerLock.RUnlock()
// final handler
handler := mwh.final
// build middleware chain
for _, mw := range mwh.handlers {
handler = mw(handler)
}
// start
handler.ServeHTTP(w, r)
}

View file

@ -2,6 +2,7 @@ package api
import (
"net/http"
"sync"
"github.com/gorilla/mux"
@ -9,35 +10,54 @@ import (
)
var (
router = mux.NewRouter()
// gorilla mux
mainMux = mux.NewRouter()
// middlewares
middlewareHandler = &mwHandler{
final: mainMux,
handlers: []Middleware{
RequestLogger,
authMiddleware,
},
}
// main server and lock
server = &http.Server{}
handlerLock sync.RWMutex
)
// RegisterHandleFunc registers an additional handle function with the API endoint.
func RegisterHandleFunc(path string, handleFunc func(http.ResponseWriter, *http.Request)) *mux.Route {
return router.HandleFunc(path, handleFunc)
// RegisterHandler registers a handler with the API endoint.
func RegisterHandler(path string, handler http.Handler) *mux.Route {
handlerLock.Lock()
defer handlerLock.Unlock()
return mainMux.Handle(path, handler)
}
// RequestLogger is a logging middleware
func RequestLogger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI)
ew := NewEnrichedResponseWriter(w)
next.ServeHTTP(ew, r)
log.Infof("api request: %s %d %s", r.RemoteAddr, ew.Status, r.RequestURI)
})
// RegisterHandleFunc registers a handle function with the API endoint.
func RegisterHandleFunc(path string, handleFunc func(http.ResponseWriter, *http.Request)) *mux.Route {
handlerLock.Lock()
defer handlerLock.Unlock()
return mainMux.HandleFunc(path, handleFunc)
}
// RegisterMiddleware registers a middle function with the API endoint.
func RegisterMiddleware(middleware Middleware) {
handlerLock.Lock()
defer handlerLock.Unlock()
middlewareHandler.handlers = append(middlewareHandler.handlers, middleware)
}
// Serve starts serving the API endpoint.
func Serve() {
router.Use(RequestLogger)
// configure server
server.Addr = listenAddressConfig()
server.Handler = middlewareHandler
mainMux := http.NewServeMux()
mainMux.Handle("/", router) // net/http pattern matching /*
mainMux.HandleFunc("/api/database/v1", startDatabaseAPI) // net/http pattern matching only this exact path
address := getListenAddress()
log.Infof("api: starting to listen on %s", address)
log.Errorf("api: failed to listen on %s: %s", address, http.ListenAndServe(address, mainMux))
// start serving
log.Infof("api: starting to listen on %s", server.Addr)
// TODO: retry if failed
log.Errorf("api: failed to listen on %s: %s", server.Addr, server.ListenAndServe())
}
// GetMuxVars wraps github.com/gorilla/mux.Vars in order to mitigate context key issues in multi-repo projects.

View file

@ -1 +0,0 @@
package api

View file

@ -7,5 +7,5 @@ import (
)
func init() {
api.RegisterAdditionalRoute("/test/", http.StripPrefix("/test/", http.FileServer(http.Dir("./api/testclient/root/"))))
api.RegisterHandler("/test/", http.StripPrefix("/test/", http.FileServer(http.Dir("./api/testclient/root/"))))
}

View file

@ -1 +0,0 @@
package api

View file

@ -1,127 +1,127 @@
package config
import (
"errors"
"sort"
"strings"
"errors"
"sort"
"strings"
"github.com/safing/portbase/log"
"github.com/safing/portbase/database"
"github.com/safing/portbase/database/storage"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/database/query"
"github.com/safing/portbase/database/iterator"
"github.com/safing/portbase/database/query"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/database/storage"
"github.com/safing/portbase/log"
)
var (
dbController *database.Controller
dbController *database.Controller
)
// ConfigStorageInterface provices a storage.Interface to the configuration manager.
type ConfigStorageInterface struct {
// StorageInterface provices a storage.Interface to the configuration manager.
type StorageInterface struct {
storage.InjectBase
}
// Get returns a database record.
func (s *ConfigStorageInterface) Get(key string) (record.Record, error) {
optionsLock.Lock()
defer optionsLock.Unlock()
func (s *StorageInterface) Get(key string) (record.Record, error) {
optionsLock.Lock()
defer optionsLock.Unlock()
opt, ok := options[key]
if !ok {
return nil, storage.ErrNotFound
}
opt, ok := options[key]
if !ok {
return nil, storage.ErrNotFound
}
return opt.Export()
return opt.Export()
}
// Put stores a record in the database.
func (s *ConfigStorageInterface) Put(r record.Record) error {
if r.Meta().Deleted > 0 {
return setConfigOption(r.DatabaseKey(), nil, false)
}
func (s *StorageInterface) Put(r record.Record) error {
if r.Meta().Deleted > 0 {
return setConfigOption(r.DatabaseKey(), nil, false)
}
acc := r.GetAccessor(r)
if acc == nil {
return errors.New("invalid data")
}
acc := r.GetAccessor(r)
if acc == nil {
return errors.New("invalid data")
}
val, ok := acc.Get("Value")
if !ok || val == nil {
return setConfigOption(r.DatabaseKey(), nil, false)
}
val, ok := acc.Get("Value")
if !ok || val == nil {
return setConfigOption(r.DatabaseKey(), nil, false)
}
optionsLock.RLock()
option, ok := options[r.DatabaseKey()]
optionsLock.RLock()
option, ok := options[r.DatabaseKey()]
optionsLock.RUnlock()
if !ok {
return errors.New("config option does not exist")
}
if !ok {
return errors.New("config option does not exist")
}
var value interface{}
switch option.OptType {
case OptTypeString :
value, ok = acc.GetString("Value")
case OptTypeStringArray :
value, ok = acc.GetStringArray("Value")
case OptTypeInt :
value, ok = acc.GetInt("Value")
case OptTypeBool :
value, ok = acc.GetBool("Value")
}
if !ok {
return errors.New("received invalid value in \"Value\"")
}
var value interface{}
switch option.OptType {
case OptTypeString:
value, ok = acc.GetString("Value")
case OptTypeStringArray:
value, ok = acc.GetStringArray("Value")
case OptTypeInt:
value, ok = acc.GetInt("Value")
case OptTypeBool:
value, ok = acc.GetBool("Value")
}
if !ok {
return errors.New("received invalid value in \"Value\"")
}
err := setConfigOption(r.DatabaseKey(), value, false)
if err != nil {
return err
}
return nil
err := setConfigOption(r.DatabaseKey(), value, false)
if err != nil {
return err
}
return nil
}
// Delete deletes a record from the database.
func (s *ConfigStorageInterface) Delete(key string) error {
return setConfigOption(key, nil, false)
func (s *StorageInterface) Delete(key string) error {
return setConfigOption(key, nil, false)
}
// Query returns a an iterator for the supplied query.
func (s *ConfigStorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
optionsLock.Lock()
defer optionsLock.Unlock()
optionsLock.Lock()
defer optionsLock.Unlock()
it := iterator.New()
var opts []*Option
for _, opt := range options {
if strings.HasPrefix(opt.Key, q.DatabaseKeyPrefix()) {
opts = append(opts, opt)
}
}
it := iterator.New()
var opts []*Option
for _, opt := range options {
if strings.HasPrefix(opt.Key, q.DatabaseKeyPrefix()) {
opts = append(opts, opt)
}
}
go s.processQuery(q, it, opts)
go s.processQuery(q, it, opts)
return it, nil
return it, nil
}
func (s *ConfigStorageInterface) processQuery(q *query.Query, it *iterator.Iterator, opts []*Option) {
func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator, opts []*Option) {
sort.Sort(sortableOptions(opts))
sort.Sort(sortableOptions(opts))
for _, opt := range opts {
r, err := opt.Export()
if err != nil {
it.Finish(err)
return
}
it.Next <- r
}
for _, opt := range opts {
r, err := opt.Export()
if err != nil {
it.Finish(err)
return
}
it.Next <- r
}
it.Finish(nil)
it.Finish(nil)
}
// ReadOnly returns whether the database is read only.
func (s *ConfigStorageInterface) ReadOnly() bool {
func (s *StorageInterface) ReadOnly() bool {
return false
}
@ -132,33 +132,33 @@ func registerAsDatabase() error {
StorageType: "injected",
PrimaryAPI: "",
})
if err != nil {
return err
}
if err != nil {
return err
}
controller, err := database.InjectDatabase("config", &ConfigStorageInterface{})
if err != nil {
return err
}
controller, err := database.InjectDatabase("config", &StorageInterface{})
if err != nil {
return err
}
dbController = controller
dbController = controller
return nil
}
func pushFullUpdate() {
optionsLock.RLock()
defer optionsLock.RUnlock()
optionsLock.RLock()
defer optionsLock.RUnlock()
for _, option := range options {
pushUpdate(option)
}
for _, option := range options {
pushUpdate(option)
}
}
func pushUpdate(option *Option) {
r, err := option.Export()
if err != nil {
log.Errorf("failed to export option to push update: %s", err)
} else {
dbController.PushUpdate(r)
}
r, err := option.Export()
if err != nil {
log.Errorf("failed to export option to push update: %s", err)
} else {
dbController.PushUpdate(r)
}
}

View file

@ -1,23 +1,40 @@
package config
import (
"errors"
"os"
"path"
"path/filepath"
"github.com/safing/portbase/database"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/utils"
"github.com/safing/portmaster/core/structure"
)
var (
dataRoot *utils.DirStructure
)
// SetDataRoot sets the data root from which the updates module derives its paths.
func SetDataRoot(root *utils.DirStructure) {
if dataRoot == nil {
dataRoot = root
}
}
func init() {
modules.Register("config", prep, start, nil, "database")
modules.Register("config", prep, start, nil, "base", "database")
}
func prep() error {
SetDataRoot(structure.Root())
if dataRoot == nil {
return errors.New("data root is not set")
}
return nil
}
func start() error {
configFilePath = path.Join(database.GetDatabaseRoot(), "config.json")
configFilePath = filepath.Join(dataRoot.Path, "config.json")
err := registerAsDatabase()
if err != nil && !os.IsNotExist(err) {

View file

@ -1,9 +1,9 @@
package config
import (
"encoding/json"
"fmt"
"regexp"
"encoding/json"
"github.com/tidwall/sjson"
@ -47,6 +47,7 @@ type Option struct {
DefaultValue interface{}
ExternalOptType string
ValidationRegex string
RequiresRestart bool
compiledRegex *regexp.Regexp
}

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package hash
import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package hash
import "testing"

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package hash
import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package hash
import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package hash
import (

View file

@ -23,7 +23,7 @@ var (
)
func init() {
modules.Register("random", prep, Start, stop)
modules.Register("random", prep, Start, stop, "base")
config.Register(&config.Option{
Name: "RNG Cipher",

View file

@ -2,51 +2,46 @@ package dbmodule
import (
"errors"
"flag"
"sync"
"github.com/safing/portbase/database"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/utils"
)
var (
databaseDir string
shutdownSignal = make(chan struct{})
maintenanceWg sync.WaitGroup
databasePath string
databaseStructureRoot *utils.DirStructure
module *modules.Module
)
// SetDatabaseLocation sets the location of the database. Must be called before modules.Start and will be overridden by command line options. Intended for unit tests.
func SetDatabaseLocation(location string) {
databaseDir = location
func init() {
module = modules.Register("database", prep, start, stop, "base")
}
func init() {
flag.StringVar(&databaseDir, "db", "", "set database directory")
modules.Register("database", prep, start, stop)
// SetDatabaseLocation sets the location of the database for initialization. Supply either a path or dir structure.
func SetDatabaseLocation(dirPath string, dirStructureRoot *utils.DirStructure) {
databasePath = dirPath
databaseStructureRoot = dirStructureRoot
}
func prep() error {
if databaseDir == "" {
return errors.New("no database location specified, set with `-db=/path/to/db`")
}
ok := database.SetLocation(databaseDir)
if !ok {
return errors.New("database location already set")
if databasePath == "" && databaseStructureRoot == nil {
return errors.New("no database location specified")
}
return nil
}
func start() error {
err := database.Initialize()
if err == nil {
startMaintainer()
err := database.Initialize(databasePath, databaseStructureRoot)
if err != nil {
return err
}
return err
startMaintainer()
return nil
}
func stop() error {
close(shutdownSignal)
maintenanceWg.Wait()
return database.Shutdown()
}

View file

@ -13,7 +13,7 @@ var (
)
func startMaintainer() {
maintenanceWg.Add(1)
module.AddWorkers(1)
go maintenanceWorker()
}
@ -37,8 +37,8 @@ func maintenanceWorker() {
if err != nil {
log.Errorf("database: thorough maintenance error: %s", err)
}
case <-shutdownSignal:
maintenanceWg.Done()
case <-module.ShuttingDown():
module.FinishWorker()
return
}
}

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package dbutils
type Meta struct {

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
/*
Package database provides a universal interface for interacting with the database.

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package database
import (

View file

@ -1,33 +1 @@
package database
import (
"fmt"
"path/filepath"
"github.com/safing/portbase/utils"
)
const (
databasesSubDir = "databases"
)
var (
rootDir string
)
// GetDatabaseRoot returns the root directory of the database.
func GetDatabaseRoot() string {
return rootDir
}
// getLocation returns the storage location for the given name and type.
func getLocation(name, storageType string) (string, error) {
location := filepath.Join(rootDir, databasesSubDir, name, storageType)
// check location
err := utils.EnsureDirectory(location, 0700)
if err != nil {
return "", fmt.Errorf("location (%s) invalid: %s", location, err)
}
return location, nil
}

View file

@ -9,34 +9,40 @@ import (
"github.com/tevino/abool"
)
const (
databasesSubDir = "databases"
)
var (
initialized = abool.NewBool(false)
shuttingDown = abool.NewBool(false)
shutdownSignal = make(chan struct{})
rootStructure *utils.DirStructure
databasesStructure *utils.DirStructure
)
// SetLocation sets the location of the database. This is separate from the initialization to provide the location to other modules earlier.
func SetLocation(location string) (ok bool) {
if !initialized.IsSet() && rootDir == "" {
rootDir = location
return true
}
return false
}
// Initialize initialized the database
func Initialize() error {
// Initialize initializes the database at the specified location. Supply either a path or dir structure.
func Initialize(dirPath string, dirStructureRoot *utils.DirStructure) error {
if initialized.SetToIf(false, true) {
err := utils.EnsureDirectory(rootDir, 0755)
if dirStructureRoot != nil {
rootStructure = dirStructureRoot
} else {
rootStructure = utils.NewDirStructure(dirPath, 0755)
}
// ensure root and databases dirs
databasesStructure = rootStructure.ChildDir(databasesSubDir, 0700)
err := databasesStructure.Ensure()
if err != nil {
return fmt.Errorf("could not create/open database directory (%s): %s", rootDir, err)
return fmt.Errorf("could not create/open database directory (%s): %s", rootStructure.Path, err)
}
err = loadRegistry()
if err != nil {
return fmt.Errorf("could not load database registry (%s): %s", filepath.Join(rootDir, registryFileName), err)
return fmt.Errorf("could not load database registry (%s): %s", filepath.Join(rootStructure.Path, registryFileName), err)
}
// start registry writer
@ -66,3 +72,14 @@ func Shutdown() (err error) {
}
return
}
// getLocation returns the storage location for the given name and type.
func getLocation(name, storageType string) (string, error) {
location := databasesStructure.ChildDir(name, 0700).ChildDir(storageType, 0700)
// check location
err := location.Ensure()
if err != nil {
return "", fmt.Errorf(`failed to create/check database dir "%s": %s`, location.Path, err)
}
return location.Path, nil
}

View file

@ -70,7 +70,7 @@ func (b *Base) SetMeta(meta *Meta) {
b.meta = meta
}
// Marshal marshals the object, without the database key or metadata
// Marshal marshals the object, without the database key or metadata. It returns nil if the record is deleted.
func (b *Base) Marshal(self Record, format uint8) ([]byte, error) {
if b.Meta() == nil {
return nil, errors.New("missing meta")
@ -96,15 +96,15 @@ func (b *Base) MarshalRecord(self Record) ([]byte, error) {
// version
c := container.New([]byte{1})
// meta
metaSection, err := b.meta.GenCodeMarshal(nil)
// meta encoding
metaSection, err := dsd.Dump(b.meta, GenCode)
if err != nil {
return nil, err
}
c.AppendAsBlock(metaSection)
// data
dataSection, err := b.Marshal(self, dsd.JSON)
dataSection, err := b.Marshal(self, JSON)
if err != nil {
return nil, err
}

View file

@ -11,5 +11,5 @@ const (
BYTES = dsd.BYTES // X
JSON = dsd.JSON // J
BSON = dsd.BSON // B
GenCode = dsd.GenCode // G (reserved)
GenCode = dsd.GenCode // G
)

View file

@ -432,7 +432,7 @@ func BenchmarkMetaUnserializeWithCodegen(b *testing.B) {
func BenchmarkMetaSerializeWithDSDJSON(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := dsd.Dump(testMeta, dsd.JSON)
_, err := dsd.Dump(testMeta, JSON)
if err != nil {
b.Errorf("failed to serialize with DSD/JSON: %s", err)
return
@ -444,7 +444,7 @@ func BenchmarkMetaSerializeWithDSDJSON(b *testing.B) {
func BenchmarkMetaUnserializeWithDSDJSON(b *testing.B) {
// Setup
encodedData, err := dsd.Dump(testMeta, dsd.JSON)
encodedData, err := dsd.Dump(testMeta, JSON)
if err != nil {
b.Errorf("failed to serialize with DSD/JSON: %s", err)
return

View file

@ -37,9 +37,20 @@ func NewRawWrapper(database, key string, data []byte) (*Wrapper, error) {
offset += n
newMeta := &Meta{}
_, err = newMeta.GenCodeUnmarshal(metaSection)
if err != nil {
return nil, fmt.Errorf("could not unmarshal meta section: %s", err)
if len(metaSection) == 34 && metaSection[4] == 0 {
// TODO: remove in 2020
// backward compatibility:
// format would byte shift and populate metaSection[4] with value > 0 (would naturally populate >0 at 07.02.2106 07:28:15)
// this must be gencode without format
_, err = newMeta.GenCodeUnmarshal(metaSection)
if err != nil {
return nil, fmt.Errorf("could not unmarshal meta section: %s", err)
}
} else {
_, err = dsd.Load(metaSection, newMeta)
if err != nil {
return nil, fmt.Errorf("could not unmarshal meta section: %s", err)
}
}
format, n, err := varint.Unpack8(data[offset:])
@ -86,7 +97,7 @@ func (w *Wrapper) Marshal(r Record, format uint8) ([]byte, error) {
return nil, nil
}
if format != dsd.AUTO && format != w.Format {
if format != AUTO && format != w.Format {
return nil, errors.New("could not dump model, wrapped object format mismatch")
}
@ -109,14 +120,14 @@ func (w *Wrapper) MarshalRecord(r Record) ([]byte, error) {
c := container.New([]byte{1})
// meta
metaSection, err := w.meta.GenCodeMarshal(nil)
metaSection, err := dsd.Dump(w.meta, GenCode)
if err != nil {
return nil, err
}
c.AppendAsBlock(metaSection)
// data
dataSection, err := w.Marshal(r, dsd.JSON)
dataSection, err := w.Marshal(r, JSON)
if err != nil {
return nil, err
}
@ -125,16 +136,6 @@ func (w *Wrapper) MarshalRecord(r Record) ([]byte, error) {
return c.CompileData(), nil
}
// // Lock locks the record.
// func (w *Wrapper) Lock() {
// w.lock.Lock()
// }
//
// // Unlock unlocks the record.
// func (w *Wrapper) Unlock() {
// w.lock.Unlock()
// }
// IsWrapped returns whether the record is a Wrapper.
func (w *Wrapper) IsWrapped() bool {
return true

View file

@ -2,9 +2,10 @@ package record
import (
"bytes"
"errors"
"testing"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/container"
)
func TestWrapper(t *testing.T) {
@ -24,14 +25,14 @@ func TestWrapper(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if wrapper.Format != dsd.JSON {
if wrapper.Format != JSON {
t.Error("format mismatch")
}
if !bytes.Equal(testData, wrapper.Data) {
t.Error("data mismatch")
}
encoded, err := wrapper.Marshal(wrapper, dsd.JSON)
encoded, err := wrapper.Marshal(wrapper, JSON)
if err != nil {
t.Fatal(err)
}
@ -40,6 +41,7 @@ func TestWrapper(t *testing.T) {
}
wrapper.SetMeta(&Meta{})
wrapper.meta.Update()
raw, err := wrapper.MarshalRecord(wrapper)
if err != nil {
t.Fatal(err)
@ -53,4 +55,42 @@ func TestWrapper(t *testing.T) {
t.Error("marshal mismatch")
}
// test new format
oldRaw, err := oldWrapperMarshalRecord(wrapper, wrapper)
if err != nil {
t.Fatal(err)
}
wrapper3, err := NewRawWrapper("test", "a", oldRaw)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(testData, wrapper3.Data) {
t.Error("marshal mismatch")
}
}
func oldWrapperMarshalRecord(w *Wrapper, r Record) ([]byte, error) {
if w.Meta() == nil {
return nil, errors.New("missing meta")
}
// version
c := container.New([]byte{1})
// meta
metaSection, err := w.meta.GenCodeMarshal(nil)
if err != nil {
return nil, err
}
c.AppendAsBlock(metaSection)
// data
dataSection, err := w.Marshal(r, JSON)
if err != nil {
return nil, err
}
c.Append(dataSection)
return c.CompileData(), nil
}

View file

@ -104,7 +104,7 @@ func loadRegistry() error {
defer registryLock.Unlock()
// read file
filePath := path.Join(rootDir, registryFileName)
filePath := path.Join(rootStructure.Path, registryFileName)
data, err := ioutil.ReadFile(filePath)
if err != nil {
if os.IsNotExist(err) {
@ -139,7 +139,7 @@ func saveRegistry(lock bool) error {
}
// write file
filePath := path.Join(rootDir, registryFileName)
filePath := path.Join(rootStructure.Path, registryFileName)
return ioutil.WriteFile(filePath, data, 0600)
}

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package dsd
// dynamic structured data
@ -22,7 +20,7 @@ const (
BYTES = 88 // X
JSON = 74 // J
BSON = 66 // B
GenCode = 71 // G (reserved)
GenCode = 71 // G
)
// define errors
@ -30,6 +28,7 @@ var errNoMoreSpace = errors.New("dsd: no more space left after reading dsd type"
var errUnknownType = errors.New("dsd: tried to unpack unknown type")
var errNotImplemented = errors.New("dsd: this type is not yet implemented")
// Load loads an dsd structured data blob into the given interface.
func Load(data []byte, t interface{}) (interface{}, error) {
if len(data) < 2 {
return nil, errNoMoreSpace
@ -46,6 +45,7 @@ func Load(data []byte, t interface{}) (interface{}, error) {
return LoadAsFormat(data[read:], format, t)
}
// LoadAsFormat loads a data blob into the interface using the specified format.
func LoadAsFormat(data []byte, format uint8, t interface{}) (interface{}, error) {
switch format {
case STRING:
@ -55,28 +55,32 @@ func LoadAsFormat(data []byte, format uint8, t interface{}) (interface{}, error)
case JSON:
err := json.Unmarshal(data, t)
if err != nil {
return nil, err
return nil, fmt.Errorf("dsd: failed to unpack json data: %s", data)
}
return t, nil
// case BSON:
// err := bson.Unmarshal(data[read:], t)
// if err != nil {
// return nil, err
// }
// return t, nil
case GenCode:
genCodeStruct, ok := t.(GenCodeCompatible)
if !ok {
return nil, errors.New("dsd: gencode is not supported by the given data structure")
}
_, err := genCodeStruct.GenCodeUnmarshal(data)
if err != nil {
return nil, fmt.Errorf("dsd: failed to unpack gencode data: %s", err)
}
return t, nil
// case BSON:
// err := bson.Unmarshal(data[read:], t)
// if err != nil {
// return nil, err
// }
// return t, nil
// case MSGP:
// err := t.UnmarshalMsg(data[read:])
// if err != nil {
// return nil, err
// }
// return t, nil
default:
return nil, errors.New(fmt.Sprintf("dsd: tried to load unknown type %d, data: %v", format, data))
return nil, fmt.Errorf("dsd: tried to load unknown type %d, data: %v", format, data)
}
}
// Dump stores the interface as a dsd formatted data structure.
func Dump(t interface{}, format uint8) ([]byte, error) {
if format == AUTO {
switch t.(type) {
case string:
@ -107,18 +111,19 @@ func Dump(t interface{}, format uint8) ([]byte, error) {
// if err != nil {
// return nil, err
// }
// case MSGP:
// data, err := t.MarshalMsg(nil)
// if err != nil {
// return nil, err
// }
case GenCode:
genCodeStruct, ok := t.(GenCodeCompatible)
if !ok {
return nil, errors.New("dsd: gencode is not supported by the given data structure")
}
data, err = genCodeStruct.GenCodeMarshal(nil)
if err != nil {
return nil, fmt.Errorf("dsd: failed to pack gencode struct: %s", err)
}
default:
return nil, errors.New(fmt.Sprintf("dsd: tried to dump unknown type %d", format))
return nil, fmt.Errorf("dsd: tried to dump unknown type %d", format)
}
r := append(f, data...)
// log.Tracef("packing %v to %s", t, string(r))
// return nil, errors.New(fmt.Sprintf("dsd: dumped bytes are: %v", r))
return r, nil
}

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package dsd
import (
@ -10,6 +8,7 @@ import (
//go:generate msgp
// SimpleTestStruct is used for testing.
type SimpleTestStruct struct {
S string
B byte
@ -21,11 +20,11 @@ type ComplexTestStruct struct {
I16 int16
I32 int32
I64 int64
Ui uint
Ui8 uint8
Ui16 uint16
Ui32 uint32
Ui64 uint64
UI uint
UI8 uint8
UI16 uint16
UI32 uint32
UI64 uint64
S string
Sp *string
Sa []string
@ -38,6 +37,25 @@ type ComplexTestStruct struct {
Mp *map[string]string
}
type GenCodeTestStruct struct {
I8 int8
I16 int16
I32 int32
I64 int64
UI8 uint8
UI16 uint16
UI32 uint32
UI64 uint64
S string
Sp *string
Sa []string
Sap *[]string
B byte
Bp *byte
Ba []byte
Bap *[]byte
}
func TestConversion(t *testing.T) {
// STRING
@ -113,7 +131,26 @@ func TestConversion(t *testing.T) {
},
}
// TODO: test all formats
genCodeSubject := GenCodeTestStruct{
-2,
-3,
-4,
-5,
2,
3,
4,
5,
"a",
&bString,
[]string{"c", "d", "e"},
&[]string{"f", "g", "h"},
0x01,
&bBytes,
[]byte{0x03, 0x04, 0x05},
&[]byte{0x05, 0x06, 0x07},
}
// test all formats (complex)
formats := []uint8{JSON}
for _, format := range formats {
@ -163,20 +200,20 @@ func TestConversion(t *testing.T) {
if complexSubject.I64 != co.I64 {
t.Errorf("Load (complex struct): struct.I64 is not equal (%v != %v)", complexSubject.I64, co.I64)
}
if complexSubject.Ui != co.Ui {
t.Errorf("Load (complex struct): struct.Ui is not equal (%v != %v)", complexSubject.Ui, co.Ui)
if complexSubject.UI != co.UI {
t.Errorf("Load (complex struct): struct.UI is not equal (%v != %v)", complexSubject.UI, co.UI)
}
if complexSubject.Ui8 != co.Ui8 {
t.Errorf("Load (complex struct): struct.Ui8 is not equal (%v != %v)", complexSubject.Ui8, co.Ui8)
if complexSubject.UI8 != co.UI8 {
t.Errorf("Load (complex struct): struct.UI8 is not equal (%v != %v)", complexSubject.UI8, co.UI8)
}
if complexSubject.Ui16 != co.Ui16 {
t.Errorf("Load (complex struct): struct.Ui16 is not equal (%v != %v)", complexSubject.Ui16, co.Ui16)
if complexSubject.UI16 != co.UI16 {
t.Errorf("Load (complex struct): struct.UI16 is not equal (%v != %v)", complexSubject.UI16, co.UI16)
}
if complexSubject.Ui32 != co.Ui32 {
t.Errorf("Load (complex struct): struct.Ui32 is not equal (%v != %v)", complexSubject.Ui32, co.Ui32)
if complexSubject.UI32 != co.UI32 {
t.Errorf("Load (complex struct): struct.UI32 is not equal (%v != %v)", complexSubject.UI32, co.UI32)
}
if complexSubject.Ui64 != co.Ui64 {
t.Errorf("Load (complex struct): struct.Ui64 is not equal (%v != %v)", complexSubject.Ui64, co.Ui64)
if complexSubject.UI64 != co.UI64 {
t.Errorf("Load (complex struct): struct.UI64 is not equal (%v != %v)", complexSubject.UI64, co.UI64)
}
if complexSubject.S != co.S {
t.Errorf("Load (complex struct): struct.S is not equal (%v != %v)", complexSubject.S, co.S)
@ -211,4 +248,87 @@ func TestConversion(t *testing.T) {
}
// test all formats
formats = []uint8{JSON, GenCode}
for _, format := range formats {
// simple
b, err := Dump(&simpleSubject, format)
if err != nil {
t.Fatalf("Dump error (simple struct): %s", err)
}
o, err := Load(b, &SimpleTestStruct{})
if err != nil {
t.Fatalf("Load error (simple struct): %s", err)
}
if !reflect.DeepEqual(&simpleSubject, o) {
t.Errorf("Load (simple struct): subject does not match loaded object")
t.Errorf("Encoded: %v", string(b))
t.Errorf("Compared: %v == %v", &simpleSubject, o)
}
// complex
b, err = Dump(&genCodeSubject, format)
if err != nil {
t.Fatalf("Dump error (complex struct): %s", err)
}
o, err = Load(b, &GenCodeTestStruct{})
if err != nil {
t.Fatalf("Load error (complex struct): %s", err)
}
co := o.(*GenCodeTestStruct)
if genCodeSubject.I8 != co.I8 {
t.Errorf("Load (complex struct): struct.I8 is not equal (%v != %v)", genCodeSubject.I8, co.I8)
}
if genCodeSubject.I16 != co.I16 {
t.Errorf("Load (complex struct): struct.I16 is not equal (%v != %v)", genCodeSubject.I16, co.I16)
}
if genCodeSubject.I32 != co.I32 {
t.Errorf("Load (complex struct): struct.I32 is not equal (%v != %v)", genCodeSubject.I32, co.I32)
}
if genCodeSubject.I64 != co.I64 {
t.Errorf("Load (complex struct): struct.I64 is not equal (%v != %v)", genCodeSubject.I64, co.I64)
}
if genCodeSubject.UI8 != co.UI8 {
t.Errorf("Load (complex struct): struct.UI8 is not equal (%v != %v)", genCodeSubject.UI8, co.UI8)
}
if genCodeSubject.UI16 != co.UI16 {
t.Errorf("Load (complex struct): struct.UI16 is not equal (%v != %v)", genCodeSubject.UI16, co.UI16)
}
if genCodeSubject.UI32 != co.UI32 {
t.Errorf("Load (complex struct): struct.UI32 is not equal (%v != %v)", genCodeSubject.UI32, co.UI32)
}
if genCodeSubject.UI64 != co.UI64 {
t.Errorf("Load (complex struct): struct.UI64 is not equal (%v != %v)", genCodeSubject.UI64, co.UI64)
}
if genCodeSubject.S != co.S {
t.Errorf("Load (complex struct): struct.S is not equal (%v != %v)", genCodeSubject.S, co.S)
}
if !reflect.DeepEqual(genCodeSubject.Sp, co.Sp) {
t.Errorf("Load (complex struct): struct.Sp is not equal (%v != %v)", genCodeSubject.Sp, co.Sp)
}
if !reflect.DeepEqual(genCodeSubject.Sa, co.Sa) {
t.Errorf("Load (complex struct): struct.Sa is not equal (%v != %v)", genCodeSubject.Sa, co.Sa)
}
if !reflect.DeepEqual(genCodeSubject.Sap, co.Sap) {
t.Errorf("Load (complex struct): struct.Sap is not equal (%v != %v)", genCodeSubject.Sap, co.Sap)
}
if genCodeSubject.B != co.B {
t.Errorf("Load (complex struct): struct.B is not equal (%v != %v)", genCodeSubject.B, co.B)
}
if !reflect.DeepEqual(genCodeSubject.Bp, co.Bp) {
t.Errorf("Load (complex struct): struct.Bp is not equal (%v != %v)", genCodeSubject.Bp, co.Bp)
}
if !reflect.DeepEqual(genCodeSubject.Ba, co.Ba) {
t.Errorf("Load (complex struct): struct.Ba is not equal (%v != %v)", genCodeSubject.Ba, co.Ba)
}
if !reflect.DeepEqual(genCodeSubject.Bap, co.Bap) {
t.Errorf("Load (complex struct): struct.Bap is not equal (%v != %v)", genCodeSubject.Bap, co.Bap)
}
}
}

835
formats/dsd/gencode_test.go Normal file
View file

@ -0,0 +1,835 @@
package dsd
import (
"io"
"time"
"unsafe"
)
var (
_ = unsafe.Sizeof(0)
_ = io.ReadFull
_ = time.Now()
)
func (d *SimpleTestStruct) Size() (s uint64) {
{
l := uint64(len(d.S))
{
t := l
for t >= 0x80 {
t >>= 7
s++
}
s++
}
s += l
}
s++
return
}
func (d *SimpleTestStruct) GenCodeMarshal(buf []byte) ([]byte, error) {
size := d.Size()
{
if uint64(cap(buf)) >= size {
buf = buf[:size]
} else {
buf = make([]byte, size)
}
}
i := uint64(0)
{
l := uint64(len(d.S))
{
t := uint64(l)
for t >= 0x80 {
buf[i+0] = byte(t) | 0x80
t >>= 7
i++
}
buf[i+0] = byte(t)
i++
}
copy(buf[i+0:], d.S)
i += l
}
{
buf[i+0] = d.B
}
return buf[:i+1], nil
}
func (d *SimpleTestStruct) GenCodeUnmarshal(buf []byte) (uint64, error) {
i := uint64(0)
{
l := uint64(0)
{
bs := uint8(7)
t := uint64(buf[i+0] & 0x7F)
for buf[i+0]&0x80 == 0x80 {
i++
t |= uint64(buf[i+0]&0x7F) << bs
bs += 7
}
i++
l = t
}
d.S = string(buf[i+0 : i+0+l])
i += l
}
{
d.B = buf[i+0]
}
return i + 1, nil
}
func (d *GenCodeTestStruct) Size() (s uint64) {
{
l := uint64(len(d.S))
{
t := l
for t >= 0x80 {
t >>= 7
s++
}
s++
}
s += l
}
{
if d.Sp != nil {
{
l := uint64(len((*d.Sp)))
{
t := l
for t >= 0x80 {
t >>= 7
s++
}
s++
}
s += l
}
s += 0
}
}
{
l := uint64(len(d.Sa))
{
t := l
for t >= 0x80 {
t >>= 7
s++
}
s++
}
for k0 := range d.Sa {
{
l := uint64(len(d.Sa[k0]))
{
t := l
for t >= 0x80 {
t >>= 7
s++
}
s++
}
s += l
}
}
}
{
if d.Sap != nil {
{
l := uint64(len((*d.Sap)))
{
t := l
for t >= 0x80 {
t >>= 7
s++
}
s++
}
for k0 := range *d.Sap {
{
l := uint64(len((*d.Sap)[k0]))
{
t := l
for t >= 0x80 {
t >>= 7
s++
}
s++
}
s += l
}
}
}
s += 0
}
}
{
if d.Bp != nil {
s++
}
}
{
l := uint64(len(d.Ba))
{
t := l
for t >= 0x80 {
t >>= 7
s++
}
s++
}
s += l
}
{
if d.Bap != nil {
{
l := uint64(len((*d.Bap)))
{
t := l
for t >= 0x80 {
t >>= 7
s++
}
s++
}
s += l
}
s += 0
}
}
s += 35
return
}
func (d *GenCodeTestStruct) GenCodeMarshal(buf []byte) ([]byte, error) {
size := d.Size()
{
if uint64(cap(buf)) >= size {
buf = buf[:size]
} else {
buf = make([]byte, size)
}
}
i := uint64(0)
{
buf[0+0] = byte(d.I8 >> 0)
}
{
buf[0+1] = byte(d.I16 >> 0)
buf[1+1] = byte(d.I16 >> 8)
}
{
buf[0+3] = byte(d.I32 >> 0)
buf[1+3] = byte(d.I32 >> 8)
buf[2+3] = byte(d.I32 >> 16)
buf[3+3] = byte(d.I32 >> 24)
}
{
buf[0+7] = byte(d.I64 >> 0)
buf[1+7] = byte(d.I64 >> 8)
buf[2+7] = byte(d.I64 >> 16)
buf[3+7] = byte(d.I64 >> 24)
buf[4+7] = byte(d.I64 >> 32)
buf[5+7] = byte(d.I64 >> 40)
buf[6+7] = byte(d.I64 >> 48)
buf[7+7] = byte(d.I64 >> 56)
}
{
buf[0+15] = byte(d.UI8 >> 0)
}
{
buf[0+16] = byte(d.UI16 >> 0)
buf[1+16] = byte(d.UI16 >> 8)
}
{
buf[0+18] = byte(d.UI32 >> 0)
buf[1+18] = byte(d.UI32 >> 8)
buf[2+18] = byte(d.UI32 >> 16)
buf[3+18] = byte(d.UI32 >> 24)
}
{
buf[0+22] = byte(d.UI64 >> 0)
buf[1+22] = byte(d.UI64 >> 8)
buf[2+22] = byte(d.UI64 >> 16)
buf[3+22] = byte(d.UI64 >> 24)
buf[4+22] = byte(d.UI64 >> 32)
buf[5+22] = byte(d.UI64 >> 40)
buf[6+22] = byte(d.UI64 >> 48)
buf[7+22] = byte(d.UI64 >> 56)
}
{
l := uint64(len(d.S))
{
t := uint64(l)
for t >= 0x80 {
buf[i+30] = byte(t) | 0x80
t >>= 7
i++
}
buf[i+30] = byte(t)
i++
}
copy(buf[i+30:], d.S)
i += l
}
{
if d.Sp == nil {
buf[i+30] = 0
} else {
buf[i+30] = 1
{
l := uint64(len((*d.Sp)))
{
t := uint64(l)
for t >= 0x80 {
buf[i+31] = byte(t) | 0x80
t >>= 7
i++
}
buf[i+31] = byte(t)
i++
}
copy(buf[i+31:], (*d.Sp))
i += l
}
i += 0
}
}
{
l := uint64(len(d.Sa))
{
t := uint64(l)
for t >= 0x80 {
buf[i+31] = byte(t) | 0x80
t >>= 7
i++
}
buf[i+31] = byte(t)
i++
}
for k0 := range d.Sa {
{
l := uint64(len(d.Sa[k0]))
{
t := uint64(l)
for t >= 0x80 {
buf[i+31] = byte(t) | 0x80
t >>= 7
i++
}
buf[i+31] = byte(t)
i++
}
copy(buf[i+31:], d.Sa[k0])
i += l
}
}
}
{
if d.Sap == nil {
buf[i+31] = 0
} else {
buf[i+31] = 1
{
l := uint64(len((*d.Sap)))
{
t := uint64(l)
for t >= 0x80 {
buf[i+32] = byte(t) | 0x80
t >>= 7
i++
}
buf[i+32] = byte(t)
i++
}
for k0 := range *d.Sap {
{
l := uint64(len((*d.Sap)[k0]))
{
t := uint64(l)
for t >= 0x80 {
buf[i+32] = byte(t) | 0x80
t >>= 7
i++
}
buf[i+32] = byte(t)
i++
}
copy(buf[i+32:], (*d.Sap)[k0])
i += l
}
}
}
i += 0
}
}
{
buf[i+32] = d.B
}
{
if d.Bp == nil {
buf[i+33] = 0
} else {
buf[i+33] = 1
{
buf[i+34] = (*d.Bp)
}
i++
}
}
{
l := uint64(len(d.Ba))
{
t := uint64(l)
for t >= 0x80 {
buf[i+34] = byte(t) | 0x80
t >>= 7
i++
}
buf[i+34] = byte(t)
i++
}
copy(buf[i+34:], d.Ba)
i += l
}
{
if d.Bap == nil {
buf[i+34] = 0
} else {
buf[i+34] = 1
{
l := uint64(len((*d.Bap)))
{
t := uint64(l)
for t >= 0x80 {
buf[i+35] = byte(t) | 0x80
t >>= 7
i++
}
buf[i+35] = byte(t)
i++
}
copy(buf[i+35:], (*d.Bap))
i += l
}
i += 0
}
}
return buf[:i+35], nil
}
func (d *GenCodeTestStruct) GenCodeUnmarshal(buf []byte) (uint64, error) {
i := uint64(0)
{
d.I8 = 0 | (int8(buf[i+0+0]) << 0)
}
{
d.I16 = 0 | (int16(buf[i+0+1]) << 0) | (int16(buf[i+1+1]) << 8)
}
{
d.I32 = 0 | (int32(buf[i+0+3]) << 0) | (int32(buf[i+1+3]) << 8) | (int32(buf[i+2+3]) << 16) | (int32(buf[i+3+3]) << 24)
}
{
d.I64 = 0 | (int64(buf[i+0+7]) << 0) | (int64(buf[i+1+7]) << 8) | (int64(buf[i+2+7]) << 16) | (int64(buf[i+3+7]) << 24) | (int64(buf[i+4+7]) << 32) | (int64(buf[i+5+7]) << 40) | (int64(buf[i+6+7]) << 48) | (int64(buf[i+7+7]) << 56)
}
{
d.UI8 = 0 | (uint8(buf[i+0+15]) << 0)
}
{
d.UI16 = 0 | (uint16(buf[i+0+16]) << 0) | (uint16(buf[i+1+16]) << 8)
}
{
d.UI32 = 0 | (uint32(buf[i+0+18]) << 0) | (uint32(buf[i+1+18]) << 8) | (uint32(buf[i+2+18]) << 16) | (uint32(buf[i+3+18]) << 24)
}
{
d.UI64 = 0 | (uint64(buf[i+0+22]) << 0) | (uint64(buf[i+1+22]) << 8) | (uint64(buf[i+2+22]) << 16) | (uint64(buf[i+3+22]) << 24) | (uint64(buf[i+4+22]) << 32) | (uint64(buf[i+5+22]) << 40) | (uint64(buf[i+6+22]) << 48) | (uint64(buf[i+7+22]) << 56)
}
{
l := uint64(0)
{
bs := uint8(7)
t := uint64(buf[i+30] & 0x7F)
for buf[i+30]&0x80 == 0x80 {
i++
t |= uint64(buf[i+30]&0x7F) << bs
bs += 7
}
i++
l = t
}
d.S = string(buf[i+30 : i+30+l])
i += l
}
{
if buf[i+30] == 1 {
if d.Sp == nil {
d.Sp = new(string)
}
{
l := uint64(0)
{
bs := uint8(7)
t := uint64(buf[i+31] & 0x7F)
for buf[i+31]&0x80 == 0x80 {
i++
t |= uint64(buf[i+31]&0x7F) << bs
bs += 7
}
i++
l = t
}
(*d.Sp) = string(buf[i+31 : i+31+l])
i += l
}
i += 0
} else {
d.Sp = nil
}
}
{
l := uint64(0)
{
bs := uint8(7)
t := uint64(buf[i+31] & 0x7F)
for buf[i+31]&0x80 == 0x80 {
i++
t |= uint64(buf[i+31]&0x7F) << bs
bs += 7
}
i++
l = t
}
if uint64(cap(d.Sa)) >= l {
d.Sa = d.Sa[:l]
} else {
d.Sa = make([]string, l)
}
for k0 := range d.Sa {
{
l := uint64(0)
{
bs := uint8(7)
t := uint64(buf[i+31] & 0x7F)
for buf[i+31]&0x80 == 0x80 {
i++
t |= uint64(buf[i+31]&0x7F) << bs
bs += 7
}
i++
l = t
}
d.Sa[k0] = string(buf[i+31 : i+31+l])
i += l
}
}
}
{
if buf[i+31] == 1 {
if d.Sap == nil {
d.Sap = new([]string)
}
{
l := uint64(0)
{
bs := uint8(7)
t := uint64(buf[i+32] & 0x7F)
for buf[i+32]&0x80 == 0x80 {
i++
t |= uint64(buf[i+32]&0x7F) << bs
bs += 7
}
i++
l = t
}
if uint64(cap((*d.Sap))) >= l {
(*d.Sap) = (*d.Sap)[:l]
} else {
(*d.Sap) = make([]string, l)
}
for k0 := range *d.Sap {
{
l := uint64(0)
{
bs := uint8(7)
t := uint64(buf[i+32] & 0x7F)
for buf[i+32]&0x80 == 0x80 {
i++
t |= uint64(buf[i+32]&0x7F) << bs
bs += 7
}
i++
l = t
}
(*d.Sap)[k0] = string(buf[i+32 : i+32+l])
i += l
}
}
}
i += 0
} else {
d.Sap = nil
}
}
{
d.B = buf[i+32]
}
{
if buf[i+33] == 1 {
if d.Bp == nil {
d.Bp = new(byte)
}
{
(*d.Bp) = buf[i+34]
}
i++
} else {
d.Bp = nil
}
}
{
l := uint64(0)
{
bs := uint8(7)
t := uint64(buf[i+34] & 0x7F)
for buf[i+34]&0x80 == 0x80 {
i++
t |= uint64(buf[i+34]&0x7F) << bs
bs += 7
}
i++
l = t
}
if uint64(cap(d.Ba)) >= l {
d.Ba = d.Ba[:l]
} else {
d.Ba = make([]byte, l)
}
copy(d.Ba, buf[i+34:])
i += l
}
{
if buf[i+34] == 1 {
if d.Bap == nil {
d.Bap = new([]byte)
}
{
l := uint64(0)
{
bs := uint8(7)
t := uint64(buf[i+35] & 0x7F)
for buf[i+35]&0x80 == 0x80 {
i++
t |= uint64(buf[i+35]&0x7F) << bs
bs += 7
}
i++
l = t
}
if uint64(cap((*d.Bap))) >= l {
(*d.Bap) = (*d.Bap)[:l]
} else {
(*d.Bap) = make([]byte, l)
}
copy((*d.Bap), buf[i+35:])
i += l
}
i += 0
} else {
d.Bap = nil
}
}
return i + 35, nil
}

View file

@ -0,0 +1,9 @@
package dsd
// GenCodeCompatible is an interface to identify and use gencode compatible structs.
type GenCodeCompatible interface {
// GenCodeMarshal gencode marshalls the struct into the given byte array, or a new one if its too small.
GenCodeMarshal(buf []byte) ([]byte, error)
// GenCodeUnmarshal gencode unmarshalls the struct and returns the bytes read.
GenCodeUnmarshal(buf []byte) (uint64, error)
}

23
formats/dsd/tests.gencode Normal file
View file

@ -0,0 +1,23 @@
struct SimpleTestStruct {
S string
B byte
}
struct GenCodeTestStructure {
I8 int8
I16 int16
I32 int32
I64 int64
UI8 uint8
UI16 uint16
UI32 uint32
UI64 uint64
S string
Sp *string
Sa []string
Sap *[]string
B byte
Bp *byte
Ba []byte
Bap *[]byte
}

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package varint
import "errors"

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package varint
import (

View file

@ -3,11 +3,11 @@ package log
import "flag"
var (
logLevelFlag string
fileLogLevelsFlag string
logLevelFlag string
pkgLogLevelsFlag string
)
func init() {
flag.StringVar(&logLevelFlag, "log", "info", "set log level to [trace|debug|info|warning|error|critical]")
flag.StringVar(&fileLogLevelsFlag, "flog", "", "set log level of files: database=trace,firewall=debug")
flag.StringVar(&pkgLogLevelsFlag, "plog", "", "set log level of packages: database=trace,firewall=debug")
}

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package log
import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package log
import (
@ -11,7 +9,7 @@ import (
)
func fastcheck(level severity) bool {
if fileLevelsActive.IsSet() {
if pkgLevelsActive.IsSet() {
return true
}
if uint32(level) < atomic.LoadUint32(logLevel) {
@ -33,7 +31,7 @@ func log(level severity, msg string, trace *ContextTracer) {
}
// check if level is enabled
if !fileLevelsActive.IsSet() && uint32(level) < atomic.LoadUint32(logLevel) {
if !pkgLevelsActive.IsSet() && uint32(level) < atomic.LoadUint32(logLevel) {
return
}
@ -54,12 +52,12 @@ func log(level severity, msg string, trace *ContextTracer) {
}
// check if level is enabled for file or generally
if fileLevelsActive.IsSet() {
if pkgLevelsActive.IsSet() {
fileOnly := strings.Split(file, "/")
if len(fileOnly) < 2 {
return
}
sev, ok := fileLevels[fileOnly[len(fileOnly)-2]]
sev, ok := pkgLevels[fileOnly[len(fileOnly)-2]]
if ok {
if level < sev {
return

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package log
import (
@ -75,9 +73,9 @@ var (
logLevelInt = uint32(3)
logLevel = &logLevelInt
fileLevelsActive = abool.NewBool(false)
fileLevels = make(map[string]severity)
fileLevelsLock sync.Mutex
pkgLevelsActive = abool.NewBool(false)
pkgLevels = make(map[string]severity)
pkgLevelsLock sync.Mutex
logsWaiting = make(chan bool, 1)
logsWaitingFlag = abool.NewBool(false)
@ -92,15 +90,15 @@ var (
testErrors = abool.NewBool(false)
)
func SetFileLevels(levels map[string]severity) {
fileLevelsLock.Lock()
fileLevels = levels
fileLevelsLock.Unlock()
fileLevelsActive.Set()
func SetPkgLevels(levels map[string]severity) {
pkgLevelsLock.Lock()
pkgLevels = levels
pkgLevelsLock.Unlock()
pkgLevelsActive.Set()
}
func UnSetFileLevels() {
fileLevelsActive.UnSet()
func UnSetPkgLevels() {
pkgLevelsActive.UnSet()
}
func SetLogLevel(level severity) {
@ -143,10 +141,10 @@ func Start() (err error) {
}
// get and set file loglevels
fileLogLevels := fileLogLevelsFlag
if len(fileLogLevels) > 0 {
newFileLevels := make(map[string]severity)
for _, pair := range strings.Split(fileLogLevels, ",") {
pkgLogLevels := pkgLogLevelsFlag
if len(pkgLogLevels) > 0 {
newPkgLevels := make(map[string]severity)
for _, pair := range strings.Split(pkgLogLevels, ",") {
splitted := strings.Split(pair, "=")
if len(splitted) != 2 {
err = fmt.Errorf("log warning: invalid file log level \"%s\", ignoring", pair)
@ -159,9 +157,9 @@ func Start() (err error) {
fmt.Fprintf(os.Stderr, "%s\n", err.Error())
break
}
newFileLevels[splitted[0]] = fileLevel
newPkgLevels[splitted[0]] = fileLevel
}
SetFileLevels(newFileLevels)
SetPkgLevels(newPkgLevels)
}
startWriter()

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package log
import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package log
import (

View file

@ -3,11 +3,12 @@ package modules
import "flag"
var (
helpFlag bool
// HelpFlag triggers printing flag.Usage. It's exported for custom help handling.
HelpFlag bool
)
func init() {
flag.BoolVar(&helpFlag, "help", false, "print help")
flag.BoolVar(&HelpFlag, "help", false, "print help")
}
func parseFlags() error {
@ -15,7 +16,7 @@ func parseFlags() error {
// parse flags
flag.Parse()
if helpFlag {
if HelpFlag {
flag.Usage()
return ErrCleanExit
}

View file

@ -1,12 +1,14 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package modules
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/safing/portbase/log"
"github.com/tevino/abool"
)
@ -20,33 +22,104 @@ var (
// Module represents a module.
type Module struct {
Name string
Name string
// lifecycle mgmt
Prepped *abool.AtomicBool
Started *abool.AtomicBool
Stopped *abool.AtomicBool
inTransition *abool.AtomicBool
// lifecycle callback functions
prep func() error
start func() error
stop func() error
// shutdown mgmt
Ctx context.Context
cancelCtx func()
shutdownFlag *abool.AtomicBool
workerGroup sync.WaitGroup
workerCnt *int32
// dependency mgmt
depNames []string
depModules []*Module
depReverse []*Module
}
// AddWorkers adds workers to the worker waitgroup. This is a failsafe wrapper for sync.Waitgroup.
func (m *Module) AddWorkers(n uint) {
if !m.ShutdownInProgress() {
if atomic.AddInt32(m.workerCnt, int32(n)) > 0 {
// only add to workgroup if cnt is positive (try to compensate wrong usage)
m.workerGroup.Add(int(n))
}
}
}
// FinishWorker removes a worker from the worker waitgroup. This is a failsafe wrapper for sync.Waitgroup.
func (m *Module) FinishWorker() {
// check worker cnt
if atomic.AddInt32(m.workerCnt, -1) < 0 {
log.Warningf("modules: %s module tried to finish more workers than added, this may lead to undefined behavior when shutting down", m.Name)
return
}
// also mark worker done in workgroup
m.workerGroup.Done()
}
// ShutdownInProgress returns whether the module has started shutting down. In most cases, you should use ShuttingDown instead.
func (m *Module) ShutdownInProgress() bool {
return m.shutdownFlag.IsSet()
}
// ShuttingDown lets you listen for the shutdown signal.
func (m *Module) ShuttingDown() <-chan struct{} {
return m.Ctx.Done()
}
func (m *Module) shutdown() error {
// signal shutdown
m.shutdownFlag.Set()
m.cancelCtx()
// wait for workers
done := make(chan struct{})
go func() {
m.workerGroup.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(3 * time.Second):
return errors.New("timed out while waiting for module workers to finish")
}
// call shutdown function
return m.stop()
}
func dummyAction() error {
return nil
}
// Register registers a new module.
// Register registers a new module. The control functions `prep`, `start` and `stop` are technically optional. `stop` is called _after_ all added module workers finished.
func Register(name string, prep, start, stop func() error, dependencies ...string) *Module {
ctx, cancelCtx := context.WithCancel(context.Background())
var workerCnt int32
newModule := &Module{
Name: name,
Prepped: abool.NewBool(false),
Started: abool.NewBool(false),
Stopped: abool.NewBool(false),
inTransition: abool.NewBool(false),
Ctx: ctx,
cancelCtx: cancelCtx,
shutdownFlag: abool.NewBool(false),
workerGroup: sync.WaitGroup{},
workerCnt: &workerCnt,
prep: prep,
start: start,
stop: stop,
@ -77,7 +150,7 @@ func initDependencies() error {
// get dependency
depModule, ok := modules[depName]
if !ok {
return fmt.Errorf("modules: module %s declares dependency \"%s\", but this module has not been registered", m.Name, depName)
return fmt.Errorf("module %s declares dependency \"%s\", but this module has not been registered", m.Name, depName)
}
// link together

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package modules
import (
@ -204,11 +202,11 @@ func TestErrors(t *testing.T) {
startCompleteSignal = make(chan struct{})
// test help flag
helpFlag = true
HelpFlag = true
err = Start()
if err == nil {
t.Error("should fail")
}
helpFlag = false
HelpFlag = false
}

View file

@ -38,6 +38,7 @@ func Start() error {
// inter-link modules
err := initDependencies()
if err != nil {
fmt.Fprintf(os.Stderr, "CRITICAL ERROR: failed to initialize modules: %s\n", err)
return err
}

View file

@ -74,7 +74,7 @@ func stopModules() error {
go func() {
reports <- &report{
module: execM,
err: execM.stop(),
err: execM.shutdown(),
}
}()
}

View file

@ -2,6 +2,8 @@ package notifications
import (
"time"
"github.com/safing/portbase/log"
)
func cleaner() {
@ -10,31 +12,55 @@ func cleaner() {
case <-shutdownSignal:
shutdownWg.Done()
return
case <-time.After(1 * time.Minute):
case <-time.After(5 * time.Second):
cleanNotifications()
}
}
func cleanNotifications() {
threshold := time.Now().Add(-2 * time.Minute).Unix()
maxThreshold := time.Now().Add(-72 * time.Hour).Unix()
now := time.Now().Unix()
finishedThreshhold := time.Now().Add(-10 * time.Second).Unix()
executionTimelimit := time.Now().Add(-24 * time.Hour).Unix()
fallbackTimelimit := time.Now().Add(-72 * time.Hour).Unix()
notsLock.Lock()
defer notsLock.Unlock()
for _, n := range nots {
n.Lock()
if n.Expires != 0 && n.Expires < threshold ||
n.Executed != 0 && n.Executed < threshold ||
n.Created < maxThreshold {
switch {
case n.Executed != 0: // notification was fully handled
// wait for a short time before deleting
if n.Executed < finishedThreshhold {
go deleteNotification(n)
}
case n.Responded != 0:
// waiting for execution
if n.Responded < executionTimelimit {
go deleteNotification(n)
}
case n.Expires != 0:
// expired without response
if n.Expires < now {
go deleteNotification(n)
}
case n.Created != 0:
// fallback: delete after 3 days after creation
if n.Created < fallbackTimelimit {
go deleteNotification(n)
// delete
n.Meta().Delete()
delete(nots, n.ID)
// save (ie. propagate delete)
go n.Save()
}
default:
// invalid, impossible to determine cleanup timeframe, delete now
go deleteNotification(n)
}
n.Unlock()
}
}
func deleteNotification(n *Notification) {
err := n.Delete()
if err != nil {
log.Debugf("notifications: failed to delete %s: %s", n.ID, err)
}
}

View file

@ -43,6 +43,26 @@ type StorageInterface struct {
storage.InjectBase
}
func registerAsDatabase() error {
_, err := database.Register(&database.Database{
Name: "notifications",
Description: "Notifications",
StorageType: "injected",
PrimaryAPI: "",
})
if err != nil {
return err
}
controller, err := database.InjectDatabase("notifications", &StorageInterface{})
if err != nil {
return err
}
dbController = controller
return nil
}
// Get returns a database record.
func (s *StorageInterface) Get(key string) (record.Record, error) {
notsLock.RLock()
@ -78,6 +98,10 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) {
// send all notifications
for _, n := range nots {
if n.Meta().IsDeleted() {
continue
}
if q.MatchesKey(n.DatabaseKey()) && q.MatchesRecord(n) {
it.Next <- n
}
@ -86,26 +110,6 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) {
it.Finish(nil)
}
func registerAsDatabase() error {
_, err := database.Register(&database.Database{
Name: "notifications",
Description: "Notifications",
StorageType: "injected",
PrimaryAPI: "",
})
if err != nil {
return err
}
controller, err := database.InjectDatabase("notifications", &StorageInterface{})
if err != nil {
return err
}
dbController = controller
return nil
}
// Put stores a record in the database.
func (s *StorageInterface) Put(r record.Record) error {
// record is already locked!
@ -124,36 +128,75 @@ func (s *StorageInterface) Put(r record.Record) error {
}
// continue in goroutine
go updateNotificationFromDatabasePut(n, key)
go UpdateNotification(n, key)
return nil
}
func updateNotificationFromDatabasePut(n *Notification, key string) {
// UpdateNotification updates a notification with input from a database action. Notification will not be saved/propagated if there is no valid change.
func UpdateNotification(n *Notification, key string) {
n.Lock()
defer n.Unlock()
// seperate goroutine in order to correctly lock notsLock
notsLock.RLock()
origN, ok := nots[key]
notsLock.RUnlock()
save := false
// ignore if already deleted
if ok && origN.Meta().IsDeleted() {
ok = false
}
if ok {
// existing notification, update selected action ID only
n.Lock()
defer n.Unlock()
if n.SelectedActionID != "" {
log.Tracef("notifications: user selected action for %s: %s", n.ID, n.SelectedActionID)
go origN.SelectAndExecuteAction(n.SelectedActionID)
}
// existing notification
// only update select attributes
origN.Lock()
defer origN.Unlock()
} else {
// accept new notification as is
notsLock.Lock()
nots[key] = n
notsLock.Unlock()
// new notification (from external source): old == new
origN = n
save = true
}
switch {
case n.SelectedActionID != "" && n.Responded == 0:
// select action, if not yet already handled
log.Tracef("notifications: selected action for %s: %s", n.ID, n.SelectedActionID)
origN.selectAndExecuteAction(n.SelectedActionID)
save = true
case origN.Executed == 0 && n.Executed != 0:
log.Tracef("notifications: action for %s executed externally", n.ID)
origN.Executed = n.Executed
save = true
}
if save {
// we may be locking
go origN.Save()
}
}
// Delete deletes a record from the database.
func (s *StorageInterface) Delete(key string) error {
return ErrNoDelete
// transform key
if strings.HasPrefix(key, "all/") {
key = strings.TrimPrefix(key, "all/")
} else {
return storage.ErrNotFound
}
// get notification
notsLock.Lock()
n, ok := nots[key]
notsLock.Unlock()
if !ok {
return storage.ErrNotFound
}
// delete
return n.Delete()
}
// ReadOnly returns whether the database is read only.

View file

@ -12,7 +12,7 @@ var (
)
func init() {
modules.Register("notifications", nil, start, nil, "core")
modules.Register("notifications", nil, start, nil, "base", "database")
}
func start() error {

View file

@ -5,8 +5,11 @@ import (
"sync"
"time"
"github.com/safing/portbase/database"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/log"
uuid "github.com/satori/go.uuid"
)
// Notification types
@ -20,7 +23,9 @@ const (
type Notification struct {
record.Base
ID string
ID string
GUID string
Message string
// MessageTemplate string
// MessageData []string
@ -39,6 +44,7 @@ type Notification struct {
lock sync.Mutex
actionFunction func(*Notification) // call function to process action
actionTrigger chan string // and/or send to a channel
expiredTrigger chan struct{} // closed on expire
}
// Action describes an action that can be taken for a notification.
@ -62,12 +68,6 @@ func Get(id string) *Notification {
return nil
}
// Init initializes a Notification and returns it.
func (n *Notification) Init() *Notification {
n.Created = time.Now().Unix()
return n
}
// Save saves the notification and returns it.
func (n *Notification) Save() *Notification {
notsLock.Lock()
@ -75,6 +75,13 @@ func (n *Notification) Save() *Notification {
n.Lock()
defer n.Unlock()
// initialize
if n.Created == 0 {
n.Created = time.Now().Unix()
}
if n.GUID == "" {
n.GUID = uuid.NewV4().String()
}
// check key
if n.DatabaseKey() == "" {
n.SetKey(fmt.Sprintf("notifications:all/%s", n.ID))
@ -104,11 +111,12 @@ func (n *Notification) Save() *Notification {
Executed: n.Executed,
}
duplicate.SetMeta(n.Meta().Duplicate())
duplicate.SetKey(fmt.Sprintf("%s/%s", persistentBasePath, n.ID))
key := fmt.Sprintf("%s/%s", persistentBasePath, n.ID)
duplicate.SetKey(key)
go func() {
err := dbInterface.Put(duplicate)
if err != nil {
log.Warningf("notifications: failed to persist notification %s: %s", n.Key(), err)
log.Warningf("notifications: failed to persist notification %s: %s", key, err)
}
}()
}
@ -145,42 +153,85 @@ func (n *Notification) MakeAck() *Notification {
// Response waits for the user to respond to the notification and returns the selected action.
func (n *Notification) Response() <-chan string {
n.lock.Lock()
defer n.lock.Unlock()
if n.actionTrigger == nil {
n.actionTrigger = make(chan string)
}
n.lock.Unlock()
return n.actionTrigger
}
// Cancel (prematurely) destroys a notification.
func (n *Notification) Cancel() {
// Update updates/resends a notification if it was not already responded to.
func (n *Notification) Update(expires int64) {
responded := true
n.lock.Lock()
if n.Responded == 0 {
responded = false
n.Expires = expires
}
n.lock.Unlock()
// save if not yet responded
if !responded {
n.Save()
}
}
// Delete (prematurely) cancels and deletes a notification.
func (n *Notification) Delete() error {
notsLock.Lock()
defer notsLock.Unlock()
n.Lock()
defer n.Unlock()
// delete
// mark as deleted
n.Meta().Delete()
// delete from internal storage
delete(nots, n.ID)
// save (ie. propagate delete)
go n.Save()
// close expired
if n.expiredTrigger != nil {
close(n.expiredTrigger)
n.expiredTrigger = nil
}
// push update
dbController.PushUpdate(n)
// delete from persistent storage
if n.Persistent && persistentBasePath != "" {
key := fmt.Sprintf("%s/%s", persistentBasePath, n.ID)
err := dbInterface.Delete(key)
if err != nil && err != database.ErrNotFound {
return fmt.Errorf("failed to delete persisted notification %s from database: %s", key, err)
}
}
return nil
}
// SelectAndExecuteAction sets the user response and executes/triggers the action, if possible.
func (n *Notification) SelectAndExecuteAction(id string) {
n.Lock()
defer n.Unlock()
// Expired notifies the caller when the notification has expired.
func (n *Notification) Expired() <-chan struct{} {
n.lock.Lock()
if n.expiredTrigger == nil {
n.expiredTrigger = make(chan struct{})
}
n.lock.Unlock()
// update selection
return n.expiredTrigger
}
// selectAndExecuteAction sets the user response and executes/triggers the action, if possible.
func (n *Notification) selectAndExecuteAction(id string) {
// abort if already executed
if n.Executed != 0 {
// we already executed
return
}
n.SelectedActionID = id
// set response
n.Responded = time.Now().Unix()
n.SelectedActionID = id
// execute
executed := false
@ -195,7 +246,7 @@ func (n *Notification) SelectAndExecuteAction(id string) {
select {
case n.actionTrigger <- n.SelectedActionID:
executed = true
default:
case <-time.After(100 * time.Millisecond): // mitigate race conditions
break triggerAll
}
}
@ -205,8 +256,6 @@ func (n *Notification) SelectAndExecuteAction(id string) {
if executed {
n.Executed = time.Now().Unix()
}
go n.Save()
}
// AddDataSubject adds the data subject to the notification. This is the only way how a data subject should be added - it avoids locking problems.

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package taskmanager
import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package taskmanager
import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package taskmanager
import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package taskmanager
import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package taskmanager
import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package taskmanager
import (

View file

@ -6,6 +6,8 @@ import (
"runtime"
)
const isWindows = runtime.GOOS == "windows"
// EnsureDirectory ensures that the given directoy exists and that is has the given permissions set.
// If path is a file, it is deleted and a directory created.
// If a directory is created, also all missing directories up to the required one are created with the given permissions.
@ -16,7 +18,7 @@ func EnsureDirectory(path string, perm os.FileMode) error {
// file exists
if f.IsDir() {
// directory exists, check permissions
if runtime.GOOS == "windows" {
if isWindows {
// TODO: set correct permission on windows
// acl.Chmod(path, perm)
} else if f.Mode().Perm() != perm {

View file

@ -20,12 +20,13 @@ func EnableColorSupport() bool {
if !colorSupportChecked {
colorSupport = enableColorSupport()
colorSupportChecked = true
}
return colorSupport
}
func enableColorSupport() bool {
if IsWindowsVersion("10.") {
if IsAtLeastWindowsNTVersionWithDefault("10", false) {
// check if windows.Stdout is file
if windows.GetFileInformationByHandle(windows.Stdout, &windows.ByHandleFileInformation{}) == nil {

View file

@ -1,56 +1,100 @@
package osdetail
import (
"os/exec"
"fmt"
"regexp"
"strings"
"sync"
)
// FIXME: use https://godoc.org/github.com/shirou/gopsutil/host#PlatformInformation instead
"github.com/hashicorp/go-version"
versionCmp "github.com/hashicorp/go-version"
"github.com/shirou/gopsutil/host"
)
var (
versionRe = regexp.MustCompile(`[0-9\.]+`)
windowsVersion string
windowsNTVersion string
windowsNTVersionForCmp *versionCmp.Version
fetching sync.Mutex
fetched bool
)
func fetchVersion() {
// WindowsNTVersion returns the current Windows version.
func WindowsNTVersion() (string, error) {
var err error
fetching.Lock()
defer fetching.Unlock()
if !fetched {
fetched = true
_, _, windowsNTVersion, err = host.PlatformInformation()
output, err := exec.Command("cmd", "ver").Output()
if err != nil {
return
return "", fmt.Errorf("failed to obtain Windows-Version: %s", err)
}
match := versionRe.Find(output)
if match == nil {
return
windowsNTVersionForCmp, err = version.NewVersion(windowsNTVersion)
if err != nil {
return "", fmt.Errorf("failed to parse Windows-Version %s: %s", windowsNTVersion, err)
}
windowsVersion = string(match)
fetched = true
}
return windowsNTVersion, err
}
// WindowsVersion returns the current Windows version.
func WindowsVersion() string {
fetching.Lock()
defer fetching.Unlock()
fetchVersion()
// IsAtLeastWindowsNTVersion returns whether the current WindowsNT version is at least the given version or newer.
func IsAtLeastWindowsNTVersion(version string) (bool, error) {
_, err := WindowsNTVersion()
if err != nil {
return false, err
}
return windowsVersion
versionForCmp, err := versionCmp.NewVersion(version)
if err != nil {
return false, err
}
return windowsNTVersionForCmp.GreaterThanOrEqual(versionForCmp), nil
}
// IsWindowsVersion returns whether the given version matches (HasPrefix) the current Windows version.
func IsWindowsVersion(version string) bool {
fetching.Lock()
defer fetching.Unlock()
fetchVersion()
// TODO: we can do better.
return strings.HasPrefix(windowsVersion, version)
// IsAtLeastWindowsNTVersionWithDefault is like IsAtLeastWindowsNTVersion(), but keeps the Error and returns the default Value in Errorcase
func IsAtLeastWindowsNTVersionWithDefault(v string, defaultValue bool) bool {
val, err := IsAtLeastWindowsNTVersion(v)
if err != nil {
return defaultValue
}
return val
}
// IsAtLeastWindowsVersion returns whether the current Windows version is at least the given version or newer.
func IsAtLeastWindowsVersion(v string) (bool, error) {
var (
NTVersion string
)
switch v {
case "7":
NTVersion = "6.1"
case "8":
NTVersion = "6.2"
case "8.1":
NTVersion = "6.3"
case "10":
NTVersion = "10"
default:
return false, fmt.Errorf("failed to compare Windows-Version: Windows %s is unknown", v)
}
return IsAtLeastWindowsNTVersion(NTVersion)
}
// IsAtLeastWindowsVersionWithDefault is like IsAtLeastWindowsVersion(), but keeps the Error and returns the default Value in Errorcase
func IsAtLeastWindowsVersionWithDefault(v string, defaultValue bool) bool {
val, err := IsAtLeastWindowsVersion(v)
if err != nil {
return defaultValue
}
return val
}

View file

@ -2,8 +2,28 @@ package osdetail
import "testing"
func TestWindowsVersion(t *testing.T) {
if WindowsVersion() == "" {
t.Fatal("could not get windows version")
func TestWindowsNTVersion(t *testing.T) {
if str, err := WindowsNTVersion(); str == "" || err != nil {
t.Fatalf("failed to obtain windows version: %s", err)
}
}
func TestIsAtLeastWindowsNTVersion(t *testing.T) {
ret, err := IsAtLeastWindowsNTVersion("6")
if err != nil {
t.Fatalf("failed to compare windows versions: %s", err)
}
if !ret {
t.Fatalf("WindowsNTVersion is less than 6 (Vista)")
}
}
func TestIsAtLeastWindowsVersion(t *testing.T) {
ret, err := IsAtLeastWindowsVersion("7")
if err != nil {
t.Fatalf("failed to compare windows versions: %s", err)
}
if !ret {
t.Fatalf("WindowsVersion is less than 7")
}
}

139
utils/structure.go Normal file
View file

@ -0,0 +1,139 @@
package utils
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
)
// DirStructure represents a directory structure with permissions that should be enforced.
type DirStructure struct {
sync.Mutex
Path string
Dir string
Perm os.FileMode
Parent *DirStructure
Children map[string]*DirStructure
}
// NewDirStructure returns a new DirStructure.
func NewDirStructure(path string, perm os.FileMode) *DirStructure {
return &DirStructure{
Path: path,
Perm: perm,
Children: make(map[string]*DirStructure),
}
}
// ChildDir adds a new child DirStructure and returns it. Should the child already exist, the existing child is returned and the permissions are updated.
func (ds *DirStructure) ChildDir(dirName string, perm os.FileMode) (child *DirStructure) {
ds.Lock()
defer ds.Unlock()
// if exists, update
child, ok := ds.Children[dirName]
if ok {
child.Perm = perm
return child
}
// create new
new := &DirStructure{
Path: filepath.Join(ds.Path, dirName),
Dir: dirName,
Perm: perm,
Parent: ds,
Children: make(map[string]*DirStructure),
}
ds.Children[dirName] = new
return new
}
// Ensure ensures that the specified directory structure (from the first parent on) exists.
func (ds *DirStructure) Ensure() error {
return ds.EnsureAbsPath(ds.Path)
}
// EnsureRelPath ensures that the specified directory structure (from the first parent on) and the given relative path (to the DirStructure) exists.
func (ds *DirStructure) EnsureRelPath(dirPath string) error {
return ds.EnsureAbsPath(filepath.Join(ds.Path, dirPath))
}
// EnsureRelDir ensures that the specified directory structure (from the first parent on) and the given relative path (to the DirStructure) exists.
func (ds *DirStructure) EnsureRelDir(dirNames ...string) error {
return ds.EnsureAbsPath(filepath.Join(append([]string{ds.Path}, dirNames...)...))
}
// EnsureAbsPath ensures that the specified directory structure (from the first parent on) and the given absolute path exists.
// If the given path is outside the DirStructure, an error will be returned.
func (ds *DirStructure) EnsureAbsPath(dirPath string) error {
// always start at the top
if ds.Parent != nil {
return ds.Parent.EnsureAbsPath(dirPath)
}
// check if root
if dirPath == ds.Path {
return ds.ensure(nil)
}
// check scope
slashedPath := ds.Path
// add slash to end
if !strings.HasSuffix(slashedPath, string(filepath.Separator)) {
slashedPath += string(filepath.Separator)
}
// check if given path is in scope
if !strings.HasPrefix(dirPath, slashedPath) {
return fmt.Errorf(`path "%s" is outside of DirStructure scope`, dirPath)
}
// get relative path
relPath, err := filepath.Rel(ds.Path, dirPath)
if err != nil {
return fmt.Errorf("failed to get relative path: %s", err)
}
// split to path elements
pathDirs := strings.Split(filepath.ToSlash(relPath), "/")
// start checking
return ds.ensure(pathDirs)
}
func (ds *DirStructure) ensure(pathDirs []string) error {
ds.Lock()
defer ds.Unlock()
// check current dir
err := EnsureDirectory(ds.Path, ds.Perm)
if err != nil {
return err
}
if len(pathDirs) == 0 {
// we reached the end!
return nil
}
child, ok := ds.Children[pathDirs[0]]
if !ok {
// we have reached the end of the defined dir structure
// ensure all remaining dirs
dirPath := ds.Path
for _, dir := range pathDirs {
dirPath = filepath.Join(dirPath, dir)
err := EnsureDirectory(dirPath, ds.Perm)
if err != nil {
return err
}
}
return nil
}
// we got a child, continue
return child.ensure(pathDirs[1:])
}

72
utils/structure_test.go Normal file
View file

@ -0,0 +1,72 @@
// +build !windows
package utils
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strings"
)
func ExampleDirStructure() {
// output:
// / [755]
// /repo [777]
// /repo/b [755]
// /repo/b/c [750]
// /repo/b/d [755]
// /repo/b/d/e [755]
// /repo/b/d/f [755]
// /secret [700]
basePath, err := ioutil.TempDir("", "")
if err != nil {
fmt.Println(err)
return
}
ds := NewDirStructure(basePath, 0755)
secret := ds.ChildDir("secret", 0700)
repo := ds.ChildDir("repo", 0777)
_ = repo.ChildDir("a", 0700)
b := repo.ChildDir("b", 0755)
c := b.ChildDir("c", 0750)
err = ds.Ensure()
if err != nil {
fmt.Println(err)
}
err = c.Ensure()
if err != nil {
fmt.Println(err)
}
err = secret.Ensure()
if err != nil {
fmt.Println(err)
}
err = b.EnsureRelDir("d", "e")
if err != nil {
fmt.Println(err)
}
err = b.EnsureRelPath("d/f")
if err != nil {
fmt.Println(err)
}
filepath.Walk(basePath, func(path string, info os.FileInfo, err error) error {
if err == nil {
dir := strings.TrimPrefix(path, basePath)
if dir == "" {
dir = "/"
}
fmt.Printf("%s [%o]\n", dir, info.Mode().Perm())
}
return nil
})
}