Release to master

This commit is contained in:
Daniel 2019-08-19 11:29:47 +02:00 committed by GitHub
commit d63a514ee2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
68 changed files with 2274 additions and 508 deletions

122
Gopkg.lock generated
View file

@ -9,6 +9,14 @@
pruneopts = "UT" pruneopts = "UT"
revision = "e2d15f34fcf99d5dbb871c820ec73f710fca9815" revision = "e2d15f34fcf99d5dbb871c820ec73f710fca9815"
[[projects]]
digest = "1:e92f5581902c345eb4ceffdcd4a854fb8f73cf436d47d837d1ec98ef1fe0a214"
name = "github.com/StackExchange/wmi"
packages = ["."]
pruneopts = "UT"
revision = "5d049714c4a64225c3c79a7cf7d02f7fb5b96338"
version = "1.0.0"
[[projects]] [[projects]]
digest = "1:dbd3a713434b6f32d9459b1e6786ad58cec128470b58555cdc7b3b7314a1706f" digest = "1:dbd3a713434b6f32d9459b1e6786ad58cec128470b58555cdc7b3b7314a1706f"
name = "github.com/aead/serpent" name = "github.com/aead/serpent"
@ -34,19 +42,19 @@
version = "v1.1.1" version = "v1.1.1"
[[projects]] [[projects]]
digest = "1:5f5090f05382959db941fa45acbeb7f4c5241aa8ac0f8f4393dec696e5953f53" digest = "1:6be8582a4f52ba2851d8a039eb9c3a3b90334b2820563d71e97de35580da128e"
name = "github.com/dgraph-io/badger" name = "github.com/dgraph-io/badger"
packages = [ packages = [
".", ".",
"options", "options",
"protos", "pb",
"skl", "skl",
"table", "table",
"y", "y",
] ]
pruneopts = "UT" pruneopts = "UT"
revision = "99233d725dbdd26d156c61b2f42ae1671b794656" revision = "2fa005c9d4bf695277ab5214c1fbce3735b9562a"
version = "v1.5.4" version = "v1.6.0"
[[projects]] [[projects]]
branch = "master" branch = "master"
@ -57,12 +65,31 @@
revision = "6a90982ecee230ff6cba02d5bd386acc030be9d3" revision = "6a90982ecee230ff6cba02d5bd386acc030be9d3"
[[projects]] [[projects]]
digest = "1:318f1c959a8a740366fce4b1e1eb2fd914036b4af58fbd0a003349b305f118ad" digest = "1:6f9339c912bbdda81302633ad7e99a28dfa5a639c864061f1929510a9a64aa74"
name = "github.com/dustin/go-humanize"
packages = ["."]
pruneopts = "UT"
revision = "9f541cc9db5d55bce703bd99987c9d5cb8eea45e"
version = "v1.0.0"
[[projects]]
digest = "1:440028f55cb322d8cb5b9d5ebec298a00b7d74690a658fe6b1c0c0b44341bfae"
name = "github.com/go-ole/go-ole"
packages = [
".",
"oleutil",
]
pruneopts = "UT"
revision = "97b6244175ae18ea6eef668034fd6565847501c9"
version = "v1.2.4"
[[projects]]
digest = "1:573ca21d3669500ff845bdebee890eb7fc7f0f50c59f2132f2a0c6b03d85086a"
name = "github.com/golang/protobuf" name = "github.com/golang/protobuf"
packages = ["proto"] packages = ["proto"]
pruneopts = "UT" pruneopts = "UT"
revision = "b5d812f8a3706043e23a9cd5babf2e5423744d30" revision = "6c65a5562fc06764971b7c5d05c76c75e84bdbf7"
version = "v1.3.1" version = "v1.3.2"
[[projects]] [[projects]]
digest = "1:c3388642e07731a240e14f4bc7207df59cfcc009447c657b9de87fec072d07e3" digest = "1:c3388642e07731a240e14f4bc7207df59cfcc009447c657b9de87fec072d07e3"
@ -88,9 +115,17 @@
revision = "66b9c49e59c6c48f0ffce28c2d8b8a5678502c6d" revision = "66b9c49e59c6c48f0ffce28c2d8b8a5678502c6d"
version = "v1.4.0" version = "v1.4.0"
[[projects]]
digest = "1:88e0b0baeb9072f0a4afbcf12dda615fc8be001d1802357538591155998da21b"
name = "github.com/hashicorp/go-version"
packages = ["."]
pruneopts = "UT"
revision = "ac23dc3fea5d1a983c43f6a0f6e2c13f0195d8bd"
version = "v1.2.0"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:ae08d850ba158ea3ba4a7bb90f8372608172d8920644e5a6693b940a1f4e5d01" digest = "1:7e8b852581596acce37bcb939a05d7d5ff27156045b50057e659e299c16fc1ca"
name = "github.com/mmcloughlin/avo" name = "github.com/mmcloughlin/avo"
packages = [ packages = [
"attr", "attr",
@ -108,7 +143,7 @@
"x86", "x86",
] ]
pruneopts = "UT" pruneopts = "UT"
revision = "83fbad1a6b3cba8ac7711170e57953fd12cdc40a" revision = "bb615f61ce85790a1667efc145c66e917cce1a39"
[[projects]] [[projects]]
digest = "1:cf31692c14422fa27c83a05292eb5cbe0fb2775972e8f1f8446a71549bd8980b" digest = "1:cf31692c14422fa27c83a05292eb5cbe0fb2775972e8f1f8446a71549bd8980b"
@ -118,6 +153,22 @@
revision = "ba968bfe8b2f7e042a574c888954fccecfa385b4" revision = "ba968bfe8b2f7e042a574c888954fccecfa385b4"
version = "v0.8.1" version = "v0.8.1"
[[projects]]
branch = "develop"
digest = "1:d88649ff4a4a0746857dd9e39915aedddce2b08e442ac131a91e573cd45bde93"
name = "github.com/safing/portmaster"
packages = ["core/structure"]
pruneopts = "UT"
revision = "26c307b7a0db78d91b35ef9020706f106ebef8b6"
[[projects]]
digest = "1:274f67cb6fed9588ea2521ecdac05a6d62a8c51c074c1fccc6a49a40ba80e925"
name = "github.com/satori/go.uuid"
packages = ["."]
pruneopts = "UT"
revision = "f58768cc1a7a7e77a3bd49e98cdd21419399b6a3"
version = "v1.2.0"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:f1ee4af7c43f206d87f13644636e3710a05e499a084a32ec2cc7d8aa25cef1aa" digest = "1:f1ee4af7c43f206d87f13644636e3710a05e499a084a32ec2cc7d8aa25cef1aa"
@ -134,6 +185,29 @@
pruneopts = "UT" pruneopts = "UT"
revision = "e9f377c596061894b7f9ee69aab61e62c3ccc13e" revision = "e9f377c596061894b7f9ee69aab61e62c3ccc13e"
[[projects]]
digest = "1:7ebe52387a847d276a9b9e7b379b2e3e4d536843ad840a9c4426a0152e6b3d54"
name = "github.com/shirou/gopsutil"
packages = [
"cpu",
"host",
"internal/common",
"mem",
"net",
"process",
]
pruneopts = "UT"
revision = "d80c43f9c984a48783daf22f4bd9278006ae483a"
version = "v2.19.7"
[[projects]]
branch = "master"
digest = "1:99c6a6dab47067c9b898e8c8b13d130c6ab4ffbcc4b7cc6236c2cd0b1e344f5b"
name = "github.com/shirou/w32"
packages = ["."]
pruneopts = "UT"
revision = "bb4de0191aa41b5507caa14b0650cdbddcd9280b"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:93d6687fc19da8a35c7352d72117a6acd2072dfb7e9bfd65646227bf2a913b2a" digest = "1:93d6687fc19da8a35c7352d72117a6acd2072dfb7e9bfd65646227bf2a913b2a"
@ -143,12 +217,12 @@
revision = "9b9efcf221b50905aab9bbabd3daed56dc10f339" revision = "9b9efcf221b50905aab9bbabd3daed56dc10f339"
[[projects]] [[projects]]
digest = "1:7df351557a6d5c30804e7d6f7ed87f2fccb0619c08fcc84869a93f22bec96c11" digest = "1:46aea6ffe39c3d95c13640c2515236ed3c5cdffc3c78c6c0ed4edec2caf7a0dc"
name = "github.com/tidwall/gjson" name = "github.com/tidwall/gjson"
packages = ["."] packages = ["."]
pruneopts = "UT" pruneopts = "UT"
revision = "eee0b6226f0d1db2675a176fdfaa8419bcad4ca8" revision = "c5e72cdf74dff23857243dd662c465b810891c21"
version = "v1.2.1" version = "v1.3.2"
[[projects]] [[projects]]
digest = "1:8453ddbed197809ee8ca28b06bd04e127bec9912deb4ba451fea7a1eca578328" digest = "1:8453ddbed197809ee8ca28b06bd04e127bec9912deb4ba451fea7a1eca578328"
@ -159,12 +233,12 @@
version = "v1.0.1" version = "v1.0.1"
[[projects]] [[projects]]
branch = "master"
digest = "1:ddfe0a54e5f9b29536a6d7b2defa376f2cb2b6e4234d676d7ff214d5b097cb50" digest = "1:ddfe0a54e5f9b29536a6d7b2defa376f2cb2b6e4234d676d7ff214d5b097cb50"
name = "github.com/tidwall/pretty" name = "github.com/tidwall/pretty"
packages = ["."] packages = ["."]
pruneopts = "UT" pruneopts = "UT"
revision = "1166b9ac2b65e46a43d8618d30d1554f4652d49b" revision = "1166b9ac2b65e46a43d8618d30d1554f4652d49b"
version = "v1.0.0"
[[projects]] [[projects]]
digest = "1:b70c951ba6fdeecfbd50dabe95aa5e1b973866ae9abbece46ad60348112214f2" digest = "1:b70c951ba6fdeecfbd50dabe95aa5e1b973866ae9abbece46ad60348112214f2"
@ -175,12 +249,12 @@
version = "v1.0.4" version = "v1.0.4"
[[projects]] [[projects]]
digest = "1:5f7414cf41466d4b4dd7ec52b2cd3e481e08cfd11e7e24fef730c0e483e88bb1" digest = "1:f2ac2c724fc8214bb7b9dd6d4f5b7a983152051f5133320f228557182263cb94"
name = "go.etcd.io/bbolt" name = "go.etcd.io/bbolt"
packages = ["."] packages = ["."]
pruneopts = "UT" pruneopts = "UT"
revision = "63597a96ec0ad9e6d43c3fc81e809909e0237461" revision = "a0458a2b35708eef59eb5f620ceb3cd1c01a824d"
version = "v1.3.2" version = "v1.3.3"
[[projects]] [[projects]]
branch = "master" branch = "master"
@ -192,7 +266,7 @@
"sha3", "sha3",
] ]
pruneopts = "UT" pruneopts = "UT"
revision = "20be4c3c3ed52bfccdb2d59a412ee1a936d175a7" revision = "4def268fd1a49955bfb3dda92fe3db4f924f2285"
[[projects]] [[projects]]
branch = "master" branch = "master"
@ -203,11 +277,11 @@
"trace", "trace",
] ]
pruneopts = "UT" pruneopts = "UT"
revision = "f3200d17e092c607f615320ecaad13d87ad9a2b3" revision = "74dc4d7220e7acc4e100824340f3e66577424772"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:8a3986af7a48f0991ce6168708859c56d39d2ff8b82b34d0805bbb545a9a32a6" digest = "1:6c97b4baa9cf774b42192e23632afc7dee7b899d4a3dc98057d24708ce1f60ac"
name = "golang.org/x/sys" name = "golang.org/x/sys"
packages = [ packages = [
"cpu", "cpu",
@ -215,11 +289,11 @@
"windows", "windows",
] ]
pruneopts = "UT" pruneopts = "UT"
revision = "46560c3f3c0a091352115a3d825af45663b983d8" revision = "fde4db37ae7ad8191b03d30d27f258b5291ae4e3"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:47717932fbd4293f80d296cd65254cd9c2ec9e51d8b4227b6254440cfa34da2a" digest = "1:b3e64b43fe77039813b4f33100e4ef2e17960309f3c6c9ab120f5c55de747992"
name = "golang.org/x/tools" name = "golang.org/x/tools"
packages = [ packages = [
"go/ast/astutil", "go/ast/astutil",
@ -233,7 +307,7 @@
"internal/semver", "internal/semver",
] ]
pruneopts = "UT" pruneopts = "UT"
revision = "75312fb06703a759656ae75b3d1c24b4aae95dfe" revision = "caa95bb40b630f80d344d1f710f7e39be971d3e8"
[solve-meta] [solve-meta]
analyzer-name = "dep" analyzer-name = "dep"
@ -246,7 +320,11 @@
"github.com/google/renameio", "github.com/google/renameio",
"github.com/gorilla/mux", "github.com/gorilla/mux",
"github.com/gorilla/websocket", "github.com/gorilla/websocket",
"github.com/hashicorp/go-version",
"github.com/safing/portmaster/core/structure",
"github.com/satori/go.uuid",
"github.com/seehuhn/fortuna", "github.com/seehuhn/fortuna",
"github.com/shirou/gopsutil/host",
"github.com/tevino/abool", "github.com/tevino/abool",
"github.com/tidwall/gjson", "github.com/tidwall/gjson",
"github.com/tidwall/sjson", "github.com/tidwall/sjson",

View file

@ -84,3 +84,15 @@
[prune] [prune]
go-tests = true go-tests = true
unused-packages = true unused-packages = true
[[constraint]]
name = "github.com/satori/go.uuid"
version = "1.2.0"
[[constraint]]
name = "github.com/shirou/gopsutil"
version = "2.19.6"
[[constraint]]
name = "github.com/hashicorp/go-version"
version = "1.2.0"

110
api/authentication.go Normal file
View file

@ -0,0 +1,110 @@
package api
import (
"encoding/base64"
"net/http"
"sync"
"time"
"github.com/safing/portbase/crypto/random"
"github.com/safing/portbase/log"
)
var (
validTokens map[string]time.Time
validTokensLock sync.Mutex
authFnLock sync.Mutex
authFn Authenticator
)
const (
cookieName = "T17"
// in seconds
cookieBaseTTL = 300 // 5 minutes
cookieTTL = cookieBaseTTL * time.Second
cookieRefresh = cookieBaseTTL * 0.9 * time.Second
)
// Authenticator is a function that can be set as the authenticator for the API endpoint. If none is set, all requests will be allowed.
type Authenticator func(s *http.Server, r *http.Request) (grantAccess bool, err error)
// SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be allowed.
func SetAuthenticator(fn Authenticator) error {
authFnLock.Lock()
defer authFnLock.Unlock()
if authFn == nil {
authFn = fn
return nil
}
return ErrAuthenticationAlreadySet
}
func authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// check existing auth cookie
c, err := r.Cookie(cookieName)
if err == nil {
// get token
validTokensLock.Lock()
validUntil, valid := validTokens[c.Value]
validTokensLock.Unlock()
// check if token is valid
if valid && time.Now().Before(validUntil) {
// maybe refresh cookie
if time.Now().After(validUntil.Add(-cookieRefresh)) {
validTokensLock.Lock()
validTokens[c.Value] = time.Now()
validTokensLock.Unlock()
}
next.ServeHTTP(w, r)
return
}
}
// get authenticator
authFnLock.Lock()
authenticator := authFn
authFnLock.Unlock()
// permit if no authenticator set
if authenticator == nil {
next.ServeHTTP(w, r)
return
}
// get auth decision
grantAccess, err := authenticator(server, r)
if err != nil {
log.Warningf("api: authenticator failed: %s", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
if !grantAccess {
log.Warningf("api: denying api access to %s", r.RemoteAddr)
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// write new cookie
token, err := random.Bytes(32) // 256 bit
if err != nil {
log.Warningf("api: failed to generate random token: %s", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
tokenString := base64.RawURLEncoding.EncodeToString(token)
http.SetCookie(w, &http.Cookie{
Name: cookieName,
Value: tokenString,
HttpOnly: true,
})
// serve
log.Tracef("api: granted %s", r.RemoteAddr)
next.ServeHTTP(w, r)
})
}

View file

@ -17,40 +17,41 @@ func init() {
flag.StringVar(&listenAddressFlag, "api-address", "", "override api listen address") flag.StringVar(&listenAddressFlag, "api-address", "", "override api listen address")
} }
func checkFlags() error { func logFlagOverrides() {
if listenAddressFlag != "" { if listenAddressFlag != "" {
log.Warning("api: api/listenAddress config is being overridden by -api-address flag") log.Warning("api: api/listenAddress default config is being overridden by -api-address flag")
} }
return nil
} }
func getListenAddress() string { func getDefaultListenAddress() string {
// check if overridden
if listenAddressFlag != "" { if listenAddressFlag != "" {
return listenAddressFlag return listenAddressFlag
} }
return listenAddressConfig() // return internal default
return defaultListenAddress
} }
func registerConfig() error { func registerConfig() error {
err := config.Register(&config.Option{ err := config.Register(&config.Option{
Name: "API Address", Name: "API Address",
Key: "api/listenAddress", Key: "api/listenAddress",
Description: "Define on what IP and port the API should listen on. Be careful, changing this may become a security issue.", Description: "Define on which IP and port the API should listen on.",
ExpertiseLevel: config.ExpertiseLevelExpert, ExpertiseLevel: config.ExpertiseLevelDeveloper,
OptType: config.OptTypeString, OptType: config.OptTypeString,
DefaultValue: defaultListenAddress, DefaultValue: getDefaultListenAddress(),
ValidationRegex: "^([0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}:[0-9]{1,5}|\\[[:0-9A-Fa-f]+\\]:[0-9]{1,5})$", ValidationRegex: "^([0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}:[0-9]{1,5}|\\[[:0-9A-Fa-f]+\\]:[0-9]{1,5})$",
RequiresRestart: true,
}) })
if err != nil { if err != nil {
return err return err
} }
listenAddressConfig = config.GetAsString("api/listenAddress", defaultListenAddress) listenAddressConfig = config.GetAsString("api/listenAddress", getDefaultListenAddress())
return nil return nil
} }
// SetDefaultAPIListenAddress sets the default listen address for the API.
func SetDefaultAPIListenAddress(address string) { func SetDefaultAPIListenAddress(address string) {
if defaultListenAddress == "" {
defaultListenAddress = address defaultListenAddress = address
}
} }

View file

@ -35,6 +35,10 @@ var (
dbAPISeperatorBytes = []byte(dbAPISeperator) dbAPISeperatorBytes = []byte(dbAPISeperator)
) )
func init() {
RegisterHandleFunc("/api/database/v1", startDatabaseAPI) // net/http pattern matching only this exact path
}
// DatabaseAPI is a database API instance. // DatabaseAPI is a database API instance.
type DatabaseAPI struct { type DatabaseAPI struct {
conn *websocket.Conn conn *websocket.Conn

View file

@ -1,25 +1,68 @@
package api package api
import ( import (
"bufio"
"errors"
"net"
"net/http" "net/http"
"github.com/safing/portbase/log"
) )
// EnrichedResponseWriter is a wrapper for http.ResponseWriter for better information extraction. // LoggingResponseWriter is a wrapper for http.ResponseWriter for better request logging.
type EnrichedResponseWriter struct { type LoggingResponseWriter struct {
http.ResponseWriter ResponseWriter http.ResponseWriter
Request *http.Request
Status int Status int
} }
// NewEnrichedResponseWriter wraps a http.ResponseWriter. // NewLoggingResponseWriter wraps a http.ResponseWriter.
func NewEnrichedResponseWriter(w http.ResponseWriter) *EnrichedResponseWriter { func NewLoggingResponseWriter(w http.ResponseWriter, r *http.Request) *LoggingResponseWriter {
return &EnrichedResponseWriter{ return &LoggingResponseWriter{
w, ResponseWriter: w,
0, Request: r,
} }
} }
// WriteHeader wraps the original WriteHeader method to extract information. // Header wraps the original Header method.
func (ew *EnrichedResponseWriter) WriteHeader(code int) { func (lrw *LoggingResponseWriter) Header() http.Header {
ew.Status = code return lrw.ResponseWriter.Header()
ew.ResponseWriter.WriteHeader(code) }
// Write wraps the original Write method.
func (lrw *LoggingResponseWriter) Write(b []byte) (int, error) {
return lrw.ResponseWriter.Write(b)
}
// WriteHeader wraps the original WriteHeader method to extract information.
func (lrw *LoggingResponseWriter) WriteHeader(code int) {
lrw.Status = code
lrw.ResponseWriter.WriteHeader(code)
}
// Hijack wraps the original Hijack method, if available.
func (lrw *LoggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := lrw.ResponseWriter.(http.Hijacker)
if ok {
c, b, err := hijacker.Hijack()
if err != nil {
return nil, nil, err
}
log.Infof("api request: %s HIJ %s", lrw.Request.RemoteAddr, lrw.Request.RequestURI)
return c, b, nil
}
return nil, nil, errors.New("response does not implement http.Hijacker")
}
// RequestLogger is a logging middleware
func RequestLogger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI)
lrw := NewLoggingResponseWriter(w, r)
next.ServeHTTP(lrw, r)
if lrw.Status != 0 {
// request may have been hijacked
log.Infof("api request: %s %d %s", lrw.Request.RemoteAddr, lrw.Status, lrw.Request.RequestURI)
}
})
} }

View file

@ -1,22 +1,37 @@
package api package api
import ( import (
"context"
"errors"
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
) )
// API Errors
var (
ErrAuthenticationAlreadySet = errors.New("the authentication function has already been set")
)
func init() { func init() {
modules.Register("api", prep, start, nil, "database") modules.Register("api", prep, start, stop, "base", "database", "config")
} }
func prep() error { func prep() error {
err := checkFlags() if getDefaultListenAddress() == "" {
if err != nil { return errors.New("no listen address for api available")
return err
} }
return registerConfig() return registerConfig()
} }
func start() error { func start() error {
logFlagOverrides()
go Serve() go Serve()
return nil return nil
} }
func stop() error {
if server != nil {
return server.Shutdown(context.Background())
}
return nil
}

27
api/middleware.go Normal file
View file

@ -0,0 +1,27 @@
package api
import "net/http"
// Middleware is a function that can be added as a middleware to the API endpoint.
type Middleware func(next http.Handler) http.Handler
type mwHandler struct {
handlers []Middleware
final http.Handler
}
func (mwh *mwHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
handlerLock.RLock()
defer handlerLock.RUnlock()
// final handler
handler := mwh.final
// build middleware chain
for _, mw := range mwh.handlers {
handler = mw(handler)
}
// start
handler.ServeHTTP(w, r)
}

View file

@ -2,6 +2,7 @@ package api
import ( import (
"net/http" "net/http"
"sync"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -9,35 +10,54 @@ import (
) )
var ( var (
router = mux.NewRouter() // gorilla mux
mainMux = mux.NewRouter()
// middlewares
middlewareHandler = &mwHandler{
final: mainMux,
handlers: []Middleware{
RequestLogger,
authMiddleware,
},
}
// main server and lock
server = &http.Server{}
handlerLock sync.RWMutex
) )
// RegisterHandleFunc registers an additional handle function with the API endoint. // RegisterHandler registers a handler with the API endoint.
func RegisterHandleFunc(path string, handleFunc func(http.ResponseWriter, *http.Request)) *mux.Route { func RegisterHandler(path string, handler http.Handler) *mux.Route {
return router.HandleFunc(path, handleFunc) handlerLock.Lock()
defer handlerLock.Unlock()
return mainMux.Handle(path, handler)
} }
// RequestLogger is a logging middleware // RegisterHandleFunc registers a handle function with the API endoint.
func RequestLogger(next http.Handler) http.Handler { func RegisterHandleFunc(path string, handleFunc func(http.ResponseWriter, *http.Request)) *mux.Route {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handlerLock.Lock()
log.Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI) defer handlerLock.Unlock()
ew := NewEnrichedResponseWriter(w) return mainMux.HandleFunc(path, handleFunc)
next.ServeHTTP(ew, r) }
log.Infof("api request: %s %d %s", r.RemoteAddr, ew.Status, r.RequestURI)
}) // RegisterMiddleware registers a middle function with the API endoint.
func RegisterMiddleware(middleware Middleware) {
handlerLock.Lock()
defer handlerLock.Unlock()
middlewareHandler.handlers = append(middlewareHandler.handlers, middleware)
} }
// Serve starts serving the API endpoint. // Serve starts serving the API endpoint.
func Serve() { func Serve() {
router.Use(RequestLogger) // configure server
server.Addr = listenAddressConfig()
server.Handler = middlewareHandler
mainMux := http.NewServeMux() // start serving
mainMux.Handle("/", router) // net/http pattern matching /* log.Infof("api: starting to listen on %s", server.Addr)
mainMux.HandleFunc("/api/database/v1", startDatabaseAPI) // net/http pattern matching only this exact path // TODO: retry if failed
log.Errorf("api: failed to listen on %s: %s", server.Addr, server.ListenAndServe())
address := getListenAddress()
log.Infof("api: starting to listen on %s", address)
log.Errorf("api: failed to listen on %s: %s", address, http.ListenAndServe(address, mainMux))
} }
// GetMuxVars wraps github.com/gorilla/mux.Vars in order to mitigate context key issues in multi-repo projects. // GetMuxVars wraps github.com/gorilla/mux.Vars in order to mitigate context key issues in multi-repo projects.

View file

@ -1 +0,0 @@
package api

View file

@ -7,5 +7,5 @@ import (
) )
func init() { func init() {
api.RegisterAdditionalRoute("/test/", http.StripPrefix("/test/", http.FileServer(http.Dir("./api/testclient/root/")))) api.RegisterHandler("/test/", http.StripPrefix("/test/", http.FileServer(http.Dir("./api/testclient/root/"))))
} }

View file

@ -1 +0,0 @@
package api

View file

@ -5,25 +5,25 @@ import (
"sort" "sort"
"strings" "strings"
"github.com/safing/portbase/log"
"github.com/safing/portbase/database" "github.com/safing/portbase/database"
"github.com/safing/portbase/database/storage"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/database/query"
"github.com/safing/portbase/database/iterator" "github.com/safing/portbase/database/iterator"
"github.com/safing/portbase/database/query"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/database/storage"
"github.com/safing/portbase/log"
) )
var ( var (
dbController *database.Controller dbController *database.Controller
) )
// ConfigStorageInterface provices a storage.Interface to the configuration manager. // StorageInterface provices a storage.Interface to the configuration manager.
type ConfigStorageInterface struct { type StorageInterface struct {
storage.InjectBase storage.InjectBase
} }
// Get returns a database record. // Get returns a database record.
func (s *ConfigStorageInterface) Get(key string) (record.Record, error) { func (s *StorageInterface) Get(key string) (record.Record, error) {
optionsLock.Lock() optionsLock.Lock()
defer optionsLock.Unlock() defer optionsLock.Unlock()
@ -36,7 +36,7 @@ func (s *ConfigStorageInterface) Get(key string) (record.Record, error) {
} }
// Put stores a record in the database. // Put stores a record in the database.
func (s *ConfigStorageInterface) Put(r record.Record) error { func (s *StorageInterface) Put(r record.Record) error {
if r.Meta().Deleted > 0 { if r.Meta().Deleted > 0 {
return setConfigOption(r.DatabaseKey(), nil, false) return setConfigOption(r.DatabaseKey(), nil, false)
} }
@ -60,13 +60,13 @@ func (s *ConfigStorageInterface) Put(r record.Record) error {
var value interface{} var value interface{}
switch option.OptType { switch option.OptType {
case OptTypeString : case OptTypeString:
value, ok = acc.GetString("Value") value, ok = acc.GetString("Value")
case OptTypeStringArray : case OptTypeStringArray:
value, ok = acc.GetStringArray("Value") value, ok = acc.GetStringArray("Value")
case OptTypeInt : case OptTypeInt:
value, ok = acc.GetInt("Value") value, ok = acc.GetInt("Value")
case OptTypeBool : case OptTypeBool:
value, ok = acc.GetBool("Value") value, ok = acc.GetBool("Value")
} }
if !ok { if !ok {
@ -81,12 +81,12 @@ func (s *ConfigStorageInterface) Put(r record.Record) error {
} }
// Delete deletes a record from the database. // Delete deletes a record from the database.
func (s *ConfigStorageInterface) Delete(key string) error { func (s *StorageInterface) Delete(key string) error {
return setConfigOption(key, nil, false) return setConfigOption(key, nil, false)
} }
// Query returns a an iterator for the supplied query. // Query returns a an iterator for the supplied query.
func (s *ConfigStorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
optionsLock.Lock() optionsLock.Lock()
defer optionsLock.Unlock() defer optionsLock.Unlock()
@ -104,7 +104,7 @@ func (s *ConfigStorageInterface) Query(q *query.Query, local, internal bool) (*i
return it, nil return it, nil
} }
func (s *ConfigStorageInterface) processQuery(q *query.Query, it *iterator.Iterator, opts []*Option) { func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator, opts []*Option) {
sort.Sort(sortableOptions(opts)) sort.Sort(sortableOptions(opts))
@ -121,7 +121,7 @@ func (s *ConfigStorageInterface) processQuery(q *query.Query, it *iterator.Itera
} }
// ReadOnly returns whether the database is read only. // ReadOnly returns whether the database is read only.
func (s *ConfigStorageInterface) ReadOnly() bool { func (s *StorageInterface) ReadOnly() bool {
return false return false
} }
@ -136,7 +136,7 @@ func registerAsDatabase() error {
return err return err
} }
controller, err := database.InjectDatabase("config", &ConfigStorageInterface{}) controller, err := database.InjectDatabase("config", &StorageInterface{})
if err != nil { if err != nil {
return err return err
} }

View file

@ -1,23 +1,40 @@
package config package config
import ( import (
"errors"
"os" "os"
"path" "path/filepath"
"github.com/safing/portbase/database"
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
"github.com/safing/portbase/utils"
"github.com/safing/portmaster/core/structure"
) )
var (
dataRoot *utils.DirStructure
)
// SetDataRoot sets the data root from which the updates module derives its paths.
func SetDataRoot(root *utils.DirStructure) {
if dataRoot == nil {
dataRoot = root
}
}
func init() { func init() {
modules.Register("config", prep, start, nil, "database") modules.Register("config", prep, start, nil, "base", "database")
} }
func prep() error { func prep() error {
SetDataRoot(structure.Root())
if dataRoot == nil {
return errors.New("data root is not set")
}
return nil return nil
} }
func start() error { func start() error {
configFilePath = path.Join(database.GetDatabaseRoot(), "config.json") configFilePath = filepath.Join(dataRoot.Path, "config.json")
err := registerAsDatabase() err := registerAsDatabase()
if err != nil && !os.IsNotExist(err) { if err != nil && !os.IsNotExist(err) {

View file

@ -1,9 +1,9 @@
package config package config
import ( import (
"encoding/json"
"fmt" "fmt"
"regexp" "regexp"
"encoding/json"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
@ -47,6 +47,7 @@ type Option struct {
DefaultValue interface{} DefaultValue interface{}
ExternalOptType string ExternalOptType string
ValidationRegex string ValidationRegex string
RequiresRestart bool
compiledRegex *regexp.Regexp compiledRegex *regexp.Regexp
} }

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package hash package hash
import ( import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package hash package hash
import "testing" import "testing"

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package hash package hash
import ( import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package hash package hash
import ( import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package hash package hash
import ( import (

View file

@ -23,7 +23,7 @@ var (
) )
func init() { func init() {
modules.Register("random", prep, Start, stop) modules.Register("random", prep, Start, stop, "base")
config.Register(&config.Option{ config.Register(&config.Option{
Name: "RNG Cipher", Name: "RNG Cipher",

View file

@ -2,51 +2,46 @@ package dbmodule
import ( import (
"errors" "errors"
"flag"
"sync"
"github.com/safing/portbase/database" "github.com/safing/portbase/database"
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
"github.com/safing/portbase/utils"
) )
var ( var (
databaseDir string databasePath string
shutdownSignal = make(chan struct{}) databaseStructureRoot *utils.DirStructure
maintenanceWg sync.WaitGroup
module *modules.Module
) )
// SetDatabaseLocation sets the location of the database. Must be called before modules.Start and will be overridden by command line options. Intended for unit tests. func init() {
func SetDatabaseLocation(location string) { module = modules.Register("database", prep, start, stop, "base")
databaseDir = location
} }
func init() { // SetDatabaseLocation sets the location of the database for initialization. Supply either a path or dir structure.
flag.StringVar(&databaseDir, "db", "", "set database directory") func SetDatabaseLocation(dirPath string, dirStructureRoot *utils.DirStructure) {
databasePath = dirPath
modules.Register("database", prep, start, stop) databaseStructureRoot = dirStructureRoot
} }
func prep() error { func prep() error {
if databaseDir == "" { if databasePath == "" && databaseStructureRoot == nil {
return errors.New("no database location specified, set with `-db=/path/to/db`") return errors.New("no database location specified")
}
ok := database.SetLocation(databaseDir)
if !ok {
return errors.New("database location already set")
} }
return nil return nil
} }
func start() error { func start() error {
err := database.Initialize() err := database.Initialize(databasePath, databaseStructureRoot)
if err == nil { if err != nil {
startMaintainer()
}
return err return err
}
startMaintainer()
return nil
} }
func stop() error { func stop() error {
close(shutdownSignal)
maintenanceWg.Wait()
return database.Shutdown() return database.Shutdown()
} }

View file

@ -13,7 +13,7 @@ var (
) )
func startMaintainer() { func startMaintainer() {
maintenanceWg.Add(1) module.AddWorkers(1)
go maintenanceWorker() go maintenanceWorker()
} }
@ -37,8 +37,8 @@ func maintenanceWorker() {
if err != nil { if err != nil {
log.Errorf("database: thorough maintenance error: %s", err) log.Errorf("database: thorough maintenance error: %s", err)
} }
case <-shutdownSignal: case <-module.ShuttingDown():
maintenanceWg.Done() module.FinishWorker()
return return
} }
} }

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package dbutils package dbutils
type Meta struct { type Meta struct {

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
/* /*
Package database provides a universal interface for interacting with the database. Package database provides a universal interface for interacting with the database.

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package database package database
import ( import (

View file

@ -1,33 +1 @@
package database package database
import (
"fmt"
"path/filepath"
"github.com/safing/portbase/utils"
)
const (
databasesSubDir = "databases"
)
var (
rootDir string
)
// GetDatabaseRoot returns the root directory of the database.
func GetDatabaseRoot() string {
return rootDir
}
// getLocation returns the storage location for the given name and type.
func getLocation(name, storageType string) (string, error) {
location := filepath.Join(rootDir, databasesSubDir, name, storageType)
// check location
err := utils.EnsureDirectory(location, 0700)
if err != nil {
return "", fmt.Errorf("location (%s) invalid: %s", location, err)
}
return location, nil
}

View file

@ -9,34 +9,40 @@ import (
"github.com/tevino/abool" "github.com/tevino/abool"
) )
const (
databasesSubDir = "databases"
)
var ( var (
initialized = abool.NewBool(false) initialized = abool.NewBool(false)
shuttingDown = abool.NewBool(false) shuttingDown = abool.NewBool(false)
shutdownSignal = make(chan struct{}) shutdownSignal = make(chan struct{})
rootStructure *utils.DirStructure
databasesStructure *utils.DirStructure
) )
// SetLocation sets the location of the database. This is separate from the initialization to provide the location to other modules earlier. // Initialize initializes the database at the specified location. Supply either a path or dir structure.
func SetLocation(location string) (ok bool) { func Initialize(dirPath string, dirStructureRoot *utils.DirStructure) error {
if !initialized.IsSet() && rootDir == "" {
rootDir = location
return true
}
return false
}
// Initialize initialized the database
func Initialize() error {
if initialized.SetToIf(false, true) { if initialized.SetToIf(false, true) {
err := utils.EnsureDirectory(rootDir, 0755) if dirStructureRoot != nil {
rootStructure = dirStructureRoot
} else {
rootStructure = utils.NewDirStructure(dirPath, 0755)
}
// ensure root and databases dirs
databasesStructure = rootStructure.ChildDir(databasesSubDir, 0700)
err := databasesStructure.Ensure()
if err != nil { if err != nil {
return fmt.Errorf("could not create/open database directory (%s): %s", rootDir, err) return fmt.Errorf("could not create/open database directory (%s): %s", rootStructure.Path, err)
} }
err = loadRegistry() err = loadRegistry()
if err != nil { if err != nil {
return fmt.Errorf("could not load database registry (%s): %s", filepath.Join(rootDir, registryFileName), err) return fmt.Errorf("could not load database registry (%s): %s", filepath.Join(rootStructure.Path, registryFileName), err)
} }
// start registry writer // start registry writer
@ -66,3 +72,14 @@ func Shutdown() (err error) {
} }
return return
} }
// getLocation returns the storage location for the given name and type.
func getLocation(name, storageType string) (string, error) {
location := databasesStructure.ChildDir(name, 0700).ChildDir(storageType, 0700)
// check location
err := location.Ensure()
if err != nil {
return "", fmt.Errorf(`failed to create/check database dir "%s": %s`, location.Path, err)
}
return location.Path, nil
}

View file

@ -70,7 +70,7 @@ func (b *Base) SetMeta(meta *Meta) {
b.meta = meta b.meta = meta
} }
// Marshal marshals the object, without the database key or metadata // Marshal marshals the object, without the database key or metadata. It returns nil if the record is deleted.
func (b *Base) Marshal(self Record, format uint8) ([]byte, error) { func (b *Base) Marshal(self Record, format uint8) ([]byte, error) {
if b.Meta() == nil { if b.Meta() == nil {
return nil, errors.New("missing meta") return nil, errors.New("missing meta")
@ -96,15 +96,15 @@ func (b *Base) MarshalRecord(self Record) ([]byte, error) {
// version // version
c := container.New([]byte{1}) c := container.New([]byte{1})
// meta // meta encoding
metaSection, err := b.meta.GenCodeMarshal(nil) metaSection, err := dsd.Dump(b.meta, GenCode)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.AppendAsBlock(metaSection) c.AppendAsBlock(metaSection)
// data // data
dataSection, err := b.Marshal(self, dsd.JSON) dataSection, err := b.Marshal(self, JSON)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -11,5 +11,5 @@ const (
BYTES = dsd.BYTES // X BYTES = dsd.BYTES // X
JSON = dsd.JSON // J JSON = dsd.JSON // J
BSON = dsd.BSON // B BSON = dsd.BSON // B
GenCode = dsd.GenCode // G (reserved) GenCode = dsd.GenCode // G
) )

View file

@ -432,7 +432,7 @@ func BenchmarkMetaUnserializeWithCodegen(b *testing.B) {
func BenchmarkMetaSerializeWithDSDJSON(b *testing.B) { func BenchmarkMetaSerializeWithDSDJSON(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err := dsd.Dump(testMeta, dsd.JSON) _, err := dsd.Dump(testMeta, JSON)
if err != nil { if err != nil {
b.Errorf("failed to serialize with DSD/JSON: %s", err) b.Errorf("failed to serialize with DSD/JSON: %s", err)
return return
@ -444,7 +444,7 @@ func BenchmarkMetaSerializeWithDSDJSON(b *testing.B) {
func BenchmarkMetaUnserializeWithDSDJSON(b *testing.B) { func BenchmarkMetaUnserializeWithDSDJSON(b *testing.B) {
// Setup // Setup
encodedData, err := dsd.Dump(testMeta, dsd.JSON) encodedData, err := dsd.Dump(testMeta, JSON)
if err != nil { if err != nil {
b.Errorf("failed to serialize with DSD/JSON: %s", err) b.Errorf("failed to serialize with DSD/JSON: %s", err)
return return

View file

@ -37,10 +37,21 @@ func NewRawWrapper(database, key string, data []byte) (*Wrapper, error) {
offset += n offset += n
newMeta := &Meta{} newMeta := &Meta{}
if len(metaSection) == 34 && metaSection[4] == 0 {
// TODO: remove in 2020
// backward compatibility:
// format would byte shift and populate metaSection[4] with value > 0 (would naturally populate >0 at 07.02.2106 07:28:15)
// this must be gencode without format
_, err = newMeta.GenCodeUnmarshal(metaSection) _, err = newMeta.GenCodeUnmarshal(metaSection)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not unmarshal meta section: %s", err) return nil, fmt.Errorf("could not unmarshal meta section: %s", err)
} }
} else {
_, err = dsd.Load(metaSection, newMeta)
if err != nil {
return nil, fmt.Errorf("could not unmarshal meta section: %s", err)
}
}
format, n, err := varint.Unpack8(data[offset:]) format, n, err := varint.Unpack8(data[offset:])
if err != nil { if err != nil {
@ -86,7 +97,7 @@ func (w *Wrapper) Marshal(r Record, format uint8) ([]byte, error) {
return nil, nil return nil, nil
} }
if format != dsd.AUTO && format != w.Format { if format != AUTO && format != w.Format {
return nil, errors.New("could not dump model, wrapped object format mismatch") return nil, errors.New("could not dump model, wrapped object format mismatch")
} }
@ -109,14 +120,14 @@ func (w *Wrapper) MarshalRecord(r Record) ([]byte, error) {
c := container.New([]byte{1}) c := container.New([]byte{1})
// meta // meta
metaSection, err := w.meta.GenCodeMarshal(nil) metaSection, err := dsd.Dump(w.meta, GenCode)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.AppendAsBlock(metaSection) c.AppendAsBlock(metaSection)
// data // data
dataSection, err := w.Marshal(r, dsd.JSON) dataSection, err := w.Marshal(r, JSON)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -125,16 +136,6 @@ func (w *Wrapper) MarshalRecord(r Record) ([]byte, error) {
return c.CompileData(), nil return c.CompileData(), nil
} }
// // Lock locks the record.
// func (w *Wrapper) Lock() {
// w.lock.Lock()
// }
//
// // Unlock unlocks the record.
// func (w *Wrapper) Unlock() {
// w.lock.Unlock()
// }
// IsWrapped returns whether the record is a Wrapper. // IsWrapped returns whether the record is a Wrapper.
func (w *Wrapper) IsWrapped() bool { func (w *Wrapper) IsWrapped() bool {
return true return true

View file

@ -2,9 +2,10 @@ package record
import ( import (
"bytes" "bytes"
"errors"
"testing" "testing"
"github.com/safing/portbase/formats/dsd" "github.com/safing/portbase/container"
) )
func TestWrapper(t *testing.T) { func TestWrapper(t *testing.T) {
@ -24,14 +25,14 @@ func TestWrapper(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if wrapper.Format != dsd.JSON { if wrapper.Format != JSON {
t.Error("format mismatch") t.Error("format mismatch")
} }
if !bytes.Equal(testData, wrapper.Data) { if !bytes.Equal(testData, wrapper.Data) {
t.Error("data mismatch") t.Error("data mismatch")
} }
encoded, err := wrapper.Marshal(wrapper, dsd.JSON) encoded, err := wrapper.Marshal(wrapper, JSON)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -40,6 +41,7 @@ func TestWrapper(t *testing.T) {
} }
wrapper.SetMeta(&Meta{}) wrapper.SetMeta(&Meta{})
wrapper.meta.Update()
raw, err := wrapper.MarshalRecord(wrapper) raw, err := wrapper.MarshalRecord(wrapper)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -53,4 +55,42 @@ func TestWrapper(t *testing.T) {
t.Error("marshal mismatch") t.Error("marshal mismatch")
} }
// test new format
oldRaw, err := oldWrapperMarshalRecord(wrapper, wrapper)
if err != nil {
t.Fatal(err)
}
wrapper3, err := NewRawWrapper("test", "a", oldRaw)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(testData, wrapper3.Data) {
t.Error("marshal mismatch")
}
}
func oldWrapperMarshalRecord(w *Wrapper, r Record) ([]byte, error) {
if w.Meta() == nil {
return nil, errors.New("missing meta")
}
// version
c := container.New([]byte{1})
// meta
metaSection, err := w.meta.GenCodeMarshal(nil)
if err != nil {
return nil, err
}
c.AppendAsBlock(metaSection)
// data
dataSection, err := w.Marshal(r, JSON)
if err != nil {
return nil, err
}
c.Append(dataSection)
return c.CompileData(), nil
} }

View file

@ -104,7 +104,7 @@ func loadRegistry() error {
defer registryLock.Unlock() defer registryLock.Unlock()
// read file // read file
filePath := path.Join(rootDir, registryFileName) filePath := path.Join(rootStructure.Path, registryFileName)
data, err := ioutil.ReadFile(filePath) data, err := ioutil.ReadFile(filePath)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
@ -139,7 +139,7 @@ func saveRegistry(lock bool) error {
} }
// write file // write file
filePath := path.Join(rootDir, registryFileName) filePath := path.Join(rootStructure.Path, registryFileName)
return ioutil.WriteFile(filePath, data, 0600) return ioutil.WriteFile(filePath, data, 0600)
} }

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package dsd package dsd
// dynamic structured data // dynamic structured data
@ -22,7 +20,7 @@ const (
BYTES = 88 // X BYTES = 88 // X
JSON = 74 // J JSON = 74 // J
BSON = 66 // B BSON = 66 // B
GenCode = 71 // G (reserved) GenCode = 71 // G
) )
// define errors // define errors
@ -30,6 +28,7 @@ var errNoMoreSpace = errors.New("dsd: no more space left after reading dsd type"
var errUnknownType = errors.New("dsd: tried to unpack unknown type") var errUnknownType = errors.New("dsd: tried to unpack unknown type")
var errNotImplemented = errors.New("dsd: this type is not yet implemented") var errNotImplemented = errors.New("dsd: this type is not yet implemented")
// Load loads an dsd structured data blob into the given interface.
func Load(data []byte, t interface{}) (interface{}, error) { func Load(data []byte, t interface{}) (interface{}, error) {
if len(data) < 2 { if len(data) < 2 {
return nil, errNoMoreSpace return nil, errNoMoreSpace
@ -46,6 +45,7 @@ func Load(data []byte, t interface{}) (interface{}, error) {
return LoadAsFormat(data[read:], format, t) return LoadAsFormat(data[read:], format, t)
} }
// LoadAsFormat loads a data blob into the interface using the specified format.
func LoadAsFormat(data []byte, format uint8, t interface{}) (interface{}, error) { func LoadAsFormat(data []byte, format uint8, t interface{}) (interface{}, error) {
switch format { switch format {
case STRING: case STRING:
@ -55,7 +55,7 @@ func LoadAsFormat(data []byte, format uint8, t interface{}) (interface{}, error)
case JSON: case JSON:
err := json.Unmarshal(data, t) err := json.Unmarshal(data, t)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("dsd: failed to unpack json data: %s", data)
} }
return t, nil return t, nil
// case BSON: // case BSON:
@ -64,19 +64,23 @@ func LoadAsFormat(data []byte, format uint8, t interface{}) (interface{}, error)
// return nil, err // return nil, err
// } // }
// return t, nil // return t, nil
// case MSGP: case GenCode:
// err := t.UnmarshalMsg(data[read:]) genCodeStruct, ok := t.(GenCodeCompatible)
// if err != nil { if !ok {
// return nil, err return nil, errors.New("dsd: gencode is not supported by the given data structure")
// } }
// return t, nil _, err := genCodeStruct.GenCodeUnmarshal(data)
if err != nil {
return nil, fmt.Errorf("dsd: failed to unpack gencode data: %s", err)
}
return t, nil
default: default:
return nil, errors.New(fmt.Sprintf("dsd: tried to load unknown type %d, data: %v", format, data)) return nil, fmt.Errorf("dsd: tried to load unknown type %d, data: %v", format, data)
} }
} }
// Dump stores the interface as a dsd formatted data structure.
func Dump(t interface{}, format uint8) ([]byte, error) { func Dump(t interface{}, format uint8) ([]byte, error) {
if format == AUTO { if format == AUTO {
switch t.(type) { switch t.(type) {
case string: case string:
@ -107,18 +111,19 @@ func Dump(t interface{}, format uint8) ([]byte, error) {
// if err != nil { // if err != nil {
// return nil, err // return nil, err
// } // }
// case MSGP: case GenCode:
// data, err := t.MarshalMsg(nil) genCodeStruct, ok := t.(GenCodeCompatible)
// if err != nil { if !ok {
// return nil, err return nil, errors.New("dsd: gencode is not supported by the given data structure")
// } }
data, err = genCodeStruct.GenCodeMarshal(nil)
if err != nil {
return nil, fmt.Errorf("dsd: failed to pack gencode struct: %s", err)
}
default: default:
return nil, errors.New(fmt.Sprintf("dsd: tried to dump unknown type %d", format)) return nil, fmt.Errorf("dsd: tried to dump unknown type %d", format)
} }
r := append(f, data...) r := append(f, data...)
// log.Tracef("packing %v to %s", t, string(r))
// return nil, errors.New(fmt.Sprintf("dsd: dumped bytes are: %v", r))
return r, nil return r, nil
} }

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package dsd package dsd
import ( import (
@ -10,6 +8,7 @@ import (
//go:generate msgp //go:generate msgp
// SimpleTestStruct is used for testing.
type SimpleTestStruct struct { type SimpleTestStruct struct {
S string S string
B byte B byte
@ -21,11 +20,11 @@ type ComplexTestStruct struct {
I16 int16 I16 int16
I32 int32 I32 int32
I64 int64 I64 int64
Ui uint UI uint
Ui8 uint8 UI8 uint8
Ui16 uint16 UI16 uint16
Ui32 uint32 UI32 uint32
Ui64 uint64 UI64 uint64
S string S string
Sp *string Sp *string
Sa []string Sa []string
@ -38,6 +37,25 @@ type ComplexTestStruct struct {
Mp *map[string]string Mp *map[string]string
} }
type GenCodeTestStruct struct {
I8 int8
I16 int16
I32 int32
I64 int64
UI8 uint8
UI16 uint16
UI32 uint32
UI64 uint64
S string
Sp *string
Sa []string
Sap *[]string
B byte
Bp *byte
Ba []byte
Bap *[]byte
}
func TestConversion(t *testing.T) { func TestConversion(t *testing.T) {
// STRING // STRING
@ -113,7 +131,26 @@ func TestConversion(t *testing.T) {
}, },
} }
// TODO: test all formats genCodeSubject := GenCodeTestStruct{
-2,
-3,
-4,
-5,
2,
3,
4,
5,
"a",
&bString,
[]string{"c", "d", "e"},
&[]string{"f", "g", "h"},
0x01,
&bBytes,
[]byte{0x03, 0x04, 0x05},
&[]byte{0x05, 0x06, 0x07},
}
// test all formats (complex)
formats := []uint8{JSON} formats := []uint8{JSON}
for _, format := range formats { for _, format := range formats {
@ -163,20 +200,20 @@ func TestConversion(t *testing.T) {
if complexSubject.I64 != co.I64 { if complexSubject.I64 != co.I64 {
t.Errorf("Load (complex struct): struct.I64 is not equal (%v != %v)", complexSubject.I64, co.I64) t.Errorf("Load (complex struct): struct.I64 is not equal (%v != %v)", complexSubject.I64, co.I64)
} }
if complexSubject.Ui != co.Ui { if complexSubject.UI != co.UI {
t.Errorf("Load (complex struct): struct.Ui is not equal (%v != %v)", complexSubject.Ui, co.Ui) t.Errorf("Load (complex struct): struct.UI is not equal (%v != %v)", complexSubject.UI, co.UI)
} }
if complexSubject.Ui8 != co.Ui8 { if complexSubject.UI8 != co.UI8 {
t.Errorf("Load (complex struct): struct.Ui8 is not equal (%v != %v)", complexSubject.Ui8, co.Ui8) t.Errorf("Load (complex struct): struct.UI8 is not equal (%v != %v)", complexSubject.UI8, co.UI8)
} }
if complexSubject.Ui16 != co.Ui16 { if complexSubject.UI16 != co.UI16 {
t.Errorf("Load (complex struct): struct.Ui16 is not equal (%v != %v)", complexSubject.Ui16, co.Ui16) t.Errorf("Load (complex struct): struct.UI16 is not equal (%v != %v)", complexSubject.UI16, co.UI16)
} }
if complexSubject.Ui32 != co.Ui32 { if complexSubject.UI32 != co.UI32 {
t.Errorf("Load (complex struct): struct.Ui32 is not equal (%v != %v)", complexSubject.Ui32, co.Ui32) t.Errorf("Load (complex struct): struct.UI32 is not equal (%v != %v)", complexSubject.UI32, co.UI32)
} }
if complexSubject.Ui64 != co.Ui64 { if complexSubject.UI64 != co.UI64 {
t.Errorf("Load (complex struct): struct.Ui64 is not equal (%v != %v)", complexSubject.Ui64, co.Ui64) t.Errorf("Load (complex struct): struct.UI64 is not equal (%v != %v)", complexSubject.UI64, co.UI64)
} }
if complexSubject.S != co.S { if complexSubject.S != co.S {
t.Errorf("Load (complex struct): struct.S is not equal (%v != %v)", complexSubject.S, co.S) t.Errorf("Load (complex struct): struct.S is not equal (%v != %v)", complexSubject.S, co.S)
@ -211,4 +248,87 @@ func TestConversion(t *testing.T) {
} }
// test all formats
formats = []uint8{JSON, GenCode}
for _, format := range formats {
// simple
b, err := Dump(&simpleSubject, format)
if err != nil {
t.Fatalf("Dump error (simple struct): %s", err)
}
o, err := Load(b, &SimpleTestStruct{})
if err != nil {
t.Fatalf("Load error (simple struct): %s", err)
}
if !reflect.DeepEqual(&simpleSubject, o) {
t.Errorf("Load (simple struct): subject does not match loaded object")
t.Errorf("Encoded: %v", string(b))
t.Errorf("Compared: %v == %v", &simpleSubject, o)
}
// complex
b, err = Dump(&genCodeSubject, format)
if err != nil {
t.Fatalf("Dump error (complex struct): %s", err)
}
o, err = Load(b, &GenCodeTestStruct{})
if err != nil {
t.Fatalf("Load error (complex struct): %s", err)
}
co := o.(*GenCodeTestStruct)
if genCodeSubject.I8 != co.I8 {
t.Errorf("Load (complex struct): struct.I8 is not equal (%v != %v)", genCodeSubject.I8, co.I8)
}
if genCodeSubject.I16 != co.I16 {
t.Errorf("Load (complex struct): struct.I16 is not equal (%v != %v)", genCodeSubject.I16, co.I16)
}
if genCodeSubject.I32 != co.I32 {
t.Errorf("Load (complex struct): struct.I32 is not equal (%v != %v)", genCodeSubject.I32, co.I32)
}
if genCodeSubject.I64 != co.I64 {
t.Errorf("Load (complex struct): struct.I64 is not equal (%v != %v)", genCodeSubject.I64, co.I64)
}
if genCodeSubject.UI8 != co.UI8 {
t.Errorf("Load (complex struct): struct.UI8 is not equal (%v != %v)", genCodeSubject.UI8, co.UI8)
}
if genCodeSubject.UI16 != co.UI16 {
t.Errorf("Load (complex struct): struct.UI16 is not equal (%v != %v)", genCodeSubject.UI16, co.UI16)
}
if genCodeSubject.UI32 != co.UI32 {
t.Errorf("Load (complex struct): struct.UI32 is not equal (%v != %v)", genCodeSubject.UI32, co.UI32)
}
if genCodeSubject.UI64 != co.UI64 {
t.Errorf("Load (complex struct): struct.UI64 is not equal (%v != %v)", genCodeSubject.UI64, co.UI64)
}
if genCodeSubject.S != co.S {
t.Errorf("Load (complex struct): struct.S is not equal (%v != %v)", genCodeSubject.S, co.S)
}
if !reflect.DeepEqual(genCodeSubject.Sp, co.Sp) {
t.Errorf("Load (complex struct): struct.Sp is not equal (%v != %v)", genCodeSubject.Sp, co.Sp)
}
if !reflect.DeepEqual(genCodeSubject.Sa, co.Sa) {
t.Errorf("Load (complex struct): struct.Sa is not equal (%v != %v)", genCodeSubject.Sa, co.Sa)
}
if !reflect.DeepEqual(genCodeSubject.Sap, co.Sap) {
t.Errorf("Load (complex struct): struct.Sap is not equal (%v != %v)", genCodeSubject.Sap, co.Sap)
}
if genCodeSubject.B != co.B {
t.Errorf("Load (complex struct): struct.B is not equal (%v != %v)", genCodeSubject.B, co.B)
}
if !reflect.DeepEqual(genCodeSubject.Bp, co.Bp) {
t.Errorf("Load (complex struct): struct.Bp is not equal (%v != %v)", genCodeSubject.Bp, co.Bp)
}
if !reflect.DeepEqual(genCodeSubject.Ba, co.Ba) {
t.Errorf("Load (complex struct): struct.Ba is not equal (%v != %v)", genCodeSubject.Ba, co.Ba)
}
if !reflect.DeepEqual(genCodeSubject.Bap, co.Bap) {
t.Errorf("Load (complex struct): struct.Bap is not equal (%v != %v)", genCodeSubject.Bap, co.Bap)
}
}
} }

835
formats/dsd/gencode_test.go Normal file
View file

@ -0,0 +1,835 @@
package dsd
import (
"io"
"time"
"unsafe"
)
var (
_ = unsafe.Sizeof(0)
_ = io.ReadFull
_ = time.Now()
)
func (d *SimpleTestStruct) Size() (s uint64) {
{
l := uint64(len(d.S))
{
t := l
for t >= 0x80 {
t >>= 7
s++
}
s++
}
s += l
}
s++
return
}
func (d *SimpleTestStruct) GenCodeMarshal(buf []byte) ([]byte, error) {
size := d.Size()
{
if uint64(cap(buf)) >= size {
buf = buf[:size]
} else {
buf = make([]byte, size)
}
}
i := uint64(0)
{
l := uint64(len(d.S))
{
t := uint64(l)
for t >= 0x80 {
buf[i+0] = byte(t) | 0x80
t >>= 7
i++
}
buf[i+0] = byte(t)
i++
}
copy(buf[i+0:], d.S)
i += l
}
{
buf[i+0] = d.B
}
return buf[:i+1], nil
}
func (d *SimpleTestStruct) GenCodeUnmarshal(buf []byte) (uint64, error) {
i := uint64(0)
{
l := uint64(0)
{
bs := uint8(7)
t := uint64(buf[i+0] & 0x7F)
for buf[i+0]&0x80 == 0x80 {
i++
t |= uint64(buf[i+0]&0x7F) << bs
bs += 7
}
i++
l = t
}
d.S = string(buf[i+0 : i+0+l])
i += l
}
{
d.B = buf[i+0]
}
return i + 1, nil
}
func (d *GenCodeTestStruct) Size() (s uint64) {
{
l := uint64(len(d.S))
{
t := l
for t >= 0x80 {
t >>= 7
s++
}
s++
}
s += l
}
{
if d.Sp != nil {
{
l := uint64(len((*d.Sp)))
{
t := l
for t >= 0x80 {
t >>= 7
s++
}
s++
}
s += l
}
s += 0
}
}
{
l := uint64(len(d.Sa))
{
t := l
for t >= 0x80 {
t >>= 7
s++
}
s++
}
for k0 := range d.Sa {
{
l := uint64(len(d.Sa[k0]))
{
t := l
for t >= 0x80 {
t >>= 7
s++
}
s++
}
s += l
}
}
}
{
if d.Sap != nil {
{
l := uint64(len((*d.Sap)))
{
t := l
for t >= 0x80 {
t >>= 7
s++
}
s++
}
for k0 := range *d.Sap {
{
l := uint64(len((*d.Sap)[k0]))
{
t := l
for t >= 0x80 {
t >>= 7
s++
}
s++
}
s += l
}
}
}
s += 0
}
}
{
if d.Bp != nil {
s++
}
}
{
l := uint64(len(d.Ba))
{
t := l
for t >= 0x80 {
t >>= 7
s++
}
s++
}
s += l
}
{
if d.Bap != nil {
{
l := uint64(len((*d.Bap)))
{
t := l
for t >= 0x80 {
t >>= 7
s++
}
s++
}
s += l
}
s += 0
}
}
s += 35
return
}
func (d *GenCodeTestStruct) GenCodeMarshal(buf []byte) ([]byte, error) {
size := d.Size()
{
if uint64(cap(buf)) >= size {
buf = buf[:size]
} else {
buf = make([]byte, size)
}
}
i := uint64(0)
{
buf[0+0] = byte(d.I8 >> 0)
}
{
buf[0+1] = byte(d.I16 >> 0)
buf[1+1] = byte(d.I16 >> 8)
}
{
buf[0+3] = byte(d.I32 >> 0)
buf[1+3] = byte(d.I32 >> 8)
buf[2+3] = byte(d.I32 >> 16)
buf[3+3] = byte(d.I32 >> 24)
}
{
buf[0+7] = byte(d.I64 >> 0)
buf[1+7] = byte(d.I64 >> 8)
buf[2+7] = byte(d.I64 >> 16)
buf[3+7] = byte(d.I64 >> 24)
buf[4+7] = byte(d.I64 >> 32)
buf[5+7] = byte(d.I64 >> 40)
buf[6+7] = byte(d.I64 >> 48)
buf[7+7] = byte(d.I64 >> 56)
}
{
buf[0+15] = byte(d.UI8 >> 0)
}
{
buf[0+16] = byte(d.UI16 >> 0)
buf[1+16] = byte(d.UI16 >> 8)
}
{
buf[0+18] = byte(d.UI32 >> 0)
buf[1+18] = byte(d.UI32 >> 8)
buf[2+18] = byte(d.UI32 >> 16)
buf[3+18] = byte(d.UI32 >> 24)
}
{
buf[0+22] = byte(d.UI64 >> 0)
buf[1+22] = byte(d.UI64 >> 8)
buf[2+22] = byte(d.UI64 >> 16)
buf[3+22] = byte(d.UI64 >> 24)
buf[4+22] = byte(d.UI64 >> 32)
buf[5+22] = byte(d.UI64 >> 40)
buf[6+22] = byte(d.UI64 >> 48)
buf[7+22] = byte(d.UI64 >> 56)
}
{
l := uint64(len(d.S))
{
t := uint64(l)
for t >= 0x80 {
buf[i+30] = byte(t) | 0x80
t >>= 7
i++
}
buf[i+30] = byte(t)
i++
}
copy(buf[i+30:], d.S)
i += l
}
{
if d.Sp == nil {
buf[i+30] = 0
} else {
buf[i+30] = 1
{
l := uint64(len((*d.Sp)))
{
t := uint64(l)
for t >= 0x80 {
buf[i+31] = byte(t) | 0x80
t >>= 7
i++
}
buf[i+31] = byte(t)
i++
}
copy(buf[i+31:], (*d.Sp))
i += l
}
i += 0
}
}
{
l := uint64(len(d.Sa))
{
t := uint64(l)
for t >= 0x80 {
buf[i+31] = byte(t) | 0x80
t >>= 7
i++
}
buf[i+31] = byte(t)
i++
}
for k0 := range d.Sa {
{
l := uint64(len(d.Sa[k0]))
{
t := uint64(l)
for t >= 0x80 {
buf[i+31] = byte(t) | 0x80
t >>= 7
i++
}
buf[i+31] = byte(t)
i++
}
copy(buf[i+31:], d.Sa[k0])
i += l
}
}
}
{
if d.Sap == nil {
buf[i+31] = 0
} else {
buf[i+31] = 1
{
l := uint64(len((*d.Sap)))
{
t := uint64(l)
for t >= 0x80 {
buf[i+32] = byte(t) | 0x80
t >>= 7
i++
}
buf[i+32] = byte(t)
i++
}
for k0 := range *d.Sap {
{
l := uint64(len((*d.Sap)[k0]))
{
t := uint64(l)
for t >= 0x80 {
buf[i+32] = byte(t) | 0x80
t >>= 7
i++
}
buf[i+32] = byte(t)
i++
}
copy(buf[i+32:], (*d.Sap)[k0])
i += l
}
}
}
i += 0
}
}
{
buf[i+32] = d.B
}
{
if d.Bp == nil {
buf[i+33] = 0
} else {
buf[i+33] = 1
{
buf[i+34] = (*d.Bp)
}
i++
}
}
{
l := uint64(len(d.Ba))
{
t := uint64(l)
for t >= 0x80 {
buf[i+34] = byte(t) | 0x80
t >>= 7
i++
}
buf[i+34] = byte(t)
i++
}
copy(buf[i+34:], d.Ba)
i += l
}
{
if d.Bap == nil {
buf[i+34] = 0
} else {
buf[i+34] = 1
{
l := uint64(len((*d.Bap)))
{
t := uint64(l)
for t >= 0x80 {
buf[i+35] = byte(t) | 0x80
t >>= 7
i++
}
buf[i+35] = byte(t)
i++
}
copy(buf[i+35:], (*d.Bap))
i += l
}
i += 0
}
}
return buf[:i+35], nil
}
func (d *GenCodeTestStruct) GenCodeUnmarshal(buf []byte) (uint64, error) {
i := uint64(0)
{
d.I8 = 0 | (int8(buf[i+0+0]) << 0)
}
{
d.I16 = 0 | (int16(buf[i+0+1]) << 0) | (int16(buf[i+1+1]) << 8)
}
{
d.I32 = 0 | (int32(buf[i+0+3]) << 0) | (int32(buf[i+1+3]) << 8) | (int32(buf[i+2+3]) << 16) | (int32(buf[i+3+3]) << 24)
}
{
d.I64 = 0 | (int64(buf[i+0+7]) << 0) | (int64(buf[i+1+7]) << 8) | (int64(buf[i+2+7]) << 16) | (int64(buf[i+3+7]) << 24) | (int64(buf[i+4+7]) << 32) | (int64(buf[i+5+7]) << 40) | (int64(buf[i+6+7]) << 48) | (int64(buf[i+7+7]) << 56)
}
{
d.UI8 = 0 | (uint8(buf[i+0+15]) << 0)
}
{
d.UI16 = 0 | (uint16(buf[i+0+16]) << 0) | (uint16(buf[i+1+16]) << 8)
}
{
d.UI32 = 0 | (uint32(buf[i+0+18]) << 0) | (uint32(buf[i+1+18]) << 8) | (uint32(buf[i+2+18]) << 16) | (uint32(buf[i+3+18]) << 24)
}
{
d.UI64 = 0 | (uint64(buf[i+0+22]) << 0) | (uint64(buf[i+1+22]) << 8) | (uint64(buf[i+2+22]) << 16) | (uint64(buf[i+3+22]) << 24) | (uint64(buf[i+4+22]) << 32) | (uint64(buf[i+5+22]) << 40) | (uint64(buf[i+6+22]) << 48) | (uint64(buf[i+7+22]) << 56)
}
{
l := uint64(0)
{
bs := uint8(7)
t := uint64(buf[i+30] & 0x7F)
for buf[i+30]&0x80 == 0x80 {
i++
t |= uint64(buf[i+30]&0x7F) << bs
bs += 7
}
i++
l = t
}
d.S = string(buf[i+30 : i+30+l])
i += l
}
{
if buf[i+30] == 1 {
if d.Sp == nil {
d.Sp = new(string)
}
{
l := uint64(0)
{
bs := uint8(7)
t := uint64(buf[i+31] & 0x7F)
for buf[i+31]&0x80 == 0x80 {
i++
t |= uint64(buf[i+31]&0x7F) << bs
bs += 7
}
i++
l = t
}
(*d.Sp) = string(buf[i+31 : i+31+l])
i += l
}
i += 0
} else {
d.Sp = nil
}
}
{
l := uint64(0)
{
bs := uint8(7)
t := uint64(buf[i+31] & 0x7F)
for buf[i+31]&0x80 == 0x80 {
i++
t |= uint64(buf[i+31]&0x7F) << bs
bs += 7
}
i++
l = t
}
if uint64(cap(d.Sa)) >= l {
d.Sa = d.Sa[:l]
} else {
d.Sa = make([]string, l)
}
for k0 := range d.Sa {
{
l := uint64(0)
{
bs := uint8(7)
t := uint64(buf[i+31] & 0x7F)
for buf[i+31]&0x80 == 0x80 {
i++
t |= uint64(buf[i+31]&0x7F) << bs
bs += 7
}
i++
l = t
}
d.Sa[k0] = string(buf[i+31 : i+31+l])
i += l
}
}
}
{
if buf[i+31] == 1 {
if d.Sap == nil {
d.Sap = new([]string)
}
{
l := uint64(0)
{
bs := uint8(7)
t := uint64(buf[i+32] & 0x7F)
for buf[i+32]&0x80 == 0x80 {
i++
t |= uint64(buf[i+32]&0x7F) << bs
bs += 7
}
i++
l = t
}
if uint64(cap((*d.Sap))) >= l {
(*d.Sap) = (*d.Sap)[:l]
} else {
(*d.Sap) = make([]string, l)
}
for k0 := range *d.Sap {
{
l := uint64(0)
{
bs := uint8(7)
t := uint64(buf[i+32] & 0x7F)
for buf[i+32]&0x80 == 0x80 {
i++
t |= uint64(buf[i+32]&0x7F) << bs
bs += 7
}
i++
l = t
}
(*d.Sap)[k0] = string(buf[i+32 : i+32+l])
i += l
}
}
}
i += 0
} else {
d.Sap = nil
}
}
{
d.B = buf[i+32]
}
{
if buf[i+33] == 1 {
if d.Bp == nil {
d.Bp = new(byte)
}
{
(*d.Bp) = buf[i+34]
}
i++
} else {
d.Bp = nil
}
}
{
l := uint64(0)
{
bs := uint8(7)
t := uint64(buf[i+34] & 0x7F)
for buf[i+34]&0x80 == 0x80 {
i++
t |= uint64(buf[i+34]&0x7F) << bs
bs += 7
}
i++
l = t
}
if uint64(cap(d.Ba)) >= l {
d.Ba = d.Ba[:l]
} else {
d.Ba = make([]byte, l)
}
copy(d.Ba, buf[i+34:])
i += l
}
{
if buf[i+34] == 1 {
if d.Bap == nil {
d.Bap = new([]byte)
}
{
l := uint64(0)
{
bs := uint8(7)
t := uint64(buf[i+35] & 0x7F)
for buf[i+35]&0x80 == 0x80 {
i++
t |= uint64(buf[i+35]&0x7F) << bs
bs += 7
}
i++
l = t
}
if uint64(cap((*d.Bap))) >= l {
(*d.Bap) = (*d.Bap)[:l]
} else {
(*d.Bap) = make([]byte, l)
}
copy((*d.Bap), buf[i+35:])
i += l
}
i += 0
} else {
d.Bap = nil
}
}
return i + 35, nil
}

View file

@ -0,0 +1,9 @@
package dsd
// GenCodeCompatible is an interface to identify and use gencode compatible structs.
type GenCodeCompatible interface {
// GenCodeMarshal gencode marshalls the struct into the given byte array, or a new one if its too small.
GenCodeMarshal(buf []byte) ([]byte, error)
// GenCodeUnmarshal gencode unmarshalls the struct and returns the bytes read.
GenCodeUnmarshal(buf []byte) (uint64, error)
}

23
formats/dsd/tests.gencode Normal file
View file

@ -0,0 +1,23 @@
struct SimpleTestStruct {
S string
B byte
}
struct GenCodeTestStructure {
I8 int8
I16 int16
I32 int32
I64 int64
UI8 uint8
UI16 uint16
UI32 uint32
UI64 uint64
S string
Sp *string
Sa []string
Sap *[]string
B byte
Bp *byte
Ba []byte
Bap *[]byte
}

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package varint package varint
import "errors" import "errors"

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package varint package varint
import ( import (

View file

@ -4,10 +4,10 @@ import "flag"
var ( var (
logLevelFlag string logLevelFlag string
fileLogLevelsFlag string pkgLogLevelsFlag string
) )
func init() { func init() {
flag.StringVar(&logLevelFlag, "log", "info", "set log level to [trace|debug|info|warning|error|critical]") flag.StringVar(&logLevelFlag, "log", "info", "set log level to [trace|debug|info|warning|error|critical]")
flag.StringVar(&fileLogLevelsFlag, "flog", "", "set log level of files: database=trace,firewall=debug") flag.StringVar(&pkgLogLevelsFlag, "plog", "", "set log level of packages: database=trace,firewall=debug")
} }

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package log package log
import ( import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package log package log
import ( import (
@ -11,7 +9,7 @@ import (
) )
func fastcheck(level severity) bool { func fastcheck(level severity) bool {
if fileLevelsActive.IsSet() { if pkgLevelsActive.IsSet() {
return true return true
} }
if uint32(level) < atomic.LoadUint32(logLevel) { if uint32(level) < atomic.LoadUint32(logLevel) {
@ -33,7 +31,7 @@ func log(level severity, msg string, trace *ContextTracer) {
} }
// check if level is enabled // check if level is enabled
if !fileLevelsActive.IsSet() && uint32(level) < atomic.LoadUint32(logLevel) { if !pkgLevelsActive.IsSet() && uint32(level) < atomic.LoadUint32(logLevel) {
return return
} }
@ -54,12 +52,12 @@ func log(level severity, msg string, trace *ContextTracer) {
} }
// check if level is enabled for file or generally // check if level is enabled for file or generally
if fileLevelsActive.IsSet() { if pkgLevelsActive.IsSet() {
fileOnly := strings.Split(file, "/") fileOnly := strings.Split(file, "/")
if len(fileOnly) < 2 { if len(fileOnly) < 2 {
return return
} }
sev, ok := fileLevels[fileOnly[len(fileOnly)-2]] sev, ok := pkgLevels[fileOnly[len(fileOnly)-2]]
if ok { if ok {
if level < sev { if level < sev {
return return

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package log package log
import ( import (
@ -75,9 +73,9 @@ var (
logLevelInt = uint32(3) logLevelInt = uint32(3)
logLevel = &logLevelInt logLevel = &logLevelInt
fileLevelsActive = abool.NewBool(false) pkgLevelsActive = abool.NewBool(false)
fileLevels = make(map[string]severity) pkgLevels = make(map[string]severity)
fileLevelsLock sync.Mutex pkgLevelsLock sync.Mutex
logsWaiting = make(chan bool, 1) logsWaiting = make(chan bool, 1)
logsWaitingFlag = abool.NewBool(false) logsWaitingFlag = abool.NewBool(false)
@ -92,15 +90,15 @@ var (
testErrors = abool.NewBool(false) testErrors = abool.NewBool(false)
) )
func SetFileLevels(levels map[string]severity) { func SetPkgLevels(levels map[string]severity) {
fileLevelsLock.Lock() pkgLevelsLock.Lock()
fileLevels = levels pkgLevels = levels
fileLevelsLock.Unlock() pkgLevelsLock.Unlock()
fileLevelsActive.Set() pkgLevelsActive.Set()
} }
func UnSetFileLevels() { func UnSetPkgLevels() {
fileLevelsActive.UnSet() pkgLevelsActive.UnSet()
} }
func SetLogLevel(level severity) { func SetLogLevel(level severity) {
@ -143,10 +141,10 @@ func Start() (err error) {
} }
// get and set file loglevels // get and set file loglevels
fileLogLevels := fileLogLevelsFlag pkgLogLevels := pkgLogLevelsFlag
if len(fileLogLevels) > 0 { if len(pkgLogLevels) > 0 {
newFileLevels := make(map[string]severity) newPkgLevels := make(map[string]severity)
for _, pair := range strings.Split(fileLogLevels, ",") { for _, pair := range strings.Split(pkgLogLevels, ",") {
splitted := strings.Split(pair, "=") splitted := strings.Split(pair, "=")
if len(splitted) != 2 { if len(splitted) != 2 {
err = fmt.Errorf("log warning: invalid file log level \"%s\", ignoring", pair) err = fmt.Errorf("log warning: invalid file log level \"%s\", ignoring", pair)
@ -159,9 +157,9 @@ func Start() (err error) {
fmt.Fprintf(os.Stderr, "%s\n", err.Error()) fmt.Fprintf(os.Stderr, "%s\n", err.Error())
break break
} }
newFileLevels[splitted[0]] = fileLevel newPkgLevels[splitted[0]] = fileLevel
} }
SetFileLevels(newFileLevels) SetPkgLevels(newPkgLevels)
} }
startWriter() startWriter()

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package log package log
import ( import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package log package log
import ( import (

View file

@ -3,11 +3,12 @@ package modules
import "flag" import "flag"
var ( var (
helpFlag bool // HelpFlag triggers printing flag.Usage. It's exported for custom help handling.
HelpFlag bool
) )
func init() { func init() {
flag.BoolVar(&helpFlag, "help", false, "print help") flag.BoolVar(&HelpFlag, "help", false, "print help")
} }
func parseFlags() error { func parseFlags() error {
@ -15,7 +16,7 @@ func parseFlags() error {
// parse flags // parse flags
flag.Parse() flag.Parse()
if helpFlag { if HelpFlag {
flag.Usage() flag.Usage()
return ErrCleanExit return ErrCleanExit
} }

View file

@ -1,12 +1,14 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package modules package modules
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"sync" "sync"
"sync/atomic"
"time"
"github.com/safing/portbase/log"
"github.com/tevino/abool" "github.com/tevino/abool"
) )
@ -21,32 +23,103 @@ var (
// Module represents a module. // Module represents a module.
type Module struct { type Module struct {
Name string Name string
// lifecycle mgmt
Prepped *abool.AtomicBool Prepped *abool.AtomicBool
Started *abool.AtomicBool Started *abool.AtomicBool
Stopped *abool.AtomicBool Stopped *abool.AtomicBool
inTransition *abool.AtomicBool inTransition *abool.AtomicBool
// lifecycle callback functions
prep func() error prep func() error
start func() error start func() error
stop func() error stop func() error
// shutdown mgmt
Ctx context.Context
cancelCtx func()
shutdownFlag *abool.AtomicBool
workerGroup sync.WaitGroup
workerCnt *int32
// dependency mgmt
depNames []string depNames []string
depModules []*Module depModules []*Module
depReverse []*Module depReverse []*Module
} }
// AddWorkers adds workers to the worker waitgroup. This is a failsafe wrapper for sync.Waitgroup.
func (m *Module) AddWorkers(n uint) {
if !m.ShutdownInProgress() {
if atomic.AddInt32(m.workerCnt, int32(n)) > 0 {
// only add to workgroup if cnt is positive (try to compensate wrong usage)
m.workerGroup.Add(int(n))
}
}
}
// FinishWorker removes a worker from the worker waitgroup. This is a failsafe wrapper for sync.Waitgroup.
func (m *Module) FinishWorker() {
// check worker cnt
if atomic.AddInt32(m.workerCnt, -1) < 0 {
log.Warningf("modules: %s module tried to finish more workers than added, this may lead to undefined behavior when shutting down", m.Name)
return
}
// also mark worker done in workgroup
m.workerGroup.Done()
}
// ShutdownInProgress returns whether the module has started shutting down. In most cases, you should use ShuttingDown instead.
func (m *Module) ShutdownInProgress() bool {
return m.shutdownFlag.IsSet()
}
// ShuttingDown lets you listen for the shutdown signal.
func (m *Module) ShuttingDown() <-chan struct{} {
return m.Ctx.Done()
}
func (m *Module) shutdown() error {
// signal shutdown
m.shutdownFlag.Set()
m.cancelCtx()
// wait for workers
done := make(chan struct{})
go func() {
m.workerGroup.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(3 * time.Second):
return errors.New("timed out while waiting for module workers to finish")
}
// call shutdown function
return m.stop()
}
func dummyAction() error { func dummyAction() error {
return nil return nil
} }
// Register registers a new module. // Register registers a new module. The control functions `prep`, `start` and `stop` are technically optional. `stop` is called _after_ all added module workers finished.
func Register(name string, prep, start, stop func() error, dependencies ...string) *Module { func Register(name string, prep, start, stop func() error, dependencies ...string) *Module {
ctx, cancelCtx := context.WithCancel(context.Background())
var workerCnt int32
newModule := &Module{ newModule := &Module{
Name: name, Name: name,
Prepped: abool.NewBool(false), Prepped: abool.NewBool(false),
Started: abool.NewBool(false), Started: abool.NewBool(false),
Stopped: abool.NewBool(false), Stopped: abool.NewBool(false),
inTransition: abool.NewBool(false), inTransition: abool.NewBool(false),
Ctx: ctx,
cancelCtx: cancelCtx,
shutdownFlag: abool.NewBool(false),
workerGroup: sync.WaitGroup{},
workerCnt: &workerCnt,
prep: prep, prep: prep,
start: start, start: start,
stop: stop, stop: stop,
@ -77,7 +150,7 @@ func initDependencies() error {
// get dependency // get dependency
depModule, ok := modules[depName] depModule, ok := modules[depName]
if !ok { if !ok {
return fmt.Errorf("modules: module %s declares dependency \"%s\", but this module has not been registered", m.Name, depName) return fmt.Errorf("module %s declares dependency \"%s\", but this module has not been registered", m.Name, depName)
} }
// link together // link together

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package modules package modules
import ( import (
@ -204,11 +202,11 @@ func TestErrors(t *testing.T) {
startCompleteSignal = make(chan struct{}) startCompleteSignal = make(chan struct{})
// test help flag // test help flag
helpFlag = true HelpFlag = true
err = Start() err = Start()
if err == nil { if err == nil {
t.Error("should fail") t.Error("should fail")
} }
helpFlag = false HelpFlag = false
} }

View file

@ -38,6 +38,7 @@ func Start() error {
// inter-link modules // inter-link modules
err := initDependencies() err := initDependencies()
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "CRITICAL ERROR: failed to initialize modules: %s\n", err)
return err return err
} }

View file

@ -74,7 +74,7 @@ func stopModules() error {
go func() { go func() {
reports <- &report{ reports <- &report{
module: execM, module: execM,
err: execM.stop(), err: execM.shutdown(),
} }
}() }()
} }

View file

@ -2,6 +2,8 @@ package notifications
import ( import (
"time" "time"
"github.com/safing/portbase/log"
) )
func cleaner() { func cleaner() {
@ -10,31 +12,55 @@ func cleaner() {
case <-shutdownSignal: case <-shutdownSignal:
shutdownWg.Done() shutdownWg.Done()
return return
case <-time.After(1 * time.Minute): case <-time.After(5 * time.Second):
cleanNotifications() cleanNotifications()
} }
} }
func cleanNotifications() { func cleanNotifications() {
threshold := time.Now().Add(-2 * time.Minute).Unix() now := time.Now().Unix()
maxThreshold := time.Now().Add(-72 * time.Hour).Unix() finishedThreshhold := time.Now().Add(-10 * time.Second).Unix()
executionTimelimit := time.Now().Add(-24 * time.Hour).Unix()
fallbackTimelimit := time.Now().Add(-72 * time.Hour).Unix()
notsLock.Lock() notsLock.Lock()
defer notsLock.Unlock() defer notsLock.Unlock()
for _, n := range nots { for _, n := range nots {
n.Lock() n.Lock()
if n.Expires != 0 && n.Expires < threshold || switch {
n.Executed != 0 && n.Executed < threshold || case n.Executed != 0: // notification was fully handled
n.Created < maxThreshold { // wait for a short time before deleting
if n.Executed < finishedThreshhold {
go deleteNotification(n)
}
case n.Responded != 0:
// waiting for execution
if n.Responded < executionTimelimit {
go deleteNotification(n)
}
case n.Expires != 0:
// expired without response
if n.Expires < now {
go deleteNotification(n)
}
case n.Created != 0:
// fallback: delete after 3 days after creation
if n.Created < fallbackTimelimit {
go deleteNotification(n)
// delete }
n.Meta().Delete() default:
delete(nots, n.ID) // invalid, impossible to determine cleanup timeframe, delete now
go deleteNotification(n)
// save (ie. propagate delete)
go n.Save()
} }
n.Unlock() n.Unlock()
} }
} }
func deleteNotification(n *Notification) {
err := n.Delete()
if err != nil {
log.Debugf("notifications: failed to delete %s: %s", n.ID, err)
}
}

View file

@ -43,6 +43,26 @@ type StorageInterface struct {
storage.InjectBase storage.InjectBase
} }
func registerAsDatabase() error {
_, err := database.Register(&database.Database{
Name: "notifications",
Description: "Notifications",
StorageType: "injected",
PrimaryAPI: "",
})
if err != nil {
return err
}
controller, err := database.InjectDatabase("notifications", &StorageInterface{})
if err != nil {
return err
}
dbController = controller
return nil
}
// Get returns a database record. // Get returns a database record.
func (s *StorageInterface) Get(key string) (record.Record, error) { func (s *StorageInterface) Get(key string) (record.Record, error) {
notsLock.RLock() notsLock.RLock()
@ -78,6 +98,10 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) {
// send all notifications // send all notifications
for _, n := range nots { for _, n := range nots {
if n.Meta().IsDeleted() {
continue
}
if q.MatchesKey(n.DatabaseKey()) && q.MatchesRecord(n) { if q.MatchesKey(n.DatabaseKey()) && q.MatchesRecord(n) {
it.Next <- n it.Next <- n
} }
@ -86,26 +110,6 @@ func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) {
it.Finish(nil) it.Finish(nil)
} }
func registerAsDatabase() error {
_, err := database.Register(&database.Database{
Name: "notifications",
Description: "Notifications",
StorageType: "injected",
PrimaryAPI: "",
})
if err != nil {
return err
}
controller, err := database.InjectDatabase("notifications", &StorageInterface{})
if err != nil {
return err
}
dbController = controller
return nil
}
// Put stores a record in the database. // Put stores a record in the database.
func (s *StorageInterface) Put(r record.Record) error { func (s *StorageInterface) Put(r record.Record) error {
// record is already locked! // record is already locked!
@ -124,36 +128,75 @@ func (s *StorageInterface) Put(r record.Record) error {
} }
// continue in goroutine // continue in goroutine
go updateNotificationFromDatabasePut(n, key) go UpdateNotification(n, key)
return nil return nil
} }
func updateNotificationFromDatabasePut(n *Notification, key string) { // UpdateNotification updates a notification with input from a database action. Notification will not be saved/propagated if there is no valid change.
func UpdateNotification(n *Notification, key string) {
n.Lock()
defer n.Unlock()
// seperate goroutine in order to correctly lock notsLock // seperate goroutine in order to correctly lock notsLock
notsLock.RLock() notsLock.RLock()
origN, ok := nots[key] origN, ok := nots[key]
notsLock.RUnlock() notsLock.RUnlock()
if ok { save := false
// existing notification, update selected action ID only
n.Lock() // ignore if already deleted
defer n.Unlock() if ok && origN.Meta().IsDeleted() {
if n.SelectedActionID != "" { ok = false
log.Tracef("notifications: user selected action for %s: %s", n.ID, n.SelectedActionID)
go origN.SelectAndExecuteAction(n.SelectedActionID)
} }
if ok {
// existing notification
// only update select attributes
origN.Lock()
defer origN.Unlock()
} else { } else {
// accept new notification as is // new notification (from external source): old == new
notsLock.Lock() origN = n
nots[key] = n save = true
notsLock.Unlock() }
switch {
case n.SelectedActionID != "" && n.Responded == 0:
// select action, if not yet already handled
log.Tracef("notifications: selected action for %s: %s", n.ID, n.SelectedActionID)
origN.selectAndExecuteAction(n.SelectedActionID)
save = true
case origN.Executed == 0 && n.Executed != 0:
log.Tracef("notifications: action for %s executed externally", n.ID)
origN.Executed = n.Executed
save = true
}
if save {
// we may be locking
go origN.Save()
} }
} }
// Delete deletes a record from the database. // Delete deletes a record from the database.
func (s *StorageInterface) Delete(key string) error { func (s *StorageInterface) Delete(key string) error {
return ErrNoDelete // transform key
if strings.HasPrefix(key, "all/") {
key = strings.TrimPrefix(key, "all/")
} else {
return storage.ErrNotFound
}
// get notification
notsLock.Lock()
n, ok := nots[key]
notsLock.Unlock()
if !ok {
return storage.ErrNotFound
}
// delete
return n.Delete()
} }
// ReadOnly returns whether the database is read only. // ReadOnly returns whether the database is read only.

View file

@ -12,7 +12,7 @@ var (
) )
func init() { func init() {
modules.Register("notifications", nil, start, nil, "core") modules.Register("notifications", nil, start, nil, "base", "database")
} }
func start() error { func start() error {

View file

@ -5,8 +5,11 @@ import (
"sync" "sync"
"time" "time"
"github.com/safing/portbase/database"
"github.com/safing/portbase/database/record" "github.com/safing/portbase/database/record"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
uuid "github.com/satori/go.uuid"
) )
// Notification types // Notification types
@ -21,6 +24,8 @@ type Notification struct {
record.Base record.Base
ID string ID string
GUID string
Message string Message string
// MessageTemplate string // MessageTemplate string
// MessageData []string // MessageData []string
@ -39,6 +44,7 @@ type Notification struct {
lock sync.Mutex lock sync.Mutex
actionFunction func(*Notification) // call function to process action actionFunction func(*Notification) // call function to process action
actionTrigger chan string // and/or send to a channel actionTrigger chan string // and/or send to a channel
expiredTrigger chan struct{} // closed on expire
} }
// Action describes an action that can be taken for a notification. // Action describes an action that can be taken for a notification.
@ -62,12 +68,6 @@ func Get(id string) *Notification {
return nil return nil
} }
// Init initializes a Notification and returns it.
func (n *Notification) Init() *Notification {
n.Created = time.Now().Unix()
return n
}
// Save saves the notification and returns it. // Save saves the notification and returns it.
func (n *Notification) Save() *Notification { func (n *Notification) Save() *Notification {
notsLock.Lock() notsLock.Lock()
@ -75,6 +75,13 @@ func (n *Notification) Save() *Notification {
n.Lock() n.Lock()
defer n.Unlock() defer n.Unlock()
// initialize
if n.Created == 0 {
n.Created = time.Now().Unix()
}
if n.GUID == "" {
n.GUID = uuid.NewV4().String()
}
// check key // check key
if n.DatabaseKey() == "" { if n.DatabaseKey() == "" {
n.SetKey(fmt.Sprintf("notifications:all/%s", n.ID)) n.SetKey(fmt.Sprintf("notifications:all/%s", n.ID))
@ -104,11 +111,12 @@ func (n *Notification) Save() *Notification {
Executed: n.Executed, Executed: n.Executed,
} }
duplicate.SetMeta(n.Meta().Duplicate()) duplicate.SetMeta(n.Meta().Duplicate())
duplicate.SetKey(fmt.Sprintf("%s/%s", persistentBasePath, n.ID)) key := fmt.Sprintf("%s/%s", persistentBasePath, n.ID)
duplicate.SetKey(key)
go func() { go func() {
err := dbInterface.Put(duplicate) err := dbInterface.Put(duplicate)
if err != nil { if err != nil {
log.Warningf("notifications: failed to persist notification %s: %s", n.Key(), err) log.Warningf("notifications: failed to persist notification %s: %s", key, err)
} }
}() }()
} }
@ -145,42 +153,85 @@ func (n *Notification) MakeAck() *Notification {
// Response waits for the user to respond to the notification and returns the selected action. // Response waits for the user to respond to the notification and returns the selected action.
func (n *Notification) Response() <-chan string { func (n *Notification) Response() <-chan string {
n.lock.Lock() n.lock.Lock()
defer n.lock.Unlock()
if n.actionTrigger == nil { if n.actionTrigger == nil {
n.actionTrigger = make(chan string) n.actionTrigger = make(chan string)
} }
n.lock.Unlock()
return n.actionTrigger return n.actionTrigger
} }
// Cancel (prematurely) destroys a notification. // Update updates/resends a notification if it was not already responded to.
func (n *Notification) Cancel() { func (n *Notification) Update(expires int64) {
responded := true
n.lock.Lock()
if n.Responded == 0 {
responded = false
n.Expires = expires
}
n.lock.Unlock()
// save if not yet responded
if !responded {
n.Save()
}
}
// Delete (prematurely) cancels and deletes a notification.
func (n *Notification) Delete() error {
notsLock.Lock() notsLock.Lock()
defer notsLock.Unlock() defer notsLock.Unlock()
n.Lock() n.Lock()
defer n.Unlock() defer n.Unlock()
// delete // mark as deleted
n.Meta().Delete() n.Meta().Delete()
// delete from internal storage
delete(nots, n.ID) delete(nots, n.ID)
// save (ie. propagate delete) // close expired
go n.Save() if n.expiredTrigger != nil {
close(n.expiredTrigger)
n.expiredTrigger = nil
}
// push update
dbController.PushUpdate(n)
// delete from persistent storage
if n.Persistent && persistentBasePath != "" {
key := fmt.Sprintf("%s/%s", persistentBasePath, n.ID)
err := dbInterface.Delete(key)
if err != nil && err != database.ErrNotFound {
return fmt.Errorf("failed to delete persisted notification %s from database: %s", key, err)
}
}
return nil
} }
// SelectAndExecuteAction sets the user response and executes/triggers the action, if possible. // Expired notifies the caller when the notification has expired.
func (n *Notification) SelectAndExecuteAction(id string) { func (n *Notification) Expired() <-chan struct{} {
n.Lock() n.lock.Lock()
defer n.Unlock() if n.expiredTrigger == nil {
n.expiredTrigger = make(chan struct{})
}
n.lock.Unlock()
// update selection return n.expiredTrigger
}
// selectAndExecuteAction sets the user response and executes/triggers the action, if possible.
func (n *Notification) selectAndExecuteAction(id string) {
// abort if already executed
if n.Executed != 0 { if n.Executed != 0 {
// we already executed
return return
} }
n.SelectedActionID = id
// set response
n.Responded = time.Now().Unix() n.Responded = time.Now().Unix()
n.SelectedActionID = id
// execute // execute
executed := false executed := false
@ -195,7 +246,7 @@ func (n *Notification) SelectAndExecuteAction(id string) {
select { select {
case n.actionTrigger <- n.SelectedActionID: case n.actionTrigger <- n.SelectedActionID:
executed = true executed = true
default: case <-time.After(100 * time.Millisecond): // mitigate race conditions
break triggerAll break triggerAll
} }
} }
@ -205,8 +256,6 @@ func (n *Notification) SelectAndExecuteAction(id string) {
if executed { if executed {
n.Executed = time.Now().Unix() n.Executed = time.Now().Unix()
} }
go n.Save()
} }
// AddDataSubject adds the data subject to the notification. This is the only way how a data subject should be added - it avoids locking problems. // AddDataSubject adds the data subject to the notification. This is the only way how a data subject should be added - it avoids locking problems.

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package taskmanager package taskmanager
import ( import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package taskmanager package taskmanager
import ( import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package taskmanager package taskmanager
import ( import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package taskmanager package taskmanager
import ( import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package taskmanager package taskmanager
import ( import (

View file

@ -1,5 +1,3 @@
// Copyright Safing ICS Technologies GmbH. Use of this source code is governed by the AGPL license that can be found in the LICENSE file.
package taskmanager package taskmanager
import ( import (

View file

@ -6,6 +6,8 @@ import (
"runtime" "runtime"
) )
const isWindows = runtime.GOOS == "windows"
// EnsureDirectory ensures that the given directoy exists and that is has the given permissions set. // EnsureDirectory ensures that the given directoy exists and that is has the given permissions set.
// If path is a file, it is deleted and a directory created. // If path is a file, it is deleted and a directory created.
// If a directory is created, also all missing directories up to the required one are created with the given permissions. // If a directory is created, also all missing directories up to the required one are created with the given permissions.
@ -16,7 +18,7 @@ func EnsureDirectory(path string, perm os.FileMode) error {
// file exists // file exists
if f.IsDir() { if f.IsDir() {
// directory exists, check permissions // directory exists, check permissions
if runtime.GOOS == "windows" { if isWindows {
// TODO: set correct permission on windows // TODO: set correct permission on windows
// acl.Chmod(path, perm) // acl.Chmod(path, perm)
} else if f.Mode().Perm() != perm { } else if f.Mode().Perm() != perm {

View file

@ -20,12 +20,13 @@ func EnableColorSupport() bool {
if !colorSupportChecked { if !colorSupportChecked {
colorSupport = enableColorSupport() colorSupport = enableColorSupport()
colorSupportChecked = true
} }
return colorSupport return colorSupport
} }
func enableColorSupport() bool { func enableColorSupport() bool {
if IsWindowsVersion("10.") { if IsAtLeastWindowsNTVersionWithDefault("10", false) {
// check if windows.Stdout is file // check if windows.Stdout is file
if windows.GetFileInformationByHandle(windows.Stdout, &windows.ByHandleFileInformation{}) == nil { if windows.GetFileInformationByHandle(windows.Stdout, &windows.ByHandleFileInformation{}) == nil {

View file

@ -1,56 +1,100 @@
package osdetail package osdetail
import ( import (
"os/exec" "fmt"
"regexp" "regexp"
"strings"
"sync" "sync"
)
// FIXME: use https://godoc.org/github.com/shirou/gopsutil/host#PlatformInformation instead "github.com/hashicorp/go-version"
versionCmp "github.com/hashicorp/go-version"
"github.com/shirou/gopsutil/host"
)
var ( var (
versionRe = regexp.MustCompile(`[0-9\.]+`) versionRe = regexp.MustCompile(`[0-9\.]+`)
windowsVersion string windowsNTVersion string
windowsNTVersionForCmp *versionCmp.Version
fetching sync.Mutex fetching sync.Mutex
fetched bool fetched bool
) )
func fetchVersion() { // WindowsNTVersion returns the current Windows version.
func WindowsNTVersion() (string, error) {
var err error
fetching.Lock()
defer fetching.Unlock()
if !fetched { if !fetched {
fetched = true _, _, windowsNTVersion, err = host.PlatformInformation()
output, err := exec.Command("cmd", "ver").Output()
if err != nil { if err != nil {
return return "", fmt.Errorf("failed to obtain Windows-Version: %s", err)
} }
match := versionRe.Find(output) windowsNTVersionForCmp, err = version.NewVersion(windowsNTVersion)
if match == nil {
return if err != nil {
return "", fmt.Errorf("failed to parse Windows-Version %s: %s", windowsNTVersion, err)
} }
windowsVersion = string(match) fetched = true
} }
return windowsNTVersion, err
} }
// WindowsVersion returns the current Windows version. // IsAtLeastWindowsNTVersion returns whether the current WindowsNT version is at least the given version or newer.
func WindowsVersion() string { func IsAtLeastWindowsNTVersion(version string) (bool, error) {
fetching.Lock() _, err := WindowsNTVersion()
defer fetching.Unlock() if err != nil {
fetchVersion() return false, err
}
return windowsVersion versionForCmp, err := versionCmp.NewVersion(version)
if err != nil {
return false, err
}
return windowsNTVersionForCmp.GreaterThanOrEqual(versionForCmp), nil
} }
// IsWindowsVersion returns whether the given version matches (HasPrefix) the current Windows version. // IsAtLeastWindowsNTVersionWithDefault is like IsAtLeastWindowsNTVersion(), but keeps the Error and returns the default Value in Errorcase
func IsWindowsVersion(version string) bool { func IsAtLeastWindowsNTVersionWithDefault(v string, defaultValue bool) bool {
fetching.Lock() val, err := IsAtLeastWindowsNTVersion(v)
defer fetching.Unlock() if err != nil {
fetchVersion() return defaultValue
}
// TODO: we can do better. return val
return strings.HasPrefix(windowsVersion, version) }
// IsAtLeastWindowsVersion returns whether the current Windows version is at least the given version or newer.
func IsAtLeastWindowsVersion(v string) (bool, error) {
var (
NTVersion string
)
switch v {
case "7":
NTVersion = "6.1"
case "8":
NTVersion = "6.2"
case "8.1":
NTVersion = "6.3"
case "10":
NTVersion = "10"
default:
return false, fmt.Errorf("failed to compare Windows-Version: Windows %s is unknown", v)
}
return IsAtLeastWindowsNTVersion(NTVersion)
}
// IsAtLeastWindowsVersionWithDefault is like IsAtLeastWindowsVersion(), but keeps the Error and returns the default Value in Errorcase
func IsAtLeastWindowsVersionWithDefault(v string, defaultValue bool) bool {
val, err := IsAtLeastWindowsVersion(v)
if err != nil {
return defaultValue
}
return val
} }

View file

@ -2,8 +2,28 @@ package osdetail
import "testing" import "testing"
func TestWindowsVersion(t *testing.T) { func TestWindowsNTVersion(t *testing.T) {
if WindowsVersion() == "" { if str, err := WindowsNTVersion(); str == "" || err != nil {
t.Fatal("could not get windows version") t.Fatalf("failed to obtain windows version: %s", err)
}
}
func TestIsAtLeastWindowsNTVersion(t *testing.T) {
ret, err := IsAtLeastWindowsNTVersion("6")
if err != nil {
t.Fatalf("failed to compare windows versions: %s", err)
}
if !ret {
t.Fatalf("WindowsNTVersion is less than 6 (Vista)")
}
}
func TestIsAtLeastWindowsVersion(t *testing.T) {
ret, err := IsAtLeastWindowsVersion("7")
if err != nil {
t.Fatalf("failed to compare windows versions: %s", err)
}
if !ret {
t.Fatalf("WindowsVersion is less than 7")
} }
} }

139
utils/structure.go Normal file
View file

@ -0,0 +1,139 @@
package utils
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
)
// DirStructure represents a directory structure with permissions that should be enforced.
type DirStructure struct {
sync.Mutex
Path string
Dir string
Perm os.FileMode
Parent *DirStructure
Children map[string]*DirStructure
}
// NewDirStructure returns a new DirStructure.
func NewDirStructure(path string, perm os.FileMode) *DirStructure {
return &DirStructure{
Path: path,
Perm: perm,
Children: make(map[string]*DirStructure),
}
}
// ChildDir adds a new child DirStructure and returns it. Should the child already exist, the existing child is returned and the permissions are updated.
func (ds *DirStructure) ChildDir(dirName string, perm os.FileMode) (child *DirStructure) {
ds.Lock()
defer ds.Unlock()
// if exists, update
child, ok := ds.Children[dirName]
if ok {
child.Perm = perm
return child
}
// create new
new := &DirStructure{
Path: filepath.Join(ds.Path, dirName),
Dir: dirName,
Perm: perm,
Parent: ds,
Children: make(map[string]*DirStructure),
}
ds.Children[dirName] = new
return new
}
// Ensure ensures that the specified directory structure (from the first parent on) exists.
func (ds *DirStructure) Ensure() error {
return ds.EnsureAbsPath(ds.Path)
}
// EnsureRelPath ensures that the specified directory structure (from the first parent on) and the given relative path (to the DirStructure) exists.
func (ds *DirStructure) EnsureRelPath(dirPath string) error {
return ds.EnsureAbsPath(filepath.Join(ds.Path, dirPath))
}
// EnsureRelDir ensures that the specified directory structure (from the first parent on) and the given relative path (to the DirStructure) exists.
func (ds *DirStructure) EnsureRelDir(dirNames ...string) error {
return ds.EnsureAbsPath(filepath.Join(append([]string{ds.Path}, dirNames...)...))
}
// EnsureAbsPath ensures that the specified directory structure (from the first parent on) and the given absolute path exists.
// If the given path is outside the DirStructure, an error will be returned.
func (ds *DirStructure) EnsureAbsPath(dirPath string) error {
// always start at the top
if ds.Parent != nil {
return ds.Parent.EnsureAbsPath(dirPath)
}
// check if root
if dirPath == ds.Path {
return ds.ensure(nil)
}
// check scope
slashedPath := ds.Path
// add slash to end
if !strings.HasSuffix(slashedPath, string(filepath.Separator)) {
slashedPath += string(filepath.Separator)
}
// check if given path is in scope
if !strings.HasPrefix(dirPath, slashedPath) {
return fmt.Errorf(`path "%s" is outside of DirStructure scope`, dirPath)
}
// get relative path
relPath, err := filepath.Rel(ds.Path, dirPath)
if err != nil {
return fmt.Errorf("failed to get relative path: %s", err)
}
// split to path elements
pathDirs := strings.Split(filepath.ToSlash(relPath), "/")
// start checking
return ds.ensure(pathDirs)
}
func (ds *DirStructure) ensure(pathDirs []string) error {
ds.Lock()
defer ds.Unlock()
// check current dir
err := EnsureDirectory(ds.Path, ds.Perm)
if err != nil {
return err
}
if len(pathDirs) == 0 {
// we reached the end!
return nil
}
child, ok := ds.Children[pathDirs[0]]
if !ok {
// we have reached the end of the defined dir structure
// ensure all remaining dirs
dirPath := ds.Path
for _, dir := range pathDirs {
dirPath = filepath.Join(dirPath, dir)
err := EnsureDirectory(dirPath, ds.Perm)
if err != nil {
return err
}
}
return nil
}
// we got a child, continue
return child.ensure(pathDirs[1:])
}

72
utils/structure_test.go Normal file
View file

@ -0,0 +1,72 @@
// +build !windows
package utils
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strings"
)
func ExampleDirStructure() {
// output:
// / [755]
// /repo [777]
// /repo/b [755]
// /repo/b/c [750]
// /repo/b/d [755]
// /repo/b/d/e [755]
// /repo/b/d/f [755]
// /secret [700]
basePath, err := ioutil.TempDir("", "")
if err != nil {
fmt.Println(err)
return
}
ds := NewDirStructure(basePath, 0755)
secret := ds.ChildDir("secret", 0700)
repo := ds.ChildDir("repo", 0777)
_ = repo.ChildDir("a", 0700)
b := repo.ChildDir("b", 0755)
c := b.ChildDir("c", 0750)
err = ds.Ensure()
if err != nil {
fmt.Println(err)
}
err = c.Ensure()
if err != nil {
fmt.Println(err)
}
err = secret.Ensure()
if err != nil {
fmt.Println(err)
}
err = b.EnsureRelDir("d", "e")
if err != nil {
fmt.Println(err)
}
err = b.EnsureRelPath("d/f")
if err != nil {
fmt.Println(err)
}
filepath.Walk(basePath, func(path string, info os.FileInfo, err error) error {
if err == nil {
dir := strings.TrimPrefix(path, basePath)
if dir == "" {
dir = "/"
}
fmt.Printf("%s [%o]\n", dir, info.Mode().Perm())
}
return nil
})
}