diff --git a/Gopkg.lock b/Gopkg.lock index ca14576..eb1d34d 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -9,6 +9,14 @@ pruneopts = "UT" revision = "e2d15f34fcf99d5dbb871c820ec73f710fca9815" +[[projects]] + digest = "1:e92f5581902c345eb4ceffdcd4a854fb8f73cf436d47d837d1ec98ef1fe0a214" + name = "github.com/StackExchange/wmi" + packages = ["."] + pruneopts = "UT" + revision = "5d049714c4a64225c3c79a7cf7d02f7fb5b96338" + version = "1.0.0" + [[projects]] digest = "1:dbd3a713434b6f32d9459b1e6786ad58cec128470b58555cdc7b3b7314a1706f" name = "github.com/aead/serpent" @@ -34,19 +42,19 @@ version = "v1.1.1" [[projects]] - digest = "1:5f5090f05382959db941fa45acbeb7f4c5241aa8ac0f8f4393dec696e5953f53" + digest = "1:6be8582a4f52ba2851d8a039eb9c3a3b90334b2820563d71e97de35580da128e" name = "github.com/dgraph-io/badger" packages = [ ".", "options", - "protos", + "pb", "skl", "table", "y", ] pruneopts = "UT" - revision = "99233d725dbdd26d156c61b2f42ae1671b794656" - version = "v1.5.4" + revision = "2fa005c9d4bf695277ab5214c1fbce3735b9562a" + version = "v1.6.0" [[projects]] branch = "master" @@ -57,12 +65,31 @@ revision = "6a90982ecee230ff6cba02d5bd386acc030be9d3" [[projects]] - digest = "1:318f1c959a8a740366fce4b1e1eb2fd914036b4af58fbd0a003349b305f118ad" + digest = "1:6f9339c912bbdda81302633ad7e99a28dfa5a639c864061f1929510a9a64aa74" + name = "github.com/dustin/go-humanize" + packages = ["."] + pruneopts = "UT" + revision = "9f541cc9db5d55bce703bd99987c9d5cb8eea45e" + version = "v1.0.0" + +[[projects]] + digest = "1:440028f55cb322d8cb5b9d5ebec298a00b7d74690a658fe6b1c0c0b44341bfae" + name = "github.com/go-ole/go-ole" + packages = [ + ".", + "oleutil", + ] + pruneopts = "UT" + revision = "97b6244175ae18ea6eef668034fd6565847501c9" + version = "v1.2.4" + +[[projects]] + digest = "1:573ca21d3669500ff845bdebee890eb7fc7f0f50c59f2132f2a0c6b03d85086a" name = "github.com/golang/protobuf" packages = ["proto"] pruneopts = "UT" - revision = "b5d812f8a3706043e23a9cd5babf2e5423744d30" - version = "v1.3.1" + revision = "6c65a5562fc06764971b7c5d05c76c75e84bdbf7" + version = "v1.3.2" [[projects]] digest = "1:c3388642e07731a240e14f4bc7207df59cfcc009447c657b9de87fec072d07e3" @@ -88,9 +115,17 @@ revision = "66b9c49e59c6c48f0ffce28c2d8b8a5678502c6d" version = "v1.4.0" +[[projects]] + digest = "1:88e0b0baeb9072f0a4afbcf12dda615fc8be001d1802357538591155998da21b" + name = "github.com/hashicorp/go-version" + packages = ["."] + pruneopts = "UT" + revision = "ac23dc3fea5d1a983c43f6a0f6e2c13f0195d8bd" + version = "v1.2.0" + [[projects]] branch = "master" - digest = "1:ae08d850ba158ea3ba4a7bb90f8372608172d8920644e5a6693b940a1f4e5d01" + digest = "1:7e8b852581596acce37bcb939a05d7d5ff27156045b50057e659e299c16fc1ca" name = "github.com/mmcloughlin/avo" packages = [ "attr", @@ -108,7 +143,7 @@ "x86", ] pruneopts = "UT" - revision = "83fbad1a6b3cba8ac7711170e57953fd12cdc40a" + revision = "bb615f61ce85790a1667efc145c66e917cce1a39" [[projects]] digest = "1:cf31692c14422fa27c83a05292eb5cbe0fb2775972e8f1f8446a71549bd8980b" @@ -118,6 +153,22 @@ revision = "ba968bfe8b2f7e042a574c888954fccecfa385b4" version = "v0.8.1" +[[projects]] + branch = "develop" + digest = "1:d88649ff4a4a0746857dd9e39915aedddce2b08e442ac131a91e573cd45bde93" + name = "github.com/safing/portmaster" + packages = ["core/structure"] + pruneopts = "UT" + revision = "26c307b7a0db78d91b35ef9020706f106ebef8b6" + +[[projects]] + digest = "1:274f67cb6fed9588ea2521ecdac05a6d62a8c51c074c1fccc6a49a40ba80e925" + name = "github.com/satori/go.uuid" + packages = ["."] + pruneopts = "UT" + revision = "f58768cc1a7a7e77a3bd49e98cdd21419399b6a3" + version = "v1.2.0" + [[projects]] branch = "master" digest = "1:f1ee4af7c43f206d87f13644636e3710a05e499a084a32ec2cc7d8aa25cef1aa" @@ -134,6 +185,29 @@ pruneopts = "UT" revision = "e9f377c596061894b7f9ee69aab61e62c3ccc13e" +[[projects]] + digest = "1:7ebe52387a847d276a9b9e7b379b2e3e4d536843ad840a9c4426a0152e6b3d54" + name = "github.com/shirou/gopsutil" + packages = [ + "cpu", + "host", + "internal/common", + "mem", + "net", + "process", + ] + pruneopts = "UT" + revision = "d80c43f9c984a48783daf22f4bd9278006ae483a" + version = "v2.19.7" + +[[projects]] + branch = "master" + digest = "1:99c6a6dab47067c9b898e8c8b13d130c6ab4ffbcc4b7cc6236c2cd0b1e344f5b" + name = "github.com/shirou/w32" + packages = ["."] + pruneopts = "UT" + revision = "bb4de0191aa41b5507caa14b0650cdbddcd9280b" + [[projects]] branch = "master" digest = "1:93d6687fc19da8a35c7352d72117a6acd2072dfb7e9bfd65646227bf2a913b2a" @@ -143,12 +217,12 @@ revision = "9b9efcf221b50905aab9bbabd3daed56dc10f339" [[projects]] - digest = "1:7df351557a6d5c30804e7d6f7ed87f2fccb0619c08fcc84869a93f22bec96c11" + digest = "1:46aea6ffe39c3d95c13640c2515236ed3c5cdffc3c78c6c0ed4edec2caf7a0dc" name = "github.com/tidwall/gjson" packages = ["."] pruneopts = "UT" - revision = "eee0b6226f0d1db2675a176fdfaa8419bcad4ca8" - version = "v1.2.1" + revision = "c5e72cdf74dff23857243dd662c465b810891c21" + version = "v1.3.2" [[projects]] digest = "1:8453ddbed197809ee8ca28b06bd04e127bec9912deb4ba451fea7a1eca578328" @@ -159,12 +233,12 @@ version = "v1.0.1" [[projects]] - branch = "master" digest = "1:ddfe0a54e5f9b29536a6d7b2defa376f2cb2b6e4234d676d7ff214d5b097cb50" name = "github.com/tidwall/pretty" packages = ["."] pruneopts = "UT" revision = "1166b9ac2b65e46a43d8618d30d1554f4652d49b" + version = "v1.0.0" [[projects]] digest = "1:b70c951ba6fdeecfbd50dabe95aa5e1b973866ae9abbece46ad60348112214f2" @@ -175,12 +249,12 @@ version = "v1.0.4" [[projects]] - digest = "1:5f7414cf41466d4b4dd7ec52b2cd3e481e08cfd11e7e24fef730c0e483e88bb1" + digest = "1:f2ac2c724fc8214bb7b9dd6d4f5b7a983152051f5133320f228557182263cb94" name = "go.etcd.io/bbolt" packages = ["."] pruneopts = "UT" - revision = "63597a96ec0ad9e6d43c3fc81e809909e0237461" - version = "v1.3.2" + revision = "a0458a2b35708eef59eb5f620ceb3cd1c01a824d" + version = "v1.3.3" [[projects]] branch = "master" @@ -192,7 +266,7 @@ "sha3", ] pruneopts = "UT" - revision = "20be4c3c3ed52bfccdb2d59a412ee1a936d175a7" + revision = "4def268fd1a49955bfb3dda92fe3db4f924f2285" [[projects]] branch = "master" @@ -203,11 +277,11 @@ "trace", ] pruneopts = "UT" - revision = "f3200d17e092c607f615320ecaad13d87ad9a2b3" + revision = "74dc4d7220e7acc4e100824340f3e66577424772" [[projects]] branch = "master" - digest = "1:8a3986af7a48f0991ce6168708859c56d39d2ff8b82b34d0805bbb545a9a32a6" + digest = "1:6c97b4baa9cf774b42192e23632afc7dee7b899d4a3dc98057d24708ce1f60ac" name = "golang.org/x/sys" packages = [ "cpu", @@ -215,11 +289,11 @@ "windows", ] pruneopts = "UT" - revision = "46560c3f3c0a091352115a3d825af45663b983d8" + revision = "fde4db37ae7ad8191b03d30d27f258b5291ae4e3" [[projects]] branch = "master" - digest = "1:47717932fbd4293f80d296cd65254cd9c2ec9e51d8b4227b6254440cfa34da2a" + digest = "1:b3e64b43fe77039813b4f33100e4ef2e17960309f3c6c9ab120f5c55de747992" name = "golang.org/x/tools" packages = [ "go/ast/astutil", @@ -233,7 +307,7 @@ "internal/semver", ] pruneopts = "UT" - revision = "75312fb06703a759656ae75b3d1c24b4aae95dfe" + revision = "caa95bb40b630f80d344d1f710f7e39be971d3e8" [solve-meta] analyzer-name = "dep" @@ -246,7 +320,11 @@ "github.com/google/renameio", "github.com/gorilla/mux", "github.com/gorilla/websocket", + "github.com/hashicorp/go-version", + "github.com/safing/portmaster/core/structure", + "github.com/satori/go.uuid", "github.com/seehuhn/fortuna", + "github.com/shirou/gopsutil/host", "github.com/tevino/abool", "github.com/tidwall/gjson", "github.com/tidwall/sjson", diff --git a/Gopkg.toml b/Gopkg.toml index c597e4c..0e9be2a 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -84,3 +84,15 @@ [prune] go-tests = true unused-packages = true + +[[constraint]] + name = "github.com/satori/go.uuid" + version = "1.2.0" + +[[constraint]] + name = "github.com/shirou/gopsutil" + version = "2.19.6" + +[[constraint]] + name = "github.com/hashicorp/go-version" + version = "1.2.0" diff --git a/api/authentication.go b/api/authentication.go new file mode 100644 index 0000000..0e53e4e --- /dev/null +++ b/api/authentication.go @@ -0,0 +1,110 @@ +package api + +import ( + "encoding/base64" + "net/http" + "sync" + "time" + + "github.com/safing/portbase/crypto/random" + "github.com/safing/portbase/log" +) + +var ( + validTokens map[string]time.Time + validTokensLock sync.Mutex + + authFnLock sync.Mutex + authFn Authenticator +) + +const ( + cookieName = "T17" + + // in seconds + cookieBaseTTL = 300 // 5 minutes + cookieTTL = cookieBaseTTL * time.Second + cookieRefresh = cookieBaseTTL * 0.9 * time.Second +) + +// Authenticator is a function that can be set as the authenticator for the API endpoint. If none is set, all requests will be allowed. +type Authenticator func(s *http.Server, r *http.Request) (grantAccess bool, err error) + +// SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be allowed. +func SetAuthenticator(fn Authenticator) error { + authFnLock.Lock() + defer authFnLock.Unlock() + + if authFn == nil { + authFn = fn + return nil + } + + return ErrAuthenticationAlreadySet +} + +func authMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + // check existing auth cookie + c, err := r.Cookie(cookieName) + if err == nil { + // get token + validTokensLock.Lock() + validUntil, valid := validTokens[c.Value] + validTokensLock.Unlock() + + // check if token is valid + if valid && time.Now().Before(validUntil) { + // maybe refresh cookie + if time.Now().After(validUntil.Add(-cookieRefresh)) { + validTokensLock.Lock() + validTokens[c.Value] = time.Now() + validTokensLock.Unlock() + } + next.ServeHTTP(w, r) + return + } + } + + // get authenticator + authFnLock.Lock() + authenticator := authFn + authFnLock.Unlock() + + // permit if no authenticator set + if authenticator == nil { + next.ServeHTTP(w, r) + return + } + + // get auth decision + grantAccess, err := authenticator(server, r) + if err != nil { + log.Warningf("api: authenticator failed: %s", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + if !grantAccess { + log.Warningf("api: denying api access to %s", r.RemoteAddr) + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + + // write new cookie + token, err := random.Bytes(32) // 256 bit + if err != nil { + log.Warningf("api: failed to generate random token: %s", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + tokenString := base64.RawURLEncoding.EncodeToString(token) + http.SetCookie(w, &http.Cookie{ + Name: cookieName, + Value: tokenString, + HttpOnly: true, + }) + + // serve + log.Tracef("api: granted %s", r.RemoteAddr) + next.ServeHTTP(w, r) + }) +} diff --git a/api/config.go b/api/config.go index 8a6e885..7a272c6 100644 --- a/api/config.go +++ b/api/config.go @@ -17,40 +17,41 @@ func init() { flag.StringVar(&listenAddressFlag, "api-address", "", "override api listen address") } -func checkFlags() error { +func logFlagOverrides() { if listenAddressFlag != "" { - log.Warning("api: api/listenAddress config is being overridden by -api-address flag") + log.Warning("api: api/listenAddress default config is being overridden by -api-address flag") } - return nil } -func getListenAddress() string { +func getDefaultListenAddress() string { + // check if overridden if listenAddressFlag != "" { return listenAddressFlag } - return listenAddressConfig() + // return internal default + return defaultListenAddress } func registerConfig() error { err := config.Register(&config.Option{ Name: "API Address", Key: "api/listenAddress", - Description: "Define on what IP and port the API should listen on. Be careful, changing this may become a security issue.", - ExpertiseLevel: config.ExpertiseLevelExpert, + Description: "Define on which IP and port the API should listen on.", + ExpertiseLevel: config.ExpertiseLevelDeveloper, OptType: config.OptTypeString, - DefaultValue: defaultListenAddress, + DefaultValue: getDefaultListenAddress(), ValidationRegex: "^([0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}:[0-9]{1,5}|\\[[:0-9A-Fa-f]+\\]:[0-9]{1,5})$", + RequiresRestart: true, }) if err != nil { return err } - listenAddressConfig = config.GetAsString("api/listenAddress", defaultListenAddress) + listenAddressConfig = config.GetAsString("api/listenAddress", getDefaultListenAddress()) return nil } +// SetDefaultAPIListenAddress sets the default listen address for the API. func SetDefaultAPIListenAddress(address string) { - if defaultListenAddress == "" { - defaultListenAddress = address - } + defaultListenAddress = address } diff --git a/api/database.go b/api/database.go index 0f900ed..dac729e 100644 --- a/api/database.go +++ b/api/database.go @@ -35,6 +35,10 @@ var ( dbAPISeperatorBytes = []byte(dbAPISeperator) ) +func init() { + RegisterHandleFunc("/api/database/v1", startDatabaseAPI) // net/http pattern matching only this exact path +} + // DatabaseAPI is a database API instance. type DatabaseAPI struct { conn *websocket.Conn diff --git a/api/enriched-response.go b/api/enriched-response.go index 58fd983..6aa5685 100644 --- a/api/enriched-response.go +++ b/api/enriched-response.go @@ -1,25 +1,68 @@ package api import ( + "bufio" + "errors" + "net" "net/http" + + "github.com/safing/portbase/log" ) -// EnrichedResponseWriter is a wrapper for http.ResponseWriter for better information extraction. -type EnrichedResponseWriter struct { - http.ResponseWriter - Status int +// LoggingResponseWriter is a wrapper for http.ResponseWriter for better request logging. +type LoggingResponseWriter struct { + ResponseWriter http.ResponseWriter + Request *http.Request + Status int } -// NewEnrichedResponseWriter wraps a http.ResponseWriter. -func NewEnrichedResponseWriter(w http.ResponseWriter) *EnrichedResponseWriter { - return &EnrichedResponseWriter{ - w, - 0, +// NewLoggingResponseWriter wraps a http.ResponseWriter. +func NewLoggingResponseWriter(w http.ResponseWriter, r *http.Request) *LoggingResponseWriter { + return &LoggingResponseWriter{ + ResponseWriter: w, + Request: r, } } -// WriteHeader wraps the original WriteHeader method to extract information. -func (ew *EnrichedResponseWriter) WriteHeader(code int) { - ew.Status = code - ew.ResponseWriter.WriteHeader(code) +// Header wraps the original Header method. +func (lrw *LoggingResponseWriter) Header() http.Header { + return lrw.ResponseWriter.Header() +} + +// Write wraps the original Write method. +func (lrw *LoggingResponseWriter) Write(b []byte) (int, error) { + return lrw.ResponseWriter.Write(b) +} + +// WriteHeader wraps the original WriteHeader method to extract information. +func (lrw *LoggingResponseWriter) WriteHeader(code int) { + lrw.Status = code + lrw.ResponseWriter.WriteHeader(code) +} + +// Hijack wraps the original Hijack method, if available. +func (lrw *LoggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := lrw.ResponseWriter.(http.Hijacker) + if ok { + c, b, err := hijacker.Hijack() + if err != nil { + return nil, nil, err + } + log.Infof("api request: %s HIJ %s", lrw.Request.RemoteAddr, lrw.Request.RequestURI) + return c, b, nil + } + return nil, nil, errors.New("response does not implement http.Hijacker") +} + +// RequestLogger is a logging middleware +func RequestLogger(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI) + lrw := NewLoggingResponseWriter(w, r) + next.ServeHTTP(lrw, r) + if lrw.Status != 0 { + // request may have been hijacked + log.Infof("api request: %s %d %s", lrw.Request.RemoteAddr, lrw.Status, lrw.Request.RequestURI) + } + }) } diff --git a/api/main.go b/api/main.go index 4a3c3d9..830fe79 100644 --- a/api/main.go +++ b/api/main.go @@ -1,22 +1,37 @@ package api import ( + "context" + "errors" + "github.com/safing/portbase/modules" ) +// API Errors +var ( + ErrAuthenticationAlreadySet = errors.New("the authentication function has already been set") +) + func init() { - modules.Register("api", prep, start, nil, "database") + modules.Register("api", prep, start, stop, "base", "database", "config") } func prep() error { - err := checkFlags() - if err != nil { - return err + if getDefaultListenAddress() == "" { + return errors.New("no listen address for api available") } return registerConfig() } func start() error { + logFlagOverrides() go Serve() return nil } + +func stop() error { + if server != nil { + return server.Shutdown(context.Background()) + } + return nil +} diff --git a/api/middleware.go b/api/middleware.go new file mode 100644 index 0000000..89dd465 --- /dev/null +++ b/api/middleware.go @@ -0,0 +1,27 @@ +package api + +import "net/http" + +// Middleware is a function that can be added as a middleware to the API endpoint. +type Middleware func(next http.Handler) http.Handler + +type mwHandler struct { + handlers []Middleware + final http.Handler +} + +func (mwh *mwHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + handlerLock.RLock() + defer handlerLock.RUnlock() + + // final handler + handler := mwh.final + + // build middleware chain + for _, mw := range mwh.handlers { + handler = mw(handler) + } + + // start + handler.ServeHTTP(w, r) +} diff --git a/api/router.go b/api/router.go index 06cbba7..bce3884 100644 --- a/api/router.go +++ b/api/router.go @@ -2,6 +2,7 @@ package api import ( "net/http" + "sync" "github.com/gorilla/mux" @@ -9,35 +10,54 @@ import ( ) var ( - router = mux.NewRouter() + // gorilla mux + mainMux = mux.NewRouter() + + // middlewares + middlewareHandler = &mwHandler{ + final: mainMux, + handlers: []Middleware{ + RequestLogger, + authMiddleware, + }, + } + + // main server and lock + server = &http.Server{} + handlerLock sync.RWMutex ) -// RegisterHandleFunc registers an additional handle function with the API endoint. -func RegisterHandleFunc(path string, handleFunc func(http.ResponseWriter, *http.Request)) *mux.Route { - return router.HandleFunc(path, handleFunc) +// RegisterHandler registers a handler with the API endoint. +func RegisterHandler(path string, handler http.Handler) *mux.Route { + handlerLock.Lock() + defer handlerLock.Unlock() + return mainMux.Handle(path, handler) } -// RequestLogger is a logging middleware -func RequestLogger(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI) - ew := NewEnrichedResponseWriter(w) - next.ServeHTTP(ew, r) - log.Infof("api request: %s %d %s", r.RemoteAddr, ew.Status, r.RequestURI) - }) +// RegisterHandleFunc registers a handle function with the API endoint. +func RegisterHandleFunc(path string, handleFunc func(http.ResponseWriter, *http.Request)) *mux.Route { + handlerLock.Lock() + defer handlerLock.Unlock() + return mainMux.HandleFunc(path, handleFunc) +} + +// RegisterMiddleware registers a middle function with the API endoint. +func RegisterMiddleware(middleware Middleware) { + handlerLock.Lock() + defer handlerLock.Unlock() + middlewareHandler.handlers = append(middlewareHandler.handlers, middleware) } // Serve starts serving the API endpoint. func Serve() { - router.Use(RequestLogger) + // configure server + server.Addr = listenAddressConfig() + server.Handler = middlewareHandler - mainMux := http.NewServeMux() - mainMux.Handle("/", router) // net/http pattern matching /* - mainMux.HandleFunc("/api/database/v1", startDatabaseAPI) // net/http pattern matching only this exact path - - address := getListenAddress() - log.Infof("api: starting to listen on %s", address) - log.Errorf("api: failed to listen on %s: %s", address, http.ListenAndServe(address, mainMux)) + // start serving + log.Infof("api: starting to listen on %s", server.Addr) + // TODO: retry if failed + log.Errorf("api: failed to listen on %s: %s", server.Addr, server.ListenAndServe()) } // GetMuxVars wraps github.com/gorilla/mux.Vars in order to mitigate context key issues in multi-repo projects. diff --git a/api/security.go b/api/security.go deleted file mode 100644 index 778f64e..0000000 --- a/api/security.go +++ /dev/null @@ -1 +0,0 @@ -package api diff --git a/api/testclient/serve.go b/api/testclient/serve.go index 01c40e5..e70e636 100644 --- a/api/testclient/serve.go +++ b/api/testclient/serve.go @@ -7,5 +7,5 @@ import ( ) func init() { - api.RegisterAdditionalRoute("/test/", http.StripPrefix("/test/", http.FileServer(http.Dir("./api/testclient/root/")))) + api.RegisterHandler("/test/", http.StripPrefix("/test/", http.FileServer(http.Dir("./api/testclient/root/")))) } diff --git a/api/websocket.go b/api/websocket.go deleted file mode 100644 index 778f64e..0000000 --- a/api/websocket.go +++ /dev/null @@ -1 +0,0 @@ -package api diff --git a/config/database.go b/config/database.go index 9c1f8fa..96186c8 100644 --- a/config/database.go +++ b/config/database.go @@ -1,127 +1,127 @@ package config import ( - "errors" - "sort" - "strings" + "errors" + "sort" + "strings" - "github.com/safing/portbase/log" "github.com/safing/portbase/database" - "github.com/safing/portbase/database/storage" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/database/query" "github.com/safing/portbase/database/iterator" + "github.com/safing/portbase/database/query" + "github.com/safing/portbase/database/record" + "github.com/safing/portbase/database/storage" + "github.com/safing/portbase/log" ) var ( - dbController *database.Controller + dbController *database.Controller ) -// ConfigStorageInterface provices a storage.Interface to the configuration manager. -type ConfigStorageInterface struct { +// StorageInterface provices a storage.Interface to the configuration manager. +type StorageInterface struct { storage.InjectBase } // Get returns a database record. -func (s *ConfigStorageInterface) Get(key string) (record.Record, error) { - optionsLock.Lock() - defer optionsLock.Unlock() +func (s *StorageInterface) Get(key string) (record.Record, error) { + optionsLock.Lock() + defer optionsLock.Unlock() - opt, ok := options[key] - if !ok { - return nil, storage.ErrNotFound - } + opt, ok := options[key] + if !ok { + return nil, storage.ErrNotFound + } - return opt.Export() + return opt.Export() } // Put stores a record in the database. -func (s *ConfigStorageInterface) Put(r record.Record) error { - if r.Meta().Deleted > 0 { - return setConfigOption(r.DatabaseKey(), nil, false) - } +func (s *StorageInterface) Put(r record.Record) error { + if r.Meta().Deleted > 0 { + return setConfigOption(r.DatabaseKey(), nil, false) + } - acc := r.GetAccessor(r) - if acc == nil { - return errors.New("invalid data") - } + acc := r.GetAccessor(r) + if acc == nil { + return errors.New("invalid data") + } - val, ok := acc.Get("Value") - if !ok || val == nil { - return setConfigOption(r.DatabaseKey(), nil, false) - } + val, ok := acc.Get("Value") + if !ok || val == nil { + return setConfigOption(r.DatabaseKey(), nil, false) + } - optionsLock.RLock() - option, ok := options[r.DatabaseKey()] + optionsLock.RLock() + option, ok := options[r.DatabaseKey()] optionsLock.RUnlock() - if !ok { - return errors.New("config option does not exist") - } + if !ok { + return errors.New("config option does not exist") + } - var value interface{} - switch option.OptType { - case OptTypeString : - value, ok = acc.GetString("Value") - case OptTypeStringArray : - value, ok = acc.GetStringArray("Value") - case OptTypeInt : - value, ok = acc.GetInt("Value") - case OptTypeBool : - value, ok = acc.GetBool("Value") - } - if !ok { - return errors.New("received invalid value in \"Value\"") - } + var value interface{} + switch option.OptType { + case OptTypeString: + value, ok = acc.GetString("Value") + case OptTypeStringArray: + value, ok = acc.GetStringArray("Value") + case OptTypeInt: + value, ok = acc.GetInt("Value") + case OptTypeBool: + value, ok = acc.GetBool("Value") + } + if !ok { + return errors.New("received invalid value in \"Value\"") + } - err := setConfigOption(r.DatabaseKey(), value, false) - if err != nil { - return err - } - return nil + err := setConfigOption(r.DatabaseKey(), value, false) + if err != nil { + return err + } + return nil } // Delete deletes a record from the database. -func (s *ConfigStorageInterface) Delete(key string) error { - return setConfigOption(key, nil, false) +func (s *StorageInterface) Delete(key string) error { + return setConfigOption(key, nil, false) } // Query returns a an iterator for the supplied query. -func (s *ConfigStorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { +func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { - optionsLock.Lock() - defer optionsLock.Unlock() + optionsLock.Lock() + defer optionsLock.Unlock() - it := iterator.New() - var opts []*Option - for _, opt := range options { - if strings.HasPrefix(opt.Key, q.DatabaseKeyPrefix()) { - opts = append(opts, opt) - } - } + it := iterator.New() + var opts []*Option + for _, opt := range options { + if strings.HasPrefix(opt.Key, q.DatabaseKeyPrefix()) { + opts = append(opts, opt) + } + } - go s.processQuery(q, it, opts) + go s.processQuery(q, it, opts) - return it, nil + return it, nil } -func (s *ConfigStorageInterface) processQuery(q *query.Query, it *iterator.Iterator, opts []*Option) { +func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator, opts []*Option) { - sort.Sort(sortableOptions(opts)) + sort.Sort(sortableOptions(opts)) - for _, opt := range opts { - r, err := opt.Export() - if err != nil { - it.Finish(err) - return - } - it.Next <- r - } + for _, opt := range opts { + r, err := opt.Export() + if err != nil { + it.Finish(err) + return + } + it.Next <- r + } - it.Finish(nil) + it.Finish(nil) } // ReadOnly returns whether the database is read only. -func (s *ConfigStorageInterface) ReadOnly() bool { +func (s *StorageInterface) ReadOnly() bool { return false } @@ -132,33 +132,33 @@ func registerAsDatabase() error { StorageType: "injected", PrimaryAPI: "", }) - if err != nil { - return err - } + if err != nil { + return err + } - controller, err := database.InjectDatabase("config", &ConfigStorageInterface{}) - if err != nil { - return err - } + controller, err := database.InjectDatabase("config", &StorageInterface{}) + if err != nil { + return err + } - dbController = controller + dbController = controller return nil } func pushFullUpdate() { - optionsLock.RLock() - defer optionsLock.RUnlock() + optionsLock.RLock() + defer optionsLock.RUnlock() - for _, option := range options { - pushUpdate(option) - } + for _, option := range options { + pushUpdate(option) + } } func pushUpdate(option *Option) { - r, err := option.Export() - if err != nil { - log.Errorf("failed to export option to push update: %s", err) - } else { - dbController.PushUpdate(r) - } + r, err := option.Export() + if err != nil { + log.Errorf("failed to export option to push update: %s", err) + } else { + dbController.PushUpdate(r) + } } diff --git a/config/main.go b/config/main.go index 52a2a4d..3b0d7d5 100644 --- a/config/main.go +++ b/config/main.go @@ -1,23 +1,40 @@ package config import ( + "errors" "os" - "path" + "path/filepath" - "github.com/safing/portbase/database" "github.com/safing/portbase/modules" + "github.com/safing/portbase/utils" + "github.com/safing/portmaster/core/structure" ) +var ( + dataRoot *utils.DirStructure +) + +// SetDataRoot sets the data root from which the updates module derives its paths. +func SetDataRoot(root *utils.DirStructure) { + if dataRoot == nil { + dataRoot = root + } +} + func init() { - modules.Register("config", prep, start, nil, "database") + modules.Register("config", prep, start, nil, "base", "database") } func prep() error { + SetDataRoot(structure.Root()) + if dataRoot == nil { + return errors.New("data root is not set") + } return nil } func start() error { - configFilePath = path.Join(database.GetDatabaseRoot(), "config.json") + configFilePath = filepath.Join(dataRoot.Path, "config.json") err := registerAsDatabase() if err != nil && !os.IsNotExist(err) { diff --git a/config/option.go b/config/option.go index 92d7a76..90f0295 100644 --- a/config/option.go +++ b/config/option.go @@ -1,9 +1,9 @@ package config import ( + "encoding/json" "fmt" "regexp" - "encoding/json" "github.com/tidwall/sjson" @@ -47,6 +47,7 @@ type Option struct { DefaultValue interface{} ExternalOptType string ValidationRegex string + RequiresRestart bool compiledRegex *regexp.Regexp } diff --git a/crypto/hash/algorithm.go b/crypto/hash/algorithm.go index 42a8be4..761da94 100644 --- a/crypto/hash/algorithm.go +++ b/crypto/hash/algorithm.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package hash import ( diff --git a/crypto/hash/algorithm_test.go b/crypto/hash/algorithm_test.go index fad7ad8..45b1ea0 100644 --- a/crypto/hash/algorithm_test.go +++ b/crypto/hash/algorithm_test.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package hash import "testing" diff --git a/crypto/hash/hash.go b/crypto/hash/hash.go index be3a763..02ce23c 100644 --- a/crypto/hash/hash.go +++ b/crypto/hash/hash.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package hash import ( diff --git a/crypto/hash/hash_test.go b/crypto/hash/hash_test.go index 050457c..c8d8381 100644 --- a/crypto/hash/hash_test.go +++ b/crypto/hash/hash_test.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package hash import ( diff --git a/crypto/hash/proxies.go b/crypto/hash/proxies.go index ceedcf2..af574b5 100644 --- a/crypto/hash/proxies.go +++ b/crypto/hash/proxies.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package hash import ( diff --git a/crypto/random/rng.go b/crypto/random/rng.go index 464b363..ef461a5 100644 --- a/crypto/random/rng.go +++ b/crypto/random/rng.go @@ -23,7 +23,7 @@ var ( ) func init() { - modules.Register("random", prep, Start, stop) + modules.Register("random", prep, Start, stop, "base") config.Register(&config.Option{ Name: "RNG Cipher", diff --git a/database/dbmodule/db.go b/database/dbmodule/db.go index 765af96..780ae7e 100644 --- a/database/dbmodule/db.go +++ b/database/dbmodule/db.go @@ -2,51 +2,46 @@ package dbmodule import ( "errors" - "flag" - "sync" "github.com/safing/portbase/database" "github.com/safing/portbase/modules" + "github.com/safing/portbase/utils" ) var ( - databaseDir string - shutdownSignal = make(chan struct{}) - maintenanceWg sync.WaitGroup + databasePath string + databaseStructureRoot *utils.DirStructure + + module *modules.Module ) -// SetDatabaseLocation sets the location of the database. Must be called before modules.Start and will be overridden by command line options. Intended for unit tests. -func SetDatabaseLocation(location string) { - databaseDir = location +func init() { + module = modules.Register("database", prep, start, stop, "base") } -func init() { - flag.StringVar(&databaseDir, "db", "", "set database directory") - - modules.Register("database", prep, start, stop) +// SetDatabaseLocation sets the location of the database for initialization. Supply either a path or dir structure. +func SetDatabaseLocation(dirPath string, dirStructureRoot *utils.DirStructure) { + databasePath = dirPath + databaseStructureRoot = dirStructureRoot } func prep() error { - if databaseDir == "" { - return errors.New("no database location specified, set with `-db=/path/to/db`") - } - ok := database.SetLocation(databaseDir) - if !ok { - return errors.New("database location already set") + if databasePath == "" && databaseStructureRoot == nil { + return errors.New("no database location specified") } return nil } func start() error { - err := database.Initialize() - if err == nil { - startMaintainer() + err := database.Initialize(databasePath, databaseStructureRoot) + if err != nil { + return err } - return err + + startMaintainer() + return nil } func stop() error { - close(shutdownSignal) - maintenanceWg.Wait() return database.Shutdown() } diff --git a/database/dbmodule/maintenance.go b/database/dbmodule/maintenance.go index 1635a20..82c6f98 100644 --- a/database/dbmodule/maintenance.go +++ b/database/dbmodule/maintenance.go @@ -13,7 +13,7 @@ var ( ) func startMaintainer() { - maintenanceWg.Add(1) + module.AddWorkers(1) go maintenanceWorker() } @@ -37,8 +37,8 @@ func maintenanceWorker() { if err != nil { log.Errorf("database: thorough maintenance error: %s", err) } - case <-shutdownSignal: - maintenanceWg.Done() + case <-module.ShuttingDown(): + module.FinishWorker() return } } diff --git a/database/dbutils/meta.go b/database/dbutils/meta.go index 8c3af80..4c0dcc8 100644 --- a/database/dbutils/meta.go +++ b/database/dbutils/meta.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package dbutils type Meta struct { diff --git a/database/doc.go b/database/doc.go index d93ff5f..3c8ff8b 100644 --- a/database/doc.go +++ b/database/doc.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - /* Package database provides a universal interface for interacting with the database. diff --git a/database/errors.go b/database/errors.go index 55d42e6..cae280d 100644 --- a/database/errors.go +++ b/database/errors.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package database import ( diff --git a/database/location.go b/database/location.go index 7ca9777..636bab8 100644 --- a/database/location.go +++ b/database/location.go @@ -1,33 +1 @@ package database - -import ( - "fmt" - "path/filepath" - - "github.com/safing/portbase/utils" -) - -const ( - databasesSubDir = "databases" -) - -var ( - rootDir string -) - -// GetDatabaseRoot returns the root directory of the database. -func GetDatabaseRoot() string { - return rootDir -} - -// getLocation returns the storage location for the given name and type. -func getLocation(name, storageType string) (string, error) { - location := filepath.Join(rootDir, databasesSubDir, name, storageType) - - // check location - err := utils.EnsureDirectory(location, 0700) - if err != nil { - return "", fmt.Errorf("location (%s) invalid: %s", location, err) - } - return location, nil -} diff --git a/database/main.go b/database/main.go index f918a3b..d49962c 100644 --- a/database/main.go +++ b/database/main.go @@ -9,34 +9,40 @@ import ( "github.com/tevino/abool" ) +const ( + databasesSubDir = "databases" +) + var ( initialized = abool.NewBool(false) shuttingDown = abool.NewBool(false) shutdownSignal = make(chan struct{}) + + rootStructure *utils.DirStructure + databasesStructure *utils.DirStructure ) -// SetLocation sets the location of the database. This is separate from the initialization to provide the location to other modules earlier. -func SetLocation(location string) (ok bool) { - if !initialized.IsSet() && rootDir == "" { - rootDir = location - return true - } - return false -} - -// Initialize initialized the database -func Initialize() error { +// Initialize initializes the database at the specified location. Supply either a path or dir structure. +func Initialize(dirPath string, dirStructureRoot *utils.DirStructure) error { if initialized.SetToIf(false, true) { - err := utils.EnsureDirectory(rootDir, 0755) + if dirStructureRoot != nil { + rootStructure = dirStructureRoot + } else { + rootStructure = utils.NewDirStructure(dirPath, 0755) + } + + // ensure root and databases dirs + databasesStructure = rootStructure.ChildDir(databasesSubDir, 0700) + err := databasesStructure.Ensure() if err != nil { - return fmt.Errorf("could not create/open database directory (%s): %s", rootDir, err) + return fmt.Errorf("could not create/open database directory (%s): %s", rootStructure.Path, err) } err = loadRegistry() if err != nil { - return fmt.Errorf("could not load database registry (%s): %s", filepath.Join(rootDir, registryFileName), err) + return fmt.Errorf("could not load database registry (%s): %s", filepath.Join(rootStructure.Path, registryFileName), err) } // start registry writer @@ -66,3 +72,14 @@ func Shutdown() (err error) { } return } + +// getLocation returns the storage location for the given name and type. +func getLocation(name, storageType string) (string, error) { + location := databasesStructure.ChildDir(name, 0700).ChildDir(storageType, 0700) + // check location + err := location.Ensure() + if err != nil { + return "", fmt.Errorf(`failed to create/check database dir "%s": %s`, location.Path, err) + } + return location.Path, nil +} diff --git a/database/record/base.go b/database/record/base.go index 1f97a42..8bc1d6b 100644 --- a/database/record/base.go +++ b/database/record/base.go @@ -70,7 +70,7 @@ func (b *Base) SetMeta(meta *Meta) { b.meta = meta } -// Marshal marshals the object, without the database key or metadata +// Marshal marshals the object, without the database key or metadata. It returns nil if the record is deleted. func (b *Base) Marshal(self Record, format uint8) ([]byte, error) { if b.Meta() == nil { return nil, errors.New("missing meta") @@ -96,15 +96,15 @@ func (b *Base) MarshalRecord(self Record) ([]byte, error) { // version c := container.New([]byte{1}) - // meta - metaSection, err := b.meta.GenCodeMarshal(nil) + // meta encoding + metaSection, err := dsd.Dump(b.meta, GenCode) if err != nil { return nil, err } c.AppendAsBlock(metaSection) // data - dataSection, err := b.Marshal(self, dsd.JSON) + dataSection, err := b.Marshal(self, JSON) if err != nil { return nil, err } diff --git a/database/record/formats.go b/database/record/formats.go index b09c0e6..e04d0e4 100644 --- a/database/record/formats.go +++ b/database/record/formats.go @@ -11,5 +11,5 @@ const ( BYTES = dsd.BYTES // X JSON = dsd.JSON // J BSON = dsd.BSON // B - GenCode = dsd.GenCode // G (reserved) + GenCode = dsd.GenCode // G ) diff --git a/database/record/meta-bench_test.go b/database/record/meta-bench_test.go index a531693..10b9bac 100644 --- a/database/record/meta-bench_test.go +++ b/database/record/meta-bench_test.go @@ -432,7 +432,7 @@ func BenchmarkMetaUnserializeWithCodegen(b *testing.B) { func BenchmarkMetaSerializeWithDSDJSON(b *testing.B) { for i := 0; i < b.N; i++ { - _, err := dsd.Dump(testMeta, dsd.JSON) + _, err := dsd.Dump(testMeta, JSON) if err != nil { b.Errorf("failed to serialize with DSD/JSON: %s", err) return @@ -444,7 +444,7 @@ func BenchmarkMetaSerializeWithDSDJSON(b *testing.B) { func BenchmarkMetaUnserializeWithDSDJSON(b *testing.B) { // Setup - encodedData, err := dsd.Dump(testMeta, dsd.JSON) + encodedData, err := dsd.Dump(testMeta, JSON) if err != nil { b.Errorf("failed to serialize with DSD/JSON: %s", err) return diff --git a/database/record/wrapper.go b/database/record/wrapper.go index 66768ef..a456edd 100644 --- a/database/record/wrapper.go +++ b/database/record/wrapper.go @@ -37,9 +37,20 @@ func NewRawWrapper(database, key string, data []byte) (*Wrapper, error) { offset += n newMeta := &Meta{} - _, err = newMeta.GenCodeUnmarshal(metaSection) - if err != nil { - return nil, fmt.Errorf("could not unmarshal meta section: %s", err) + if len(metaSection) == 34 && metaSection[4] == 0 { + // TODO: remove in 2020 + // backward compatibility: + // format would byte shift and populate metaSection[4] with value > 0 (would naturally populate >0 at 07.02.2106 07:28:15) + // this must be gencode without format + _, err = newMeta.GenCodeUnmarshal(metaSection) + if err != nil { + return nil, fmt.Errorf("could not unmarshal meta section: %s", err) + } + } else { + _, err = dsd.Load(metaSection, newMeta) + if err != nil { + return nil, fmt.Errorf("could not unmarshal meta section: %s", err) + } } format, n, err := varint.Unpack8(data[offset:]) @@ -86,7 +97,7 @@ func (w *Wrapper) Marshal(r Record, format uint8) ([]byte, error) { return nil, nil } - if format != dsd.AUTO && format != w.Format { + if format != AUTO && format != w.Format { return nil, errors.New("could not dump model, wrapped object format mismatch") } @@ -109,14 +120,14 @@ func (w *Wrapper) MarshalRecord(r Record) ([]byte, error) { c := container.New([]byte{1}) // meta - metaSection, err := w.meta.GenCodeMarshal(nil) + metaSection, err := dsd.Dump(w.meta, GenCode) if err != nil { return nil, err } c.AppendAsBlock(metaSection) // data - dataSection, err := w.Marshal(r, dsd.JSON) + dataSection, err := w.Marshal(r, JSON) if err != nil { return nil, err } @@ -125,16 +136,6 @@ func (w *Wrapper) MarshalRecord(r Record) ([]byte, error) { return c.CompileData(), nil } -// // Lock locks the record. -// func (w *Wrapper) Lock() { -// w.lock.Lock() -// } -// -// // Unlock unlocks the record. -// func (w *Wrapper) Unlock() { -// w.lock.Unlock() -// } - // IsWrapped returns whether the record is a Wrapper. func (w *Wrapper) IsWrapped() bool { return true diff --git a/database/record/wrapper_test.go b/database/record/wrapper_test.go index 83d72ae..460e6e9 100644 --- a/database/record/wrapper_test.go +++ b/database/record/wrapper_test.go @@ -2,9 +2,10 @@ package record import ( "bytes" + "errors" "testing" - "github.com/safing/portbase/formats/dsd" + "github.com/safing/portbase/container" ) func TestWrapper(t *testing.T) { @@ -24,14 +25,14 @@ func TestWrapper(t *testing.T) { if err != nil { t.Fatal(err) } - if wrapper.Format != dsd.JSON { + if wrapper.Format != JSON { t.Error("format mismatch") } if !bytes.Equal(testData, wrapper.Data) { t.Error("data mismatch") } - encoded, err := wrapper.Marshal(wrapper, dsd.JSON) + encoded, err := wrapper.Marshal(wrapper, JSON) if err != nil { t.Fatal(err) } @@ -40,6 +41,7 @@ func TestWrapper(t *testing.T) { } wrapper.SetMeta(&Meta{}) + wrapper.meta.Update() raw, err := wrapper.MarshalRecord(wrapper) if err != nil { t.Fatal(err) @@ -53,4 +55,42 @@ func TestWrapper(t *testing.T) { t.Error("marshal mismatch") } + // test new format + oldRaw, err := oldWrapperMarshalRecord(wrapper, wrapper) + if err != nil { + t.Fatal(err) + } + + wrapper3, err := NewRawWrapper("test", "a", oldRaw) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(testData, wrapper3.Data) { + t.Error("marshal mismatch") + } +} + +func oldWrapperMarshalRecord(w *Wrapper, r Record) ([]byte, error) { + if w.Meta() == nil { + return nil, errors.New("missing meta") + } + + // version + c := container.New([]byte{1}) + + // meta + metaSection, err := w.meta.GenCodeMarshal(nil) + if err != nil { + return nil, err + } + c.AppendAsBlock(metaSection) + + // data + dataSection, err := w.Marshal(r, JSON) + if err != nil { + return nil, err + } + c.Append(dataSection) + + return c.CompileData(), nil } diff --git a/database/registry.go b/database/registry.go index 7195bc9..3cdc11d 100644 --- a/database/registry.go +++ b/database/registry.go @@ -104,7 +104,7 @@ func loadRegistry() error { defer registryLock.Unlock() // read file - filePath := path.Join(rootDir, registryFileName) + filePath := path.Join(rootStructure.Path, registryFileName) data, err := ioutil.ReadFile(filePath) if err != nil { if os.IsNotExist(err) { @@ -139,7 +139,7 @@ func saveRegistry(lock bool) error { } // write file - filePath := path.Join(rootDir, registryFileName) + filePath := path.Join(rootStructure.Path, registryFileName) return ioutil.WriteFile(filePath, data, 0600) } diff --git a/formats/dsd/dsd.go b/formats/dsd/dsd.go index 2513323..a5e7434 100644 --- a/formats/dsd/dsd.go +++ b/formats/dsd/dsd.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package dsd // dynamic structured data @@ -22,7 +20,7 @@ const ( BYTES = 88 // X JSON = 74 // J BSON = 66 // B - GenCode = 71 // G (reserved) + GenCode = 71 // G ) // define errors @@ -30,6 +28,7 @@ var errNoMoreSpace = errors.New("dsd: no more space left after reading dsd type" var errUnknownType = errors.New("dsd: tried to unpack unknown type") var errNotImplemented = errors.New("dsd: this type is not yet implemented") +// Load loads an dsd structured data blob into the given interface. func Load(data []byte, t interface{}) (interface{}, error) { if len(data) < 2 { return nil, errNoMoreSpace @@ -46,6 +45,7 @@ func Load(data []byte, t interface{}) (interface{}, error) { return LoadAsFormat(data[read:], format, t) } +// LoadAsFormat loads a data blob into the interface using the specified format. func LoadAsFormat(data []byte, format uint8, t interface{}) (interface{}, error) { switch format { case STRING: @@ -55,28 +55,32 @@ func LoadAsFormat(data []byte, format uint8, t interface{}) (interface{}, error) case JSON: err := json.Unmarshal(data, t) if err != nil { - return nil, err + return nil, fmt.Errorf("dsd: failed to unpack json data: %s", data) + } + return t, nil + // case BSON: + // err := bson.Unmarshal(data[read:], t) + // if err != nil { + // return nil, err + // } + // return t, nil + case GenCode: + genCodeStruct, ok := t.(GenCodeCompatible) + if !ok { + return nil, errors.New("dsd: gencode is not supported by the given data structure") + } + _, err := genCodeStruct.GenCodeUnmarshal(data) + if err != nil { + return nil, fmt.Errorf("dsd: failed to unpack gencode data: %s", err) } return t, nil - // case BSON: - // err := bson.Unmarshal(data[read:], t) - // if err != nil { - // return nil, err - // } - // return t, nil - // case MSGP: - // err := t.UnmarshalMsg(data[read:]) - // if err != nil { - // return nil, err - // } - // return t, nil default: - return nil, errors.New(fmt.Sprintf("dsd: tried to load unknown type %d, data: %v", format, data)) + return nil, fmt.Errorf("dsd: tried to load unknown type %d, data: %v", format, data) } } +// Dump stores the interface as a dsd formatted data structure. func Dump(t interface{}, format uint8) ([]byte, error) { - if format == AUTO { switch t.(type) { case string: @@ -107,18 +111,19 @@ func Dump(t interface{}, format uint8) ([]byte, error) { // if err != nil { // return nil, err // } - // case MSGP: - // data, err := t.MarshalMsg(nil) - // if err != nil { - // return nil, err - // } + case GenCode: + genCodeStruct, ok := t.(GenCodeCompatible) + if !ok { + return nil, errors.New("dsd: gencode is not supported by the given data structure") + } + data, err = genCodeStruct.GenCodeMarshal(nil) + if err != nil { + return nil, fmt.Errorf("dsd: failed to pack gencode struct: %s", err) + } default: - return nil, errors.New(fmt.Sprintf("dsd: tried to dump unknown type %d", format)) + return nil, fmt.Errorf("dsd: tried to dump unknown type %d", format) } r := append(f, data...) - // log.Tracef("packing %v to %s", t, string(r)) - // return nil, errors.New(fmt.Sprintf("dsd: dumped bytes are: %v", r)) return r, nil - } diff --git a/formats/dsd/dsd_test.go b/formats/dsd/dsd_test.go index 24c9140..c660009 100644 --- a/formats/dsd/dsd_test.go +++ b/formats/dsd/dsd_test.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package dsd import ( @@ -10,6 +8,7 @@ import ( //go:generate msgp +// SimpleTestStruct is used for testing. type SimpleTestStruct struct { S string B byte @@ -21,11 +20,11 @@ type ComplexTestStruct struct { I16 int16 I32 int32 I64 int64 - Ui uint - Ui8 uint8 - Ui16 uint16 - Ui32 uint32 - Ui64 uint64 + UI uint + UI8 uint8 + UI16 uint16 + UI32 uint32 + UI64 uint64 S string Sp *string Sa []string @@ -38,6 +37,25 @@ type ComplexTestStruct struct { Mp *map[string]string } +type GenCodeTestStruct struct { + I8 int8 + I16 int16 + I32 int32 + I64 int64 + UI8 uint8 + UI16 uint16 + UI32 uint32 + UI64 uint64 + S string + Sp *string + Sa []string + Sap *[]string + B byte + Bp *byte + Ba []byte + Bap *[]byte +} + func TestConversion(t *testing.T) { // STRING @@ -113,7 +131,26 @@ func TestConversion(t *testing.T) { }, } - // TODO: test all formats + genCodeSubject := GenCodeTestStruct{ + -2, + -3, + -4, + -5, + 2, + 3, + 4, + 5, + "a", + &bString, + []string{"c", "d", "e"}, + &[]string{"f", "g", "h"}, + 0x01, + &bBytes, + []byte{0x03, 0x04, 0x05}, + &[]byte{0x05, 0x06, 0x07}, + } + + // test all formats (complex) formats := []uint8{JSON} for _, format := range formats { @@ -163,20 +200,20 @@ func TestConversion(t *testing.T) { if complexSubject.I64 != co.I64 { t.Errorf("Load (complex struct): struct.I64 is not equal (%v != %v)", complexSubject.I64, co.I64) } - if complexSubject.Ui != co.Ui { - t.Errorf("Load (complex struct): struct.Ui is not equal (%v != %v)", complexSubject.Ui, co.Ui) + if complexSubject.UI != co.UI { + t.Errorf("Load (complex struct): struct.UI is not equal (%v != %v)", complexSubject.UI, co.UI) } - if complexSubject.Ui8 != co.Ui8 { - t.Errorf("Load (complex struct): struct.Ui8 is not equal (%v != %v)", complexSubject.Ui8, co.Ui8) + if complexSubject.UI8 != co.UI8 { + t.Errorf("Load (complex struct): struct.UI8 is not equal (%v != %v)", complexSubject.UI8, co.UI8) } - if complexSubject.Ui16 != co.Ui16 { - t.Errorf("Load (complex struct): struct.Ui16 is not equal (%v != %v)", complexSubject.Ui16, co.Ui16) + if complexSubject.UI16 != co.UI16 { + t.Errorf("Load (complex struct): struct.UI16 is not equal (%v != %v)", complexSubject.UI16, co.UI16) } - if complexSubject.Ui32 != co.Ui32 { - t.Errorf("Load (complex struct): struct.Ui32 is not equal (%v != %v)", complexSubject.Ui32, co.Ui32) + if complexSubject.UI32 != co.UI32 { + t.Errorf("Load (complex struct): struct.UI32 is not equal (%v != %v)", complexSubject.UI32, co.UI32) } - if complexSubject.Ui64 != co.Ui64 { - t.Errorf("Load (complex struct): struct.Ui64 is not equal (%v != %v)", complexSubject.Ui64, co.Ui64) + if complexSubject.UI64 != co.UI64 { + t.Errorf("Load (complex struct): struct.UI64 is not equal (%v != %v)", complexSubject.UI64, co.UI64) } if complexSubject.S != co.S { t.Errorf("Load (complex struct): struct.S is not equal (%v != %v)", complexSubject.S, co.S) @@ -211,4 +248,87 @@ func TestConversion(t *testing.T) { } + // test all formats + formats = []uint8{JSON, GenCode} + + for _, format := range formats { + // simple + b, err := Dump(&simpleSubject, format) + if err != nil { + t.Fatalf("Dump error (simple struct): %s", err) + } + + o, err := Load(b, &SimpleTestStruct{}) + if err != nil { + t.Fatalf("Load error (simple struct): %s", err) + } + + if !reflect.DeepEqual(&simpleSubject, o) { + t.Errorf("Load (simple struct): subject does not match loaded object") + t.Errorf("Encoded: %v", string(b)) + t.Errorf("Compared: %v == %v", &simpleSubject, o) + } + + // complex + b, err = Dump(&genCodeSubject, format) + if err != nil { + t.Fatalf("Dump error (complex struct): %s", err) + } + + o, err = Load(b, &GenCodeTestStruct{}) + if err != nil { + t.Fatalf("Load error (complex struct): %s", err) + } + + co := o.(*GenCodeTestStruct) + + if genCodeSubject.I8 != co.I8 { + t.Errorf("Load (complex struct): struct.I8 is not equal (%v != %v)", genCodeSubject.I8, co.I8) + } + if genCodeSubject.I16 != co.I16 { + t.Errorf("Load (complex struct): struct.I16 is not equal (%v != %v)", genCodeSubject.I16, co.I16) + } + if genCodeSubject.I32 != co.I32 { + t.Errorf("Load (complex struct): struct.I32 is not equal (%v != %v)", genCodeSubject.I32, co.I32) + } + if genCodeSubject.I64 != co.I64 { + t.Errorf("Load (complex struct): struct.I64 is not equal (%v != %v)", genCodeSubject.I64, co.I64) + } + if genCodeSubject.UI8 != co.UI8 { + t.Errorf("Load (complex struct): struct.UI8 is not equal (%v != %v)", genCodeSubject.UI8, co.UI8) + } + if genCodeSubject.UI16 != co.UI16 { + t.Errorf("Load (complex struct): struct.UI16 is not equal (%v != %v)", genCodeSubject.UI16, co.UI16) + } + if genCodeSubject.UI32 != co.UI32 { + t.Errorf("Load (complex struct): struct.UI32 is not equal (%v != %v)", genCodeSubject.UI32, co.UI32) + } + if genCodeSubject.UI64 != co.UI64 { + t.Errorf("Load (complex struct): struct.UI64 is not equal (%v != %v)", genCodeSubject.UI64, co.UI64) + } + if genCodeSubject.S != co.S { + t.Errorf("Load (complex struct): struct.S is not equal (%v != %v)", genCodeSubject.S, co.S) + } + if !reflect.DeepEqual(genCodeSubject.Sp, co.Sp) { + t.Errorf("Load (complex struct): struct.Sp is not equal (%v != %v)", genCodeSubject.Sp, co.Sp) + } + if !reflect.DeepEqual(genCodeSubject.Sa, co.Sa) { + t.Errorf("Load (complex struct): struct.Sa is not equal (%v != %v)", genCodeSubject.Sa, co.Sa) + } + if !reflect.DeepEqual(genCodeSubject.Sap, co.Sap) { + t.Errorf("Load (complex struct): struct.Sap is not equal (%v != %v)", genCodeSubject.Sap, co.Sap) + } + if genCodeSubject.B != co.B { + t.Errorf("Load (complex struct): struct.B is not equal (%v != %v)", genCodeSubject.B, co.B) + } + if !reflect.DeepEqual(genCodeSubject.Bp, co.Bp) { + t.Errorf("Load (complex struct): struct.Bp is not equal (%v != %v)", genCodeSubject.Bp, co.Bp) + } + if !reflect.DeepEqual(genCodeSubject.Ba, co.Ba) { + t.Errorf("Load (complex struct): struct.Ba is not equal (%v != %v)", genCodeSubject.Ba, co.Ba) + } + if !reflect.DeepEqual(genCodeSubject.Bap, co.Bap) { + t.Errorf("Load (complex struct): struct.Bap is not equal (%v != %v)", genCodeSubject.Bap, co.Bap) + } + } } diff --git a/formats/dsd/gencode_test.go b/formats/dsd/gencode_test.go new file mode 100644 index 0000000..d024af0 --- /dev/null +++ b/formats/dsd/gencode_test.go @@ -0,0 +1,835 @@ +package dsd + +import ( + "io" + "time" + "unsafe" +) + +var ( + _ = unsafe.Sizeof(0) + _ = io.ReadFull + _ = time.Now() +) + +func (d *SimpleTestStruct) Size() (s uint64) { + + { + l := uint64(len(d.S)) + + { + + t := l + for t >= 0x80 { + t >>= 7 + s++ + } + s++ + + } + s += l + } + s++ + return +} + +func (d *SimpleTestStruct) GenCodeMarshal(buf []byte) ([]byte, error) { + size := d.Size() + { + if uint64(cap(buf)) >= size { + buf = buf[:size] + } else { + buf = make([]byte, size) + } + } + i := uint64(0) + + { + l := uint64(len(d.S)) + + { + + t := uint64(l) + + for t >= 0x80 { + buf[i+0] = byte(t) | 0x80 + t >>= 7 + i++ + } + buf[i+0] = byte(t) + i++ + + } + copy(buf[i+0:], d.S) + i += l + } + { + buf[i+0] = d.B + } + return buf[:i+1], nil +} + +func (d *SimpleTestStruct) GenCodeUnmarshal(buf []byte) (uint64, error) { + i := uint64(0) + + { + l := uint64(0) + + { + + bs := uint8(7) + t := uint64(buf[i+0] & 0x7F) + for buf[i+0]&0x80 == 0x80 { + i++ + t |= uint64(buf[i+0]&0x7F) << bs + bs += 7 + } + i++ + + l = t + + } + d.S = string(buf[i+0 : i+0+l]) + i += l + } + { + d.B = buf[i+0] + } + return i + 1, nil +} + +func (d *GenCodeTestStruct) Size() (s uint64) { + + { + l := uint64(len(d.S)) + + { + + t := l + for t >= 0x80 { + t >>= 7 + s++ + } + s++ + + } + s += l + } + { + if d.Sp != nil { + + { + l := uint64(len((*d.Sp))) + + { + + t := l + for t >= 0x80 { + t >>= 7 + s++ + } + s++ + + } + s += l + } + s += 0 + } + } + { + l := uint64(len(d.Sa)) + + { + + t := l + for t >= 0x80 { + t >>= 7 + s++ + } + s++ + + } + + for k0 := range d.Sa { + + { + l := uint64(len(d.Sa[k0])) + + { + + t := l + for t >= 0x80 { + t >>= 7 + s++ + } + s++ + + } + s += l + } + + } + + } + { + if d.Sap != nil { + + { + l := uint64(len((*d.Sap))) + + { + + t := l + for t >= 0x80 { + t >>= 7 + s++ + } + s++ + + } + + for k0 := range *d.Sap { + + { + l := uint64(len((*d.Sap)[k0])) + + { + + t := l + for t >= 0x80 { + t >>= 7 + s++ + } + s++ + + } + s += l + } + + } + + } + s += 0 + } + } + { + if d.Bp != nil { + + s++ + } + } + { + l := uint64(len(d.Ba)) + + { + + t := l + for t >= 0x80 { + t >>= 7 + s++ + } + s++ + + } + s += l + } + { + if d.Bap != nil { + + { + l := uint64(len((*d.Bap))) + + { + + t := l + for t >= 0x80 { + t >>= 7 + s++ + } + s++ + + } + s += l + } + s += 0 + } + } + s += 35 + return +} + +func (d *GenCodeTestStruct) GenCodeMarshal(buf []byte) ([]byte, error) { + size := d.Size() + { + if uint64(cap(buf)) >= size { + buf = buf[:size] + } else { + buf = make([]byte, size) + } + } + i := uint64(0) + + { + + buf[0+0] = byte(d.I8 >> 0) + + } + { + + buf[0+1] = byte(d.I16 >> 0) + + buf[1+1] = byte(d.I16 >> 8) + + } + { + + buf[0+3] = byte(d.I32 >> 0) + + buf[1+3] = byte(d.I32 >> 8) + + buf[2+3] = byte(d.I32 >> 16) + + buf[3+3] = byte(d.I32 >> 24) + + } + { + + buf[0+7] = byte(d.I64 >> 0) + + buf[1+7] = byte(d.I64 >> 8) + + buf[2+7] = byte(d.I64 >> 16) + + buf[3+7] = byte(d.I64 >> 24) + + buf[4+7] = byte(d.I64 >> 32) + + buf[5+7] = byte(d.I64 >> 40) + + buf[6+7] = byte(d.I64 >> 48) + + buf[7+7] = byte(d.I64 >> 56) + + } + { + + buf[0+15] = byte(d.UI8 >> 0) + + } + { + + buf[0+16] = byte(d.UI16 >> 0) + + buf[1+16] = byte(d.UI16 >> 8) + + } + { + + buf[0+18] = byte(d.UI32 >> 0) + + buf[1+18] = byte(d.UI32 >> 8) + + buf[2+18] = byte(d.UI32 >> 16) + + buf[3+18] = byte(d.UI32 >> 24) + + } + { + + buf[0+22] = byte(d.UI64 >> 0) + + buf[1+22] = byte(d.UI64 >> 8) + + buf[2+22] = byte(d.UI64 >> 16) + + buf[3+22] = byte(d.UI64 >> 24) + + buf[4+22] = byte(d.UI64 >> 32) + + buf[5+22] = byte(d.UI64 >> 40) + + buf[6+22] = byte(d.UI64 >> 48) + + buf[7+22] = byte(d.UI64 >> 56) + + } + { + l := uint64(len(d.S)) + + { + + t := uint64(l) + + for t >= 0x80 { + buf[i+30] = byte(t) | 0x80 + t >>= 7 + i++ + } + buf[i+30] = byte(t) + i++ + + } + copy(buf[i+30:], d.S) + i += l + } + { + if d.Sp == nil { + buf[i+30] = 0 + } else { + buf[i+30] = 1 + + { + l := uint64(len((*d.Sp))) + + { + + t := uint64(l) + + for t >= 0x80 { + buf[i+31] = byte(t) | 0x80 + t >>= 7 + i++ + } + buf[i+31] = byte(t) + i++ + + } + copy(buf[i+31:], (*d.Sp)) + i += l + } + i += 0 + } + } + { + l := uint64(len(d.Sa)) + + { + + t := uint64(l) + + for t >= 0x80 { + buf[i+31] = byte(t) | 0x80 + t >>= 7 + i++ + } + buf[i+31] = byte(t) + i++ + + } + for k0 := range d.Sa { + + { + l := uint64(len(d.Sa[k0])) + + { + + t := uint64(l) + + for t >= 0x80 { + buf[i+31] = byte(t) | 0x80 + t >>= 7 + i++ + } + buf[i+31] = byte(t) + i++ + + } + copy(buf[i+31:], d.Sa[k0]) + i += l + } + + } + } + { + if d.Sap == nil { + buf[i+31] = 0 + } else { + buf[i+31] = 1 + + { + l := uint64(len((*d.Sap))) + + { + + t := uint64(l) + + for t >= 0x80 { + buf[i+32] = byte(t) | 0x80 + t >>= 7 + i++ + } + buf[i+32] = byte(t) + i++ + + } + for k0 := range *d.Sap { + + { + l := uint64(len((*d.Sap)[k0])) + + { + + t := uint64(l) + + for t >= 0x80 { + buf[i+32] = byte(t) | 0x80 + t >>= 7 + i++ + } + buf[i+32] = byte(t) + i++ + + } + copy(buf[i+32:], (*d.Sap)[k0]) + i += l + } + + } + } + i += 0 + } + } + { + buf[i+32] = d.B + } + { + if d.Bp == nil { + buf[i+33] = 0 + } else { + buf[i+33] = 1 + + { + buf[i+34] = (*d.Bp) + } + i++ + } + } + { + l := uint64(len(d.Ba)) + + { + + t := uint64(l) + + for t >= 0x80 { + buf[i+34] = byte(t) | 0x80 + t >>= 7 + i++ + } + buf[i+34] = byte(t) + i++ + + } + copy(buf[i+34:], d.Ba) + i += l + } + { + if d.Bap == nil { + buf[i+34] = 0 + } else { + buf[i+34] = 1 + + { + l := uint64(len((*d.Bap))) + + { + + t := uint64(l) + + for t >= 0x80 { + buf[i+35] = byte(t) | 0x80 + t >>= 7 + i++ + } + buf[i+35] = byte(t) + i++ + + } + copy(buf[i+35:], (*d.Bap)) + i += l + } + i += 0 + } + } + return buf[:i+35], nil +} + +func (d *GenCodeTestStruct) GenCodeUnmarshal(buf []byte) (uint64, error) { + i := uint64(0) + + { + + d.I8 = 0 | (int8(buf[i+0+0]) << 0) + + } + { + + d.I16 = 0 | (int16(buf[i+0+1]) << 0) | (int16(buf[i+1+1]) << 8) + + } + { + + d.I32 = 0 | (int32(buf[i+0+3]) << 0) | (int32(buf[i+1+3]) << 8) | (int32(buf[i+2+3]) << 16) | (int32(buf[i+3+3]) << 24) + + } + { + + d.I64 = 0 | (int64(buf[i+0+7]) << 0) | (int64(buf[i+1+7]) << 8) | (int64(buf[i+2+7]) << 16) | (int64(buf[i+3+7]) << 24) | (int64(buf[i+4+7]) << 32) | (int64(buf[i+5+7]) << 40) | (int64(buf[i+6+7]) << 48) | (int64(buf[i+7+7]) << 56) + + } + { + + d.UI8 = 0 | (uint8(buf[i+0+15]) << 0) + + } + { + + d.UI16 = 0 | (uint16(buf[i+0+16]) << 0) | (uint16(buf[i+1+16]) << 8) + + } + { + + d.UI32 = 0 | (uint32(buf[i+0+18]) << 0) | (uint32(buf[i+1+18]) << 8) | (uint32(buf[i+2+18]) << 16) | (uint32(buf[i+3+18]) << 24) + + } + { + + d.UI64 = 0 | (uint64(buf[i+0+22]) << 0) | (uint64(buf[i+1+22]) << 8) | (uint64(buf[i+2+22]) << 16) | (uint64(buf[i+3+22]) << 24) | (uint64(buf[i+4+22]) << 32) | (uint64(buf[i+5+22]) << 40) | (uint64(buf[i+6+22]) << 48) | (uint64(buf[i+7+22]) << 56) + + } + { + l := uint64(0) + + { + + bs := uint8(7) + t := uint64(buf[i+30] & 0x7F) + for buf[i+30]&0x80 == 0x80 { + i++ + t |= uint64(buf[i+30]&0x7F) << bs + bs += 7 + } + i++ + + l = t + + } + d.S = string(buf[i+30 : i+30+l]) + i += l + } + { + if buf[i+30] == 1 { + if d.Sp == nil { + d.Sp = new(string) + } + + { + l := uint64(0) + + { + + bs := uint8(7) + t := uint64(buf[i+31] & 0x7F) + for buf[i+31]&0x80 == 0x80 { + i++ + t |= uint64(buf[i+31]&0x7F) << bs + bs += 7 + } + i++ + + l = t + + } + (*d.Sp) = string(buf[i+31 : i+31+l]) + i += l + } + i += 0 + } else { + d.Sp = nil + } + } + { + l := uint64(0) + + { + + bs := uint8(7) + t := uint64(buf[i+31] & 0x7F) + for buf[i+31]&0x80 == 0x80 { + i++ + t |= uint64(buf[i+31]&0x7F) << bs + bs += 7 + } + i++ + + l = t + + } + if uint64(cap(d.Sa)) >= l { + d.Sa = d.Sa[:l] + } else { + d.Sa = make([]string, l) + } + for k0 := range d.Sa { + + { + l := uint64(0) + + { + + bs := uint8(7) + t := uint64(buf[i+31] & 0x7F) + for buf[i+31]&0x80 == 0x80 { + i++ + t |= uint64(buf[i+31]&0x7F) << bs + bs += 7 + } + i++ + + l = t + + } + d.Sa[k0] = string(buf[i+31 : i+31+l]) + i += l + } + + } + } + { + if buf[i+31] == 1 { + if d.Sap == nil { + d.Sap = new([]string) + } + + { + l := uint64(0) + + { + + bs := uint8(7) + t := uint64(buf[i+32] & 0x7F) + for buf[i+32]&0x80 == 0x80 { + i++ + t |= uint64(buf[i+32]&0x7F) << bs + bs += 7 + } + i++ + + l = t + + } + if uint64(cap((*d.Sap))) >= l { + (*d.Sap) = (*d.Sap)[:l] + } else { + (*d.Sap) = make([]string, l) + } + for k0 := range *d.Sap { + + { + l := uint64(0) + + { + + bs := uint8(7) + t := uint64(buf[i+32] & 0x7F) + for buf[i+32]&0x80 == 0x80 { + i++ + t |= uint64(buf[i+32]&0x7F) << bs + bs += 7 + } + i++ + + l = t + + } + (*d.Sap)[k0] = string(buf[i+32 : i+32+l]) + i += l + } + + } + } + i += 0 + } else { + d.Sap = nil + } + } + { + d.B = buf[i+32] + } + { + if buf[i+33] == 1 { + if d.Bp == nil { + d.Bp = new(byte) + } + + { + (*d.Bp) = buf[i+34] + } + i++ + } else { + d.Bp = nil + } + } + { + l := uint64(0) + + { + + bs := uint8(7) + t := uint64(buf[i+34] & 0x7F) + for buf[i+34]&0x80 == 0x80 { + i++ + t |= uint64(buf[i+34]&0x7F) << bs + bs += 7 + } + i++ + + l = t + + } + if uint64(cap(d.Ba)) >= l { + d.Ba = d.Ba[:l] + } else { + d.Ba = make([]byte, l) + } + copy(d.Ba, buf[i+34:]) + i += l + } + { + if buf[i+34] == 1 { + if d.Bap == nil { + d.Bap = new([]byte) + } + + { + l := uint64(0) + + { + + bs := uint8(7) + t := uint64(buf[i+35] & 0x7F) + for buf[i+35]&0x80 == 0x80 { + i++ + t |= uint64(buf[i+35]&0x7F) << bs + bs += 7 + } + i++ + + l = t + + } + if uint64(cap((*d.Bap))) >= l { + (*d.Bap) = (*d.Bap)[:l] + } else { + (*d.Bap) = make([]byte, l) + } + copy((*d.Bap), buf[i+35:]) + i += l + } + i += 0 + } else { + d.Bap = nil + } + } + return i + 35, nil +} diff --git a/formats/dsd/interfaces.go b/formats/dsd/interfaces.go new file mode 100644 index 0000000..cae6052 --- /dev/null +++ b/formats/dsd/interfaces.go @@ -0,0 +1,9 @@ +package dsd + +// GenCodeCompatible is an interface to identify and use gencode compatible structs. +type GenCodeCompatible interface { + // GenCodeMarshal gencode marshalls the struct into the given byte array, or a new one if its too small. + GenCodeMarshal(buf []byte) ([]byte, error) + // GenCodeUnmarshal gencode unmarshalls the struct and returns the bytes read. + GenCodeUnmarshal(buf []byte) (uint64, error) +} diff --git a/formats/dsd/tests.gencode b/formats/dsd/tests.gencode new file mode 100644 index 0000000..bc29f5d --- /dev/null +++ b/formats/dsd/tests.gencode @@ -0,0 +1,23 @@ +struct SimpleTestStruct { + S string + B byte +} + +struct GenCodeTestStructure { + I8 int8 + I16 int16 + I32 int32 + I64 int64 + UI8 uint8 + UI16 uint16 + UI32 uint32 + UI64 uint64 + S string + Sp *string + Sa []string + Sap *[]string + B byte + Bp *byte + Ba []byte + Bap *[]byte +} diff --git a/formats/varint/varint.go b/formats/varint/varint.go index 711e7e0..d0f6129 100644 --- a/formats/varint/varint.go +++ b/formats/varint/varint.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package varint import "errors" diff --git a/formats/varint/varint_test.go b/formats/varint/varint_test.go index 3cf81e1..0de1741 100644 --- a/formats/varint/varint_test.go +++ b/formats/varint/varint_test.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package varint import ( diff --git a/log/flags.go b/log/flags.go index b00a94b..cc86e6b 100644 --- a/log/flags.go +++ b/log/flags.go @@ -3,11 +3,11 @@ package log import "flag" var ( - logLevelFlag string - fileLogLevelsFlag string + logLevelFlag string + pkgLogLevelsFlag string ) func init() { flag.StringVar(&logLevelFlag, "log", "info", "set log level to [trace|debug|info|warning|error|critical]") - flag.StringVar(&fileLogLevelsFlag, "flog", "", "set log level of files: database=trace,firewall=debug") + flag.StringVar(&pkgLogLevelsFlag, "plog", "", "set log level of packages: database=trace,firewall=debug") } diff --git a/log/formatting.go b/log/formatting.go index 3815579..731c5d9 100644 --- a/log/formatting.go +++ b/log/formatting.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package log import ( diff --git a/log/input.go b/log/input.go index 6c64b0c..8b271a6 100644 --- a/log/input.go +++ b/log/input.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package log import ( @@ -11,7 +9,7 @@ import ( ) func fastcheck(level severity) bool { - if fileLevelsActive.IsSet() { + if pkgLevelsActive.IsSet() { return true } if uint32(level) < atomic.LoadUint32(logLevel) { @@ -33,7 +31,7 @@ func log(level severity, msg string, trace *ContextTracer) { } // check if level is enabled - if !fileLevelsActive.IsSet() && uint32(level) < atomic.LoadUint32(logLevel) { + if !pkgLevelsActive.IsSet() && uint32(level) < atomic.LoadUint32(logLevel) { return } @@ -54,12 +52,12 @@ func log(level severity, msg string, trace *ContextTracer) { } // check if level is enabled for file or generally - if fileLevelsActive.IsSet() { + if pkgLevelsActive.IsSet() { fileOnly := strings.Split(file, "/") if len(fileOnly) < 2 { return } - sev, ok := fileLevels[fileOnly[len(fileOnly)-2]] + sev, ok := pkgLevels[fileOnly[len(fileOnly)-2]] if ok { if level < sev { return diff --git a/log/logging.go b/log/logging.go index 366ef1b..419ed04 100644 --- a/log/logging.go +++ b/log/logging.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package log import ( @@ -75,9 +73,9 @@ var ( logLevelInt = uint32(3) logLevel = &logLevelInt - fileLevelsActive = abool.NewBool(false) - fileLevels = make(map[string]severity) - fileLevelsLock sync.Mutex + pkgLevelsActive = abool.NewBool(false) + pkgLevels = make(map[string]severity) + pkgLevelsLock sync.Mutex logsWaiting = make(chan bool, 1) logsWaitingFlag = abool.NewBool(false) @@ -92,15 +90,15 @@ var ( testErrors = abool.NewBool(false) ) -func SetFileLevels(levels map[string]severity) { - fileLevelsLock.Lock() - fileLevels = levels - fileLevelsLock.Unlock() - fileLevelsActive.Set() +func SetPkgLevels(levels map[string]severity) { + pkgLevelsLock.Lock() + pkgLevels = levels + pkgLevelsLock.Unlock() + pkgLevelsActive.Set() } -func UnSetFileLevels() { - fileLevelsActive.UnSet() +func UnSetPkgLevels() { + pkgLevelsActive.UnSet() } func SetLogLevel(level severity) { @@ -143,10 +141,10 @@ func Start() (err error) { } // get and set file loglevels - fileLogLevels := fileLogLevelsFlag - if len(fileLogLevels) > 0 { - newFileLevels := make(map[string]severity) - for _, pair := range strings.Split(fileLogLevels, ",") { + pkgLogLevels := pkgLogLevelsFlag + if len(pkgLogLevels) > 0 { + newPkgLevels := make(map[string]severity) + for _, pair := range strings.Split(pkgLogLevels, ",") { splitted := strings.Split(pair, "=") if len(splitted) != 2 { err = fmt.Errorf("log warning: invalid file log level \"%s\", ignoring", pair) @@ -159,9 +157,9 @@ func Start() (err error) { fmt.Fprintf(os.Stderr, "%s\n", err.Error()) break } - newFileLevels[splitted[0]] = fileLevel + newPkgLevels[splitted[0]] = fileLevel } - SetFileLevels(newFileLevels) + SetPkgLevels(newPkgLevels) } startWriter() diff --git a/log/logging_test.go b/log/logging_test.go index fdbcc9f..aea84f9 100644 --- a/log/logging_test.go +++ b/log/logging_test.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package log import ( diff --git a/log/output.go b/log/output.go index abce8b0..a545230 100644 --- a/log/output.go +++ b/log/output.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package log import ( diff --git a/modules/flags.go b/modules/flags.go index 45328d9..d9ad1e8 100644 --- a/modules/flags.go +++ b/modules/flags.go @@ -3,11 +3,12 @@ package modules import "flag" var ( - helpFlag bool + // HelpFlag triggers printing flag.Usage. It's exported for custom help handling. + HelpFlag bool ) func init() { - flag.BoolVar(&helpFlag, "help", false, "print help") + flag.BoolVar(&HelpFlag, "help", false, "print help") } func parseFlags() error { @@ -15,7 +16,7 @@ func parseFlags() error { // parse flags flag.Parse() - if helpFlag { + if HelpFlag { flag.Usage() return ErrCleanExit } diff --git a/modules/modules.go b/modules/modules.go index b9c6435..5d55784 100644 --- a/modules/modules.go +++ b/modules/modules.go @@ -1,12 +1,14 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package modules import ( + "context" "errors" "fmt" "sync" + "sync/atomic" + "time" + "github.com/safing/portbase/log" "github.com/tevino/abool" ) @@ -20,33 +22,104 @@ var ( // Module represents a module. type Module struct { - Name string + Name string + + // lifecycle mgmt Prepped *abool.AtomicBool Started *abool.AtomicBool Stopped *abool.AtomicBool inTransition *abool.AtomicBool + // lifecycle callback functions prep func() error start func() error stop func() error + // shutdown mgmt + Ctx context.Context + cancelCtx func() + shutdownFlag *abool.AtomicBool + workerGroup sync.WaitGroup + workerCnt *int32 + + // dependency mgmt depNames []string depModules []*Module depReverse []*Module } +// AddWorkers adds workers to the worker waitgroup. This is a failsafe wrapper for sync.Waitgroup. +func (m *Module) AddWorkers(n uint) { + if !m.ShutdownInProgress() { + if atomic.AddInt32(m.workerCnt, int32(n)) > 0 { + // only add to workgroup if cnt is positive (try to compensate wrong usage) + m.workerGroup.Add(int(n)) + } + } +} + +// FinishWorker removes a worker from the worker waitgroup. This is a failsafe wrapper for sync.Waitgroup. +func (m *Module) FinishWorker() { + // check worker cnt + if atomic.AddInt32(m.workerCnt, -1) < 0 { + log.Warningf("modules: %s module tried to finish more workers than added, this may lead to undefined behavior when shutting down", m.Name) + return + } + // also mark worker done in workgroup + m.workerGroup.Done() +} + +// ShutdownInProgress returns whether the module has started shutting down. In most cases, you should use ShuttingDown instead. +func (m *Module) ShutdownInProgress() bool { + return m.shutdownFlag.IsSet() +} + +// ShuttingDown lets you listen for the shutdown signal. +func (m *Module) ShuttingDown() <-chan struct{} { + return m.Ctx.Done() +} + +func (m *Module) shutdown() error { + // signal shutdown + m.shutdownFlag.Set() + m.cancelCtx() + + // wait for workers + done := make(chan struct{}) + go func() { + m.workerGroup.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(3 * time.Second): + return errors.New("timed out while waiting for module workers to finish") + } + + // call shutdown function + return m.stop() +} + func dummyAction() error { return nil } -// Register registers a new module. +// Register registers a new module. The control functions `prep`, `start` and `stop` are technically optional. `stop` is called _after_ all added module workers finished. func Register(name string, prep, start, stop func() error, dependencies ...string) *Module { + ctx, cancelCtx := context.WithCancel(context.Background()) + var workerCnt int32 + newModule := &Module{ Name: name, Prepped: abool.NewBool(false), Started: abool.NewBool(false), Stopped: abool.NewBool(false), inTransition: abool.NewBool(false), + Ctx: ctx, + cancelCtx: cancelCtx, + shutdownFlag: abool.NewBool(false), + workerGroup: sync.WaitGroup{}, + workerCnt: &workerCnt, prep: prep, start: start, stop: stop, @@ -77,7 +150,7 @@ func initDependencies() error { // get dependency depModule, ok := modules[depName] if !ok { - return fmt.Errorf("modules: module %s declares dependency \"%s\", but this module has not been registered", m.Name, depName) + return fmt.Errorf("module %s declares dependency \"%s\", but this module has not been registered", m.Name, depName) } // link together diff --git a/modules/modules_test.go b/modules/modules_test.go index f18c506..5a85378 100644 --- a/modules/modules_test.go +++ b/modules/modules_test.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package modules import ( @@ -204,11 +202,11 @@ func TestErrors(t *testing.T) { startCompleteSignal = make(chan struct{}) // test help flag - helpFlag = true + HelpFlag = true err = Start() if err == nil { t.Error("should fail") } - helpFlag = false + HelpFlag = false } diff --git a/modules/start.go b/modules/start.go index 8d6ae88..8ac5acf 100644 --- a/modules/start.go +++ b/modules/start.go @@ -38,6 +38,7 @@ func Start() error { // inter-link modules err := initDependencies() if err != nil { + fmt.Fprintf(os.Stderr, "CRITICAL ERROR: failed to initialize modules: %s\n", err) return err } diff --git a/modules/stop.go b/modules/stop.go index 9bb90b0..ad1a79e 100644 --- a/modules/stop.go +++ b/modules/stop.go @@ -74,7 +74,7 @@ func stopModules() error { go func() { reports <- &report{ module: execM, - err: execM.stop(), + err: execM.shutdown(), } }() } diff --git a/notifications/cleaner.go b/notifications/cleaner.go index 2befef1..efd0408 100644 --- a/notifications/cleaner.go +++ b/notifications/cleaner.go @@ -2,6 +2,8 @@ package notifications import ( "time" + + "github.com/safing/portbase/log" ) func cleaner() { @@ -10,31 +12,55 @@ func cleaner() { case <-shutdownSignal: shutdownWg.Done() return - case <-time.After(1 * time.Minute): + case <-time.After(5 * time.Second): cleanNotifications() } } func cleanNotifications() { - threshold := time.Now().Add(-2 * time.Minute).Unix() - maxThreshold := time.Now().Add(-72 * time.Hour).Unix() + now := time.Now().Unix() + finishedThreshhold := time.Now().Add(-10 * time.Second).Unix() + executionTimelimit := time.Now().Add(-24 * time.Hour).Unix() + fallbackTimelimit := time.Now().Add(-72 * time.Hour).Unix() notsLock.Lock() defer notsLock.Unlock() for _, n := range nots { n.Lock() - if n.Expires != 0 && n.Expires < threshold || - n.Executed != 0 && n.Executed < threshold || - n.Created < maxThreshold { + switch { + case n.Executed != 0: // notification was fully handled + // wait for a short time before deleting + if n.Executed < finishedThreshhold { + go deleteNotification(n) + } + case n.Responded != 0: + // waiting for execution + if n.Responded < executionTimelimit { + go deleteNotification(n) + } + case n.Expires != 0: + // expired without response + if n.Expires < now { + go deleteNotification(n) + } + case n.Created != 0: + // fallback: delete after 3 days after creation + if n.Created < fallbackTimelimit { + go deleteNotification(n) - // delete - n.Meta().Delete() - delete(nots, n.ID) - - // save (ie. propagate delete) - go n.Save() + } + default: + // invalid, impossible to determine cleanup timeframe, delete now + go deleteNotification(n) } n.Unlock() } } + +func deleteNotification(n *Notification) { + err := n.Delete() + if err != nil { + log.Debugf("notifications: failed to delete %s: %s", n.ID, err) + } +} diff --git a/notifications/database.go b/notifications/database.go index 9779aae..2cfd29f 100644 --- a/notifications/database.go +++ b/notifications/database.go @@ -43,6 +43,26 @@ type StorageInterface struct { storage.InjectBase } +func registerAsDatabase() error { + _, err := database.Register(&database.Database{ + Name: "notifications", + Description: "Notifications", + StorageType: "injected", + PrimaryAPI: "", + }) + if err != nil { + return err + } + + controller, err := database.InjectDatabase("notifications", &StorageInterface{}) + if err != nil { + return err + } + + dbController = controller + return nil +} + // Get returns a database record. func (s *StorageInterface) Get(key string) (record.Record, error) { notsLock.RLock() @@ -78,6 +98,10 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { // send all notifications for _, n := range nots { + if n.Meta().IsDeleted() { + continue + } + if q.MatchesKey(n.DatabaseKey()) && q.MatchesRecord(n) { it.Next <- n } @@ -86,26 +110,6 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { it.Finish(nil) } -func registerAsDatabase() error { - _, err := database.Register(&database.Database{ - Name: "notifications", - Description: "Notifications", - StorageType: "injected", - PrimaryAPI: "", - }) - if err != nil { - return err - } - - controller, err := database.InjectDatabase("notifications", &StorageInterface{}) - if err != nil { - return err - } - - dbController = controller - return nil -} - // Put stores a record in the database. func (s *StorageInterface) Put(r record.Record) error { // record is already locked! @@ -124,36 +128,75 @@ func (s *StorageInterface) Put(r record.Record) error { } // continue in goroutine - go updateNotificationFromDatabasePut(n, key) + go UpdateNotification(n, key) return nil } -func updateNotificationFromDatabasePut(n *Notification, key string) { +// UpdateNotification updates a notification with input from a database action. Notification will not be saved/propagated if there is no valid change. +func UpdateNotification(n *Notification, key string) { + n.Lock() + defer n.Unlock() + // seperate goroutine in order to correctly lock notsLock notsLock.RLock() origN, ok := nots[key] notsLock.RUnlock() + save := false + + // ignore if already deleted + if ok && origN.Meta().IsDeleted() { + ok = false + } + if ok { - // existing notification, update selected action ID only - n.Lock() - defer n.Unlock() - if n.SelectedActionID != "" { - log.Tracef("notifications: user selected action for %s: %s", n.ID, n.SelectedActionID) - go origN.SelectAndExecuteAction(n.SelectedActionID) - } + // existing notification + // only update select attributes + origN.Lock() + defer origN.Unlock() } else { - // accept new notification as is - notsLock.Lock() - nots[key] = n - notsLock.Unlock() + // new notification (from external source): old == new + origN = n + save = true + } + + switch { + case n.SelectedActionID != "" && n.Responded == 0: + // select action, if not yet already handled + log.Tracef("notifications: selected action for %s: %s", n.ID, n.SelectedActionID) + origN.selectAndExecuteAction(n.SelectedActionID) + save = true + case origN.Executed == 0 && n.Executed != 0: + log.Tracef("notifications: action for %s executed externally", n.ID) + origN.Executed = n.Executed + save = true + } + + if save { + // we may be locking + go origN.Save() } } // Delete deletes a record from the database. func (s *StorageInterface) Delete(key string) error { - return ErrNoDelete + // transform key + if strings.HasPrefix(key, "all/") { + key = strings.TrimPrefix(key, "all/") + } else { + return storage.ErrNotFound + } + + // get notification + notsLock.Lock() + n, ok := nots[key] + notsLock.Unlock() + if !ok { + return storage.ErrNotFound + } + // delete + return n.Delete() } // ReadOnly returns whether the database is read only. diff --git a/notifications/module.go b/notifications/module.go index 21e5997..63902e2 100644 --- a/notifications/module.go +++ b/notifications/module.go @@ -12,7 +12,7 @@ var ( ) func init() { - modules.Register("notifications", nil, start, nil, "core") + modules.Register("notifications", nil, start, nil, "base", "database") } func start() error { diff --git a/notifications/notification.go b/notifications/notification.go index 59d28b7..bb5fe58 100644 --- a/notifications/notification.go +++ b/notifications/notification.go @@ -5,8 +5,11 @@ import ( "sync" "time" + "github.com/safing/portbase/database" "github.com/safing/portbase/database/record" "github.com/safing/portbase/log" + + uuid "github.com/satori/go.uuid" ) // Notification types @@ -20,7 +23,9 @@ const ( type Notification struct { record.Base - ID string + ID string + GUID string + Message string // MessageTemplate string // MessageData []string @@ -39,6 +44,7 @@ type Notification struct { lock sync.Mutex actionFunction func(*Notification) // call function to process action actionTrigger chan string // and/or send to a channel + expiredTrigger chan struct{} // closed on expire } // Action describes an action that can be taken for a notification. @@ -62,12 +68,6 @@ func Get(id string) *Notification { return nil } -// Init initializes a Notification and returns it. -func (n *Notification) Init() *Notification { - n.Created = time.Now().Unix() - return n -} - // Save saves the notification and returns it. func (n *Notification) Save() *Notification { notsLock.Lock() @@ -75,6 +75,13 @@ func (n *Notification) Save() *Notification { n.Lock() defer n.Unlock() + // initialize + if n.Created == 0 { + n.Created = time.Now().Unix() + } + if n.GUID == "" { + n.GUID = uuid.NewV4().String() + } // check key if n.DatabaseKey() == "" { n.SetKey(fmt.Sprintf("notifications:all/%s", n.ID)) @@ -104,11 +111,12 @@ func (n *Notification) Save() *Notification { Executed: n.Executed, } duplicate.SetMeta(n.Meta().Duplicate()) - duplicate.SetKey(fmt.Sprintf("%s/%s", persistentBasePath, n.ID)) + key := fmt.Sprintf("%s/%s", persistentBasePath, n.ID) + duplicate.SetKey(key) go func() { err := dbInterface.Put(duplicate) if err != nil { - log.Warningf("notifications: failed to persist notification %s: %s", n.Key(), err) + log.Warningf("notifications: failed to persist notification %s: %s", key, err) } }() } @@ -145,42 +153,85 @@ func (n *Notification) MakeAck() *Notification { // Response waits for the user to respond to the notification and returns the selected action. func (n *Notification) Response() <-chan string { n.lock.Lock() - defer n.lock.Unlock() - if n.actionTrigger == nil { n.actionTrigger = make(chan string) } + n.lock.Unlock() return n.actionTrigger } -// Cancel (prematurely) destroys a notification. -func (n *Notification) Cancel() { +// Update updates/resends a notification if it was not already responded to. +func (n *Notification) Update(expires int64) { + responded := true + n.lock.Lock() + if n.Responded == 0 { + responded = false + n.Expires = expires + } + n.lock.Unlock() + + // save if not yet responded + if !responded { + n.Save() + } +} + +// Delete (prematurely) cancels and deletes a notification. +func (n *Notification) Delete() error { notsLock.Lock() defer notsLock.Unlock() n.Lock() defer n.Unlock() - // delete + // mark as deleted n.Meta().Delete() + + // delete from internal storage delete(nots, n.ID) - // save (ie. propagate delete) - go n.Save() + // close expired + if n.expiredTrigger != nil { + close(n.expiredTrigger) + n.expiredTrigger = nil + } + + // push update + dbController.PushUpdate(n) + + // delete from persistent storage + if n.Persistent && persistentBasePath != "" { + key := fmt.Sprintf("%s/%s", persistentBasePath, n.ID) + err := dbInterface.Delete(key) + if err != nil && err != database.ErrNotFound { + return fmt.Errorf("failed to delete persisted notification %s from database: %s", key, err) + } + } + + return nil } -// SelectAndExecuteAction sets the user response and executes/triggers the action, if possible. -func (n *Notification) SelectAndExecuteAction(id string) { - n.Lock() - defer n.Unlock() +// Expired notifies the caller when the notification has expired. +func (n *Notification) Expired() <-chan struct{} { + n.lock.Lock() + if n.expiredTrigger == nil { + n.expiredTrigger = make(chan struct{}) + } + n.lock.Unlock() - // update selection + return n.expiredTrigger +} + +// selectAndExecuteAction sets the user response and executes/triggers the action, if possible. +func (n *Notification) selectAndExecuteAction(id string) { + // abort if already executed if n.Executed != 0 { - // we already executed return } - n.SelectedActionID = id + + // set response n.Responded = time.Now().Unix() + n.SelectedActionID = id // execute executed := false @@ -195,7 +246,7 @@ func (n *Notification) SelectAndExecuteAction(id string) { select { case n.actionTrigger <- n.SelectedActionID: executed = true - default: + case <-time.After(100 * time.Millisecond): // mitigate race conditions break triggerAll } } @@ -205,8 +256,6 @@ func (n *Notification) SelectAndExecuteAction(id string) { if executed { n.Executed = time.Now().Unix() } - - go n.Save() } // AddDataSubject adds the data subject to the notification. This is the only way how a data subject should be added - it avoids locking problems. diff --git a/taskmanager/microtasks.go b/taskmanager/microtasks.go index e7232e3..ad758ff 100644 --- a/taskmanager/microtasks.go +++ b/taskmanager/microtasks.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package taskmanager import ( diff --git a/taskmanager/microtasks_test.go b/taskmanager/microtasks_test.go index eeda147..eb2fd74 100644 --- a/taskmanager/microtasks_test.go +++ b/taskmanager/microtasks_test.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package taskmanager import ( diff --git a/taskmanager/queuedtasks.go b/taskmanager/queuedtasks.go index 4112c65..7925701 100644 --- a/taskmanager/queuedtasks.go +++ b/taskmanager/queuedtasks.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package taskmanager import ( diff --git a/taskmanager/queuedtasks_test.go b/taskmanager/queuedtasks_test.go index 72b31fb..0ed33ae 100644 --- a/taskmanager/queuedtasks_test.go +++ b/taskmanager/queuedtasks_test.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package taskmanager import ( diff --git a/taskmanager/scheduledtasks.go b/taskmanager/scheduledtasks.go index 86b3c2d..32c4780 100644 --- a/taskmanager/scheduledtasks.go +++ b/taskmanager/scheduledtasks.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package taskmanager import ( diff --git a/taskmanager/scheduledtasks_test.go b/taskmanager/scheduledtasks_test.go index 52eb91e..18467d1 100644 --- a/taskmanager/scheduledtasks_test.go +++ b/taskmanager/scheduledtasks_test.go @@ -1,5 +1,3 @@ -// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file. - package taskmanager import ( diff --git a/utils/fs.go b/utils/fs.go index 46f67b3..89ee87d 100644 --- a/utils/fs.go +++ b/utils/fs.go @@ -6,6 +6,8 @@ import ( "runtime" ) +const isWindows = runtime.GOOS == "windows" + // EnsureDirectory ensures that the given directoy exists and that is has the given permissions set. // If path is a file, it is deleted and a directory created. // If a directory is created, also all missing directories up to the required one are created with the given permissions. @@ -16,7 +18,7 @@ func EnsureDirectory(path string, perm os.FileMode) error { // file exists if f.IsDir() { // directory exists, check permissions - if runtime.GOOS == "windows" { + if isWindows { // TODO: set correct permission on windows // acl.Chmod(path, perm) } else if f.Mode().Perm() != perm { diff --git a/utils/osdetail/colors_windows.go b/utils/osdetail/colors_windows.go index 214b755..9a1ad7d 100644 --- a/utils/osdetail/colors_windows.go +++ b/utils/osdetail/colors_windows.go @@ -20,12 +20,13 @@ func EnableColorSupport() bool { if !colorSupportChecked { colorSupport = enableColorSupport() + colorSupportChecked = true } return colorSupport } func enableColorSupport() bool { - if IsWindowsVersion("10.") { + if IsAtLeastWindowsNTVersionWithDefault("10", false) { // check if windows.Stdout is file if windows.GetFileInformationByHandle(windows.Stdout, &windows.ByHandleFileInformation{}) == nil { diff --git a/utils/osdetail/version_windows.go b/utils/osdetail/version_windows.go index 74d502d..43bb0dc 100644 --- a/utils/osdetail/version_windows.go +++ b/utils/osdetail/version_windows.go @@ -1,56 +1,100 @@ package osdetail import ( - "os/exec" + "fmt" "regexp" - "strings" "sync" -) -// FIXME: use https://godoc.org/github.com/shirou/gopsutil/host#PlatformInformation instead + "github.com/hashicorp/go-version" + versionCmp "github.com/hashicorp/go-version" + "github.com/shirou/gopsutil/host" +) var ( versionRe = regexp.MustCompile(`[0-9\.]+`) - windowsVersion string + windowsNTVersion string + windowsNTVersionForCmp *versionCmp.Version fetching sync.Mutex fetched bool ) -func fetchVersion() { +// WindowsNTVersion returns the current Windows version. +func WindowsNTVersion() (string, error) { + var err error + fetching.Lock() + defer fetching.Unlock() + if !fetched { - fetched = true + _, _, windowsNTVersion, err = host.PlatformInformation() - output, err := exec.Command("cmd", "ver").Output() if err != nil { - return + return "", fmt.Errorf("failed to obtain Windows-Version: %s", err) } - match := versionRe.Find(output) - if match == nil { - return + windowsNTVersionForCmp, err = version.NewVersion(windowsNTVersion) + + if err != nil { + return "", fmt.Errorf("failed to parse Windows-Version %s: %s", windowsNTVersion, err) } - windowsVersion = string(match) + fetched = true } + + return windowsNTVersion, err } -// WindowsVersion returns the current Windows version. -func WindowsVersion() string { - fetching.Lock() - defer fetching.Unlock() - fetchVersion() +// IsAtLeastWindowsNTVersion returns whether the current WindowsNT version is at least the given version or newer. +func IsAtLeastWindowsNTVersion(version string) (bool, error) { + _, err := WindowsNTVersion() + if err != nil { + return false, err + } - return windowsVersion + versionForCmp, err := versionCmp.NewVersion(version) + if err != nil { + return false, err + } + + return windowsNTVersionForCmp.GreaterThanOrEqual(versionForCmp), nil } -// IsWindowsVersion returns whether the given version matches (HasPrefix) the current Windows version. -func IsWindowsVersion(version string) bool { - fetching.Lock() - defer fetching.Unlock() - fetchVersion() - - // TODO: we can do better. - return strings.HasPrefix(windowsVersion, version) +// IsAtLeastWindowsNTVersionWithDefault is like IsAtLeastWindowsNTVersion(), but keeps the Error and returns the default Value in Errorcase +func IsAtLeastWindowsNTVersionWithDefault(v string, defaultValue bool) bool { + val, err := IsAtLeastWindowsNTVersion(v) + if err != nil { + return defaultValue + } + return val +} + +// IsAtLeastWindowsVersion returns whether the current Windows version is at least the given version or newer. +func IsAtLeastWindowsVersion(v string) (bool, error) { + var ( + NTVersion string + ) + switch v { + case "7": + NTVersion = "6.1" + case "8": + NTVersion = "6.2" + case "8.1": + NTVersion = "6.3" + case "10": + NTVersion = "10" + default: + return false, fmt.Errorf("failed to compare Windows-Version: Windows %s is unknown", v) + } + + return IsAtLeastWindowsNTVersion(NTVersion) +} + +// IsAtLeastWindowsVersionWithDefault is like IsAtLeastWindowsVersion(), but keeps the Error and returns the default Value in Errorcase +func IsAtLeastWindowsVersionWithDefault(v string, defaultValue bool) bool { + val, err := IsAtLeastWindowsVersion(v) + if err != nil { + return defaultValue + } + return val } diff --git a/utils/osdetail/version_windows_test.go b/utils/osdetail/version_windows_test.go index 251fb22..c9cb409 100644 --- a/utils/osdetail/version_windows_test.go +++ b/utils/osdetail/version_windows_test.go @@ -2,8 +2,28 @@ package osdetail import "testing" -func TestWindowsVersion(t *testing.T) { - if WindowsVersion() == "" { - t.Fatal("could not get windows version") +func TestWindowsNTVersion(t *testing.T) { + if str, err := WindowsNTVersion(); str == "" || err != nil { + t.Fatalf("failed to obtain windows version: %s", err) + } +} + +func TestIsAtLeastWindowsNTVersion(t *testing.T) { + ret, err := IsAtLeastWindowsNTVersion("6") + if err != nil { + t.Fatalf("failed to compare windows versions: %s", err) + } + if !ret { + t.Fatalf("WindowsNTVersion is less than 6 (Vista)") + } +} + +func TestIsAtLeastWindowsVersion(t *testing.T) { + ret, err := IsAtLeastWindowsVersion("7") + if err != nil { + t.Fatalf("failed to compare windows versions: %s", err) + } + if !ret { + t.Fatalf("WindowsVersion is less than 7") } } diff --git a/utils/structure.go b/utils/structure.go new file mode 100644 index 0000000..5a8b331 --- /dev/null +++ b/utils/structure.go @@ -0,0 +1,139 @@ +package utils + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "sync" +) + +// DirStructure represents a directory structure with permissions that should be enforced. +type DirStructure struct { + sync.Mutex + + Path string + Dir string + Perm os.FileMode + Parent *DirStructure + Children map[string]*DirStructure +} + +// NewDirStructure returns a new DirStructure. +func NewDirStructure(path string, perm os.FileMode) *DirStructure { + return &DirStructure{ + Path: path, + Perm: perm, + Children: make(map[string]*DirStructure), + } +} + +// ChildDir adds a new child DirStructure and returns it. Should the child already exist, the existing child is returned and the permissions are updated. +func (ds *DirStructure) ChildDir(dirName string, perm os.FileMode) (child *DirStructure) { + ds.Lock() + defer ds.Unlock() + + // if exists, update + child, ok := ds.Children[dirName] + if ok { + child.Perm = perm + return child + } + + // create new + new := &DirStructure{ + Path: filepath.Join(ds.Path, dirName), + Dir: dirName, + Perm: perm, + Parent: ds, + Children: make(map[string]*DirStructure), + } + ds.Children[dirName] = new + return new +} + +// Ensure ensures that the specified directory structure (from the first parent on) exists. +func (ds *DirStructure) Ensure() error { + return ds.EnsureAbsPath(ds.Path) +} + +// EnsureRelPath ensures that the specified directory structure (from the first parent on) and the given relative path (to the DirStructure) exists. +func (ds *DirStructure) EnsureRelPath(dirPath string) error { + return ds.EnsureAbsPath(filepath.Join(ds.Path, dirPath)) +} + +// EnsureRelDir ensures that the specified directory structure (from the first parent on) and the given relative path (to the DirStructure) exists. +func (ds *DirStructure) EnsureRelDir(dirNames ...string) error { + return ds.EnsureAbsPath(filepath.Join(append([]string{ds.Path}, dirNames...)...)) +} + +// EnsureAbsPath ensures that the specified directory structure (from the first parent on) and the given absolute path exists. +// If the given path is outside the DirStructure, an error will be returned. +func (ds *DirStructure) EnsureAbsPath(dirPath string) error { + // always start at the top + if ds.Parent != nil { + return ds.Parent.EnsureAbsPath(dirPath) + } + + // check if root + if dirPath == ds.Path { + return ds.ensure(nil) + } + + // check scope + slashedPath := ds.Path + // add slash to end + if !strings.HasSuffix(slashedPath, string(filepath.Separator)) { + slashedPath += string(filepath.Separator) + } + // check if given path is in scope + if !strings.HasPrefix(dirPath, slashedPath) { + return fmt.Errorf(`path "%s" is outside of DirStructure scope`, dirPath) + } + + // get relative path + relPath, err := filepath.Rel(ds.Path, dirPath) + if err != nil { + return fmt.Errorf("failed to get relative path: %s", err) + } + + // split to path elements + pathDirs := strings.Split(filepath.ToSlash(relPath), "/") + + // start checking + return ds.ensure(pathDirs) +} + +func (ds *DirStructure) ensure(pathDirs []string) error { + ds.Lock() + defer ds.Unlock() + + // check current dir + err := EnsureDirectory(ds.Path, ds.Perm) + if err != nil { + return err + } + + if len(pathDirs) == 0 { + // we reached the end! + return nil + } + + child, ok := ds.Children[pathDirs[0]] + if !ok { + // we have reached the end of the defined dir structure + // ensure all remaining dirs + dirPath := ds.Path + for _, dir := range pathDirs { + dirPath = filepath.Join(dirPath, dir) + err := EnsureDirectory(dirPath, ds.Perm) + if err != nil { + return err + } + } + return nil + } + + // we got a child, continue + return child.ensure(pathDirs[1:]) +} diff --git a/utils/structure_test.go b/utils/structure_test.go new file mode 100644 index 0000000..d0c3e5a --- /dev/null +++ b/utils/structure_test.go @@ -0,0 +1,72 @@ +// +build !windows + +package utils + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" +) + +func ExampleDirStructure() { + // output: + // / [755] + // /repo [777] + // /repo/b [755] + // /repo/b/c [750] + // /repo/b/d [755] + // /repo/b/d/e [755] + // /repo/b/d/f [755] + // /secret [700] + + basePath, err := ioutil.TempDir("", "") + if err != nil { + fmt.Println(err) + return + } + + ds := NewDirStructure(basePath, 0755) + secret := ds.ChildDir("secret", 0700) + repo := ds.ChildDir("repo", 0777) + _ = repo.ChildDir("a", 0700) + b := repo.ChildDir("b", 0755) + c := b.ChildDir("c", 0750) + + err = ds.Ensure() + if err != nil { + fmt.Println(err) + } + + err = c.Ensure() + if err != nil { + fmt.Println(err) + } + + err = secret.Ensure() + if err != nil { + fmt.Println(err) + } + + err = b.EnsureRelDir("d", "e") + if err != nil { + fmt.Println(err) + } + + err = b.EnsureRelPath("d/f") + if err != nil { + fmt.Println(err) + } + + filepath.Walk(basePath, func(path string, info os.FileInfo, err error) error { + if err == nil { + dir := strings.TrimPrefix(path, basePath) + if dir == "" { + dir = "/" + } + fmt.Printf("%s [%o]\n", dir, info.Mode().Perm()) + } + return nil + }) +}