diff --git a/.golangci.yml b/.golangci.yml index 7b51b70..3c2d6b3 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -4,3 +4,8 @@ linters: - lll - gochecknoinits - gochecknoglobals + - funlen + - whitespace + - wsl + - godox + diff --git a/.travis.yml b/.travis.yml index 4d85a1b..3571007 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,8 @@ language: go +go: +- 1.x + os: - linux - windows diff --git a/Gopkg.lock b/Gopkg.lock index 307a13e..11cdb8e 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -123,6 +123,14 @@ revision = "ac23dc3fea5d1a983c43f6a0f6e2c13f0195d8bd" version = "v1.2.0" +[[projects]] + digest = "1:870d441fe217b8e689d7949fef6e43efbc787e50f200cb1e70dbca9204a1d6be" + name = "github.com/inconshreveable/mousetrap" + packages = ["."] + pruneopts = "UT" + revision = "76626ae9c91c4f2a10f34cad8ce83ea42c93bb75" + version = "v1.0" + [[projects]] branch = "master" digest = "1:7e8b852581596acce37bcb939a05d7d5ff27156045b50057e659e299c16fc1ca" @@ -208,6 +216,22 @@ pruneopts = "UT" revision = "bb4de0191aa41b5507caa14b0650cdbddcd9280b" +[[projects]] + digest = "1:e096613fb7cf34743d49af87d197663cfccd61876e2219853005a57baedfa562" + name = "github.com/spf13/cobra" + packages = ["."] + pruneopts = "UT" + revision = "f2b07da1e2c38d5f12845a4f607e2e1018cbb1f5" + version = "v0.0.5" + +[[projects]] + digest = "1:524b71991fc7d9246cc7dc2d9e0886ccb97648091c63e30eef619e6862c955dd" + name = "github.com/spf13/pflag" + packages = ["."] + pruneopts = "UT" + revision = "2e9d26c8c37aae03e3f9d4e90b7116f5accb7cab" + version = "v1.0.5" + [[projects]] branch = "master" digest = "1:93d6687fc19da8a35c7352d72117a6acd2072dfb7e9bfd65646227bf2a913b2a" @@ -312,6 +336,7 @@ "github.com/satori/go.uuid", "github.com/seehuhn/fortuna", "github.com/shirou/gopsutil/host", + "github.com/spf13/cobra", "github.com/tevino/abool", "github.com/tidwall/gjson", "github.com/tidwall/sjson", diff --git a/api/main.go b/api/main.go index 830fe79..f0ea824 100644 --- a/api/main.go +++ b/api/main.go @@ -7,13 +7,17 @@ import ( "github.com/safing/portbase/modules" ) +var ( + module *modules.Module +) + // API Errors var ( ErrAuthenticationAlreadySet = errors.New("the authentication function has already been set") ) func init() { - modules.Register("api", prep, start, stop, "base", "database", "config") + module = modules.Register("api", prep, start, stop, "base", "database", "config") } func prep() error { diff --git a/api/router.go b/api/router.go index bce3884..bed878e 100644 --- a/api/router.go +++ b/api/router.go @@ -1,8 +1,10 @@ package api import ( + "context" "net/http" "sync" + "time" "github.com/gorilla/mux" @@ -56,8 +58,20 @@ func Serve() { // start serving log.Infof("api: starting to listen on %s", server.Addr) - // TODO: retry if failed - log.Errorf("api: failed to listen on %s: %s", server.Addr, server.ListenAndServe()) + backoffDuration := 10 * time.Second + for { + // always returns an error + err := module.RunWorker("http endpoint", func(ctx context.Context) error { + return server.ListenAndServe() + }) + // return on shutdown error + if err == http.ErrServerClosed { + return + } + // log error and restart + log.Errorf("api: http endpoint failed: %s - restarting in %s", err, backoffDuration) + time.Sleep(backoffDuration) + } } // GetMuxVars wraps github.com/gorilla/mux.Vars in order to mitigate context key issues in multi-repo projects. diff --git a/config/doc.go b/config/doc.go new file mode 100644 index 0000000..6c023dc --- /dev/null +++ b/config/doc.go @@ -0,0 +1,2 @@ +// Package config provides a versatile configuration management system. +package config diff --git a/config/expertise.go b/config/expertise.go new file mode 100644 index 0000000..9f515d4 --- /dev/null +++ b/config/expertise.go @@ -0,0 +1,72 @@ +// Package config ... (linter fix) +//nolint:dupl +package config + +import ( + "fmt" + "sync/atomic" +) + +// Expertise Level constants +const ( + ExpertiseLevelUser uint8 = 0 + ExpertiseLevelExpert uint8 = 1 + ExpertiseLevelDeveloper uint8 = 2 + + ExpertiseLevelNameUser = "user" + ExpertiseLevelNameExpert = "expert" + ExpertiseLevelNameDeveloper = "developer" + + expertiseLevelKey = "core/expertiseLevel" +) + +var ( + expertiseLevel *int32 +) + +func init() { + var expertiseLevelVal int32 + expertiseLevel = &expertiseLevelVal + + registerExpertiseLevelOption() +} + +func registerExpertiseLevelOption() { + err := Register(&Option{ + Name: "Expertise Level", + Key: expertiseLevelKey, + Description: "The Expertise Level controls the perceived complexity. Higher settings will show you more complex settings and information. This might also affect various other things relying on this setting. Modified settings in higher expertise levels stay in effect when switching back. (Unlike the Release Level)", + + OptType: OptTypeString, + ExpertiseLevel: ExpertiseLevelUser, + ReleaseLevel: ExpertiseLevelUser, + + RequiresRestart: false, + DefaultValue: ExpertiseLevelNameUser, + + ExternalOptType: "string list", + ValidationRegex: fmt.Sprintf("^(%s|%s|%s)$", ExpertiseLevelNameUser, ExpertiseLevelNameExpert, ExpertiseLevelNameDeveloper), + }) + if err != nil { + panic(err) + } +} + +func updateExpertiseLevel() { + new := findStringValue(expertiseLevelKey, "") + switch new { + case ExpertiseLevelNameUser: + atomic.StoreInt32(expertiseLevel, int32(ExpertiseLevelUser)) + case ExpertiseLevelNameExpert: + atomic.StoreInt32(expertiseLevel, int32(ExpertiseLevelExpert)) + case ExpertiseLevelNameDeveloper: + atomic.StoreInt32(expertiseLevel, int32(ExpertiseLevelDeveloper)) + default: + atomic.StoreInt32(expertiseLevel, int32(ExpertiseLevelUser)) + } +} + +// GetExpertiseLevel returns the current active expertise level. +func GetExpertiseLevel() uint8 { + return uint8(atomic.LoadInt32(expertiseLevel)) +} diff --git a/config/get.go b/config/get.go index 4862ae5..7bc0501 100644 --- a/config/get.go +++ b/config/get.go @@ -81,21 +81,7 @@ func findValue(key string) interface{} { option.Lock() defer option.Unlock() - // check if option is active - optionActive := true - switch getReleaseLevel() { - case ReleaseLevelStable: - // In stable, only stable is active - optionActive = option.ReleaseLevel == ReleaseLevelStable - case ReleaseLevelBeta: - // In beta, only stable and beta are active - optionActive = option.ReleaseLevel == ReleaseLevelStable || option.ReleaseLevel == ReleaseLevelBeta - case ReleaseLevelExperimental: - // In experimental, everything is active - optionActive = true - } - - if optionActive && option.activeValue != nil { + if option.ReleaseLevel <= getReleaseLevel() && option.activeValue != nil { return option.activeValue } diff --git a/config/get_test.go b/config/get_test.go index d092c85..3028619 100644 --- a/config/get_test.go +++ b/config/get_test.go @@ -160,21 +160,21 @@ func TestReleaseLevel(t *testing.T) { // test option level stable subsystemOption.ReleaseLevel = ReleaseLevelStable - err = SetConfigOption(releaseLevelKey, ReleaseLevelStable) + err = SetConfigOption(releaseLevelKey, ReleaseLevelNameStable) if err != nil { t.Fatal(err) } if !testSubsystem() { t.Error("should be active") } - err = SetConfigOption(releaseLevelKey, ReleaseLevelBeta) + err = SetConfigOption(releaseLevelKey, ReleaseLevelNameBeta) if err != nil { t.Fatal(err) } if !testSubsystem() { t.Error("should be active") } - err = SetConfigOption(releaseLevelKey, ReleaseLevelExperimental) + err = SetConfigOption(releaseLevelKey, ReleaseLevelNameExperimental) if err != nil { t.Fatal(err) } @@ -184,21 +184,21 @@ func TestReleaseLevel(t *testing.T) { // test option level beta subsystemOption.ReleaseLevel = ReleaseLevelBeta - err = SetConfigOption(releaseLevelKey, ReleaseLevelStable) + err = SetConfigOption(releaseLevelKey, ReleaseLevelNameStable) if err != nil { t.Fatal(err) } if testSubsystem() { - t.Errorf("should be inactive: opt=%s system=%s", subsystemOption.ReleaseLevel, releaseLevel) + t.Errorf("should be inactive: opt=%d system=%d", subsystemOption.ReleaseLevel, getReleaseLevel()) } - err = SetConfigOption(releaseLevelKey, ReleaseLevelBeta) + err = SetConfigOption(releaseLevelKey, ReleaseLevelNameBeta) if err != nil { t.Fatal(err) } if !testSubsystem() { t.Error("should be active") } - err = SetConfigOption(releaseLevelKey, ReleaseLevelExperimental) + err = SetConfigOption(releaseLevelKey, ReleaseLevelNameExperimental) if err != nil { t.Fatal(err) } @@ -208,21 +208,21 @@ func TestReleaseLevel(t *testing.T) { // test option level experimental subsystemOption.ReleaseLevel = ReleaseLevelExperimental - err = SetConfigOption(releaseLevelKey, ReleaseLevelStable) + err = SetConfigOption(releaseLevelKey, ReleaseLevelNameStable) if err != nil { t.Fatal(err) } if testSubsystem() { t.Error("should be inactive") } - err = SetConfigOption(releaseLevelKey, ReleaseLevelBeta) + err = SetConfigOption(releaseLevelKey, ReleaseLevelNameBeta) if err != nil { t.Fatal(err) } if testSubsystem() { t.Error("should be inactive") } - err = SetConfigOption(releaseLevelKey, ReleaseLevelExperimental) + err = SetConfigOption(releaseLevelKey, ReleaseLevelNameExperimental) if err != nil { t.Fatal(err) } diff --git a/config/main.go b/config/main.go index 5267634..a37f8f6 100644 --- a/config/main.go +++ b/config/main.go @@ -10,7 +10,12 @@ import ( "github.com/safing/portmaster/core/structure" ) +const ( + configChangeEvent = "config change" +) + var ( + module *modules.Module dataRoot *utils.DirStructure ) @@ -22,7 +27,8 @@ func SetDataRoot(root *utils.DirStructure) { } func init() { - modules.Register("config", prep, start, nil, "base", "database") + module = modules.Register("config", prep, start, nil, "base", "database") + module.RegisterEvent(configChangeEvent) } func prep() error { diff --git a/config/option.go b/config/option.go index 85322db..db33e3a 100644 --- a/config/option.go +++ b/config/option.go @@ -17,14 +17,6 @@ const ( OptTypeStringArray uint8 = 2 OptTypeInt uint8 = 3 OptTypeBool uint8 = 4 - - ExpertiseLevelUser uint8 = 1 - ExpertiseLevelExpert uint8 = 2 - ExpertiseLevelDeveloper uint8 = 3 - - ReleaseLevelStable = "stable" - ReleaseLevelBeta = "beta" - ReleaseLevelExperimental = "experimental" ) func getTypeName(t uint8) string { @@ -50,9 +42,9 @@ type Option struct { Key string // in path format: category/sub/key Description string - ReleaseLevel string - ExpertiseLevel uint8 OptType uint8 + ExpertiseLevel uint8 + ReleaseLevel uint8 RequiresRestart bool DefaultValue interface{} diff --git a/config/registry.go b/config/registry.go index d837529..0b2c221 100644 --- a/config/registry.go +++ b/config/registry.go @@ -1,7 +1,6 @@ package config import ( - "errors" "fmt" "regexp" "sync" @@ -10,21 +9,21 @@ import ( var ( optionsLock sync.RWMutex options = make(map[string]*Option) - - // ErrIncompleteCall is return when RegisterOption is called with empty mandatory values. - ErrIncompleteCall = errors.New("could not register config option: all fields, except for the validationRegex are mandatory") ) // Register registers a new configuration option. func Register(option *Option) error { - - if option.Name == "" || - option.Key == "" || - option.Description == "" || - option.OptType == 0 || - option.ExpertiseLevel == 0 || - option.ReleaseLevel == "" { - return ErrIncompleteCall + if option.Name == "" { + return fmt.Errorf("failed to register option: please set option.Name") + } + if option.Key == "" { + return fmt.Errorf("failed to register option: please set option.Key") + } + if option.Description == "" { + return fmt.Errorf("failed to register option: please set option.Description") + } + if option.OptType == 0 { + return fmt.Errorf("failed to register option: please set option.OptType") } if option.ValidationRegex != "" { @@ -37,7 +36,6 @@ func Register(option *Option) error { optionsLock.Lock() defer optionsLock.Unlock() - options[option.Key] = option return nil diff --git a/config/release.go b/config/release.go index 7804421..14b239c 100644 --- a/config/release.go +++ b/config/release.go @@ -1,38 +1,51 @@ +// Package config ... (linter fix) +//nolint:dupl package config import ( "fmt" - "sync" + "sync/atomic" ) +// Release Level constants const ( - releaseLevelKey = "core/release_level" + ReleaseLevelStable uint8 = 0 + ReleaseLevelBeta uint8 = 1 + ReleaseLevelExperimental uint8 = 2 + + ReleaseLevelNameStable = "stable" + ReleaseLevelNameBeta = "beta" + ReleaseLevelNameExperimental = "experimental" + + releaseLevelKey = "core/releaseLevel" ) var ( - releaseLevel = ReleaseLevelStable - releaseLevelLock sync.Mutex + releaseLevel *int32 ) func init() { + var releaseLevelVal int32 + releaseLevel = &releaseLevelVal + registerReleaseLevelOption() } func registerReleaseLevelOption() { err := Register(&Option{ - Name: "Release Selection", + Name: "Release Level", Key: releaseLevelKey, - Description: "Select maturity level of features that should be available", + Description: "The Release Level changes which features are available to you. Some beta or experimental features are also available in the stable release channel. Unavailable settings are set to the default value.", OptType: OptTypeString, ExpertiseLevel: ExpertiseLevelExpert, ReleaseLevel: ReleaseLevelStable, RequiresRestart: false, - DefaultValue: ReleaseLevelStable, + DefaultValue: ReleaseLevelNameStable, ExternalOptType: "string list", - ValidationRegex: fmt.Sprintf("^(%s|%s|%s)$", ReleaseLevelStable, ReleaseLevelBeta, ReleaseLevelExperimental), + ValidationRegex: fmt.Sprintf("^(%s|%s|%s)$", ReleaseLevelNameStable, ReleaseLevelNameBeta, ReleaseLevelNameExperimental), }) if err != nil { panic(err) @@ -41,17 +54,18 @@ func registerReleaseLevelOption() { func updateReleaseLevel() { new := findStringValue(releaseLevelKey, "") - releaseLevelLock.Lock() - if new == "" { - releaseLevel = ReleaseLevelStable - } else { - releaseLevel = new + switch new { + case ReleaseLevelNameStable: + atomic.StoreInt32(releaseLevel, int32(ReleaseLevelStable)) + case ReleaseLevelNameBeta: + atomic.StoreInt32(releaseLevel, int32(ReleaseLevelBeta)) + case ReleaseLevelNameExperimental: + atomic.StoreInt32(releaseLevel, int32(ReleaseLevelExperimental)) + default: + atomic.StoreInt32(releaseLevel, int32(ReleaseLevelStable)) } - releaseLevelLock.Unlock() } -func getReleaseLevel() string { - releaseLevelLock.Lock() - defer releaseLevelLock.Unlock() - return releaseLevel +func getReleaseLevel() uint8 { + return uint8(atomic.LoadInt32(releaseLevel)) } diff --git a/config/set.go b/config/set.go index 7bd3dc3..8e55714 100644 --- a/config/set.go +++ b/config/set.go @@ -17,9 +17,6 @@ var ( validityFlag = abool.NewBool(true) validityFlagLock sync.RWMutex - - changedSignal = make(chan struct{}) - changedSignalLock sync.Mutex ) func getValidityFlag() *abool.AtomicBool { @@ -28,16 +25,10 @@ func getValidityFlag() *abool.AtomicBool { return validityFlag } -// Changed signals if any config option was changed. -func Changed() <-chan struct{} { - changedSignalLock.Lock() - defer changedSignalLock.Unlock() - return changedSignal -} - func signalChanges() { - // refetch and save release level + // refetch and save release level and expertise level updateReleaseLevel() + updateExpertiseLevel() // reset validity flag validityFlagLock.Lock() @@ -45,11 +36,7 @@ func signalChanges() { validityFlag = abool.NewBool(true) validityFlagLock.Unlock() - // trigger change signal: signal listeners that a config option was changed. - changedSignalLock.Lock() - close(changedSignal) - changedSignal = make(chan struct{}) - changedSignalLock.Unlock() + module.TriggerEvent(configChangeEvent, nil) } // setConfig sets the (prioritized) user defined config. diff --git a/database/query/parser.go b/database/query/parser.go index 42c07f7..54615e8 100644 --- a/database/query/parser.go +++ b/database/query/parser.go @@ -14,6 +14,7 @@ type snippet struct { } // ParseQuery parses a plaintext query. Special characters (that must be escaped with a '\') are: `\()` and any whitespaces. +//nolint:gocognit func ParseQuery(query string) (*Query, error) { snippets, err := extractSnippets(query) if err != nil { @@ -195,6 +196,7 @@ func extractSnippets(text string) (snippets []*snippet, err error) { } +//nolint:gocognit func parseAndOr(getSnippet func() (*snippet, error), remainingSnippets func() int, rootCondition bool) (Condition, error) { var isOr = false var typeSet = false diff --git a/database/record/meta-gencode.go b/database/record/meta-gencode.go index c0a2142..6494a1d 100644 --- a/database/record/meta-gencode.go +++ b/database/record/meta-gencode.go @@ -14,14 +14,14 @@ var ( ) // GenCodeSize returns the size of the gencode marshalled byte slice -func (d *Meta) GenCodeSize() (s int) { +func (m *Meta) GenCodeSize() (s int) { s += 34 return } // GenCodeMarshal gencode marshalls Meta into the given byte array, or a new one if its too small. -func (d *Meta) GenCodeMarshal(buf []byte) ([]byte, error) { - size := d.GenCodeSize() +func (m *Meta) GenCodeMarshal(buf []byte) ([]byte, error) { + size := m.GenCodeSize() { if cap(buf) >= size { buf = buf[:size] @@ -33,89 +33,89 @@ func (d *Meta) GenCodeMarshal(buf []byte) ([]byte, error) { { - buf[0+0] = byte(d.Created >> 0) + buf[0+0] = byte(m.Created >> 0) - buf[1+0] = byte(d.Created >> 8) + buf[1+0] = byte(m.Created >> 8) - buf[2+0] = byte(d.Created >> 16) + buf[2+0] = byte(m.Created >> 16) - buf[3+0] = byte(d.Created >> 24) + buf[3+0] = byte(m.Created >> 24) - buf[4+0] = byte(d.Created >> 32) + buf[4+0] = byte(m.Created >> 32) - buf[5+0] = byte(d.Created >> 40) + buf[5+0] = byte(m.Created >> 40) - buf[6+0] = byte(d.Created >> 48) + buf[6+0] = byte(m.Created >> 48) - buf[7+0] = byte(d.Created >> 56) + buf[7+0] = byte(m.Created >> 56) } { - buf[0+8] = byte(d.Modified >> 0) + buf[0+8] = byte(m.Modified >> 0) - buf[1+8] = byte(d.Modified >> 8) + buf[1+8] = byte(m.Modified >> 8) - buf[2+8] = byte(d.Modified >> 16) + buf[2+8] = byte(m.Modified >> 16) - buf[3+8] = byte(d.Modified >> 24) + buf[3+8] = byte(m.Modified >> 24) - buf[4+8] = byte(d.Modified >> 32) + buf[4+8] = byte(m.Modified >> 32) - buf[5+8] = byte(d.Modified >> 40) + buf[5+8] = byte(m.Modified >> 40) - buf[6+8] = byte(d.Modified >> 48) + buf[6+8] = byte(m.Modified >> 48) - buf[7+8] = byte(d.Modified >> 56) + buf[7+8] = byte(m.Modified >> 56) } { - buf[0+16] = byte(d.Expires >> 0) + buf[0+16] = byte(m.Expires >> 0) - buf[1+16] = byte(d.Expires >> 8) + buf[1+16] = byte(m.Expires >> 8) - buf[2+16] = byte(d.Expires >> 16) + buf[2+16] = byte(m.Expires >> 16) - buf[3+16] = byte(d.Expires >> 24) + buf[3+16] = byte(m.Expires >> 24) - buf[4+16] = byte(d.Expires >> 32) + buf[4+16] = byte(m.Expires >> 32) - buf[5+16] = byte(d.Expires >> 40) + buf[5+16] = byte(m.Expires >> 40) - buf[6+16] = byte(d.Expires >> 48) + buf[6+16] = byte(m.Expires >> 48) - buf[7+16] = byte(d.Expires >> 56) + buf[7+16] = byte(m.Expires >> 56) } { - buf[0+24] = byte(d.Deleted >> 0) + buf[0+24] = byte(m.Deleted >> 0) - buf[1+24] = byte(d.Deleted >> 8) + buf[1+24] = byte(m.Deleted >> 8) - buf[2+24] = byte(d.Deleted >> 16) + buf[2+24] = byte(m.Deleted >> 16) - buf[3+24] = byte(d.Deleted >> 24) + buf[3+24] = byte(m.Deleted >> 24) - buf[4+24] = byte(d.Deleted >> 32) + buf[4+24] = byte(m.Deleted >> 32) - buf[5+24] = byte(d.Deleted >> 40) + buf[5+24] = byte(m.Deleted >> 40) - buf[6+24] = byte(d.Deleted >> 48) + buf[6+24] = byte(m.Deleted >> 48) - buf[7+24] = byte(d.Deleted >> 56) + buf[7+24] = byte(m.Deleted >> 56) } { - if d.secret { + if m.secret { buf[32] = 1 } else { buf[32] = 0 } } { - if d.cronjewel { + if m.cronjewel { buf[33] = 1 } else { buf[33] = 0 @@ -125,38 +125,38 @@ func (d *Meta) GenCodeMarshal(buf []byte) ([]byte, error) { } // GenCodeUnmarshal gencode unmarshalls Meta and returns the bytes read. -func (d *Meta) GenCodeUnmarshal(buf []byte) (uint64, error) { - if len(buf) < d.GenCodeSize() { - return 0, fmt.Errorf("insufficient data: got %d out of %d bytes", len(buf), d.GenCodeSize()) +func (m *Meta) GenCodeUnmarshal(buf []byte) (uint64, error) { + if len(buf) < m.GenCodeSize() { + return 0, fmt.Errorf("insufficient data: got %d out of %d bytes", len(buf), m.GenCodeSize()) } i := uint64(0) { - d.Created = 0 | (int64(buf[0+0]) << 0) | (int64(buf[1+0]) << 8) | (int64(buf[2+0]) << 16) | (int64(buf[3+0]) << 24) | (int64(buf[4+0]) << 32) | (int64(buf[5+0]) << 40) | (int64(buf[6+0]) << 48) | (int64(buf[7+0]) << 56) + m.Created = 0 | (int64(buf[0+0]) << 0) | (int64(buf[1+0]) << 8) | (int64(buf[2+0]) << 16) | (int64(buf[3+0]) << 24) | (int64(buf[4+0]) << 32) | (int64(buf[5+0]) << 40) | (int64(buf[6+0]) << 48) | (int64(buf[7+0]) << 56) } { - d.Modified = 0 | (int64(buf[0+8]) << 0) | (int64(buf[1+8]) << 8) | (int64(buf[2+8]) << 16) | (int64(buf[3+8]) << 24) | (int64(buf[4+8]) << 32) | (int64(buf[5+8]) << 40) | (int64(buf[6+8]) << 48) | (int64(buf[7+8]) << 56) + m.Modified = 0 | (int64(buf[0+8]) << 0) | (int64(buf[1+8]) << 8) | (int64(buf[2+8]) << 16) | (int64(buf[3+8]) << 24) | (int64(buf[4+8]) << 32) | (int64(buf[5+8]) << 40) | (int64(buf[6+8]) << 48) | (int64(buf[7+8]) << 56) } { - d.Expires = 0 | (int64(buf[0+16]) << 0) | (int64(buf[1+16]) << 8) | (int64(buf[2+16]) << 16) | (int64(buf[3+16]) << 24) | (int64(buf[4+16]) << 32) | (int64(buf[5+16]) << 40) | (int64(buf[6+16]) << 48) | (int64(buf[7+16]) << 56) + m.Expires = 0 | (int64(buf[0+16]) << 0) | (int64(buf[1+16]) << 8) | (int64(buf[2+16]) << 16) | (int64(buf[3+16]) << 24) | (int64(buf[4+16]) << 32) | (int64(buf[5+16]) << 40) | (int64(buf[6+16]) << 48) | (int64(buf[7+16]) << 56) } { - d.Deleted = 0 | (int64(buf[0+24]) << 0) | (int64(buf[1+24]) << 8) | (int64(buf[2+24]) << 16) | (int64(buf[3+24]) << 24) | (int64(buf[4+24]) << 32) | (int64(buf[5+24]) << 40) | (int64(buf[6+24]) << 48) | (int64(buf[7+24]) << 56) + m.Deleted = 0 | (int64(buf[0+24]) << 0) | (int64(buf[1+24]) << 8) | (int64(buf[2+24]) << 16) | (int64(buf[3+24]) << 24) | (int64(buf[4+24]) << 32) | (int64(buf[5+24]) << 40) | (int64(buf[6+24]) << 48) | (int64(buf[7+24]) << 56) } { - d.secret = buf[32] == 1 + m.secret = buf[32] == 1 } { - d.cronjewel = buf[33] == 1 + m.cronjewel = buf[33] == 1 } return i + 34, nil } diff --git a/database/registry.go b/database/registry.go index 23d382c..b6d1984 100644 --- a/database/registry.go +++ b/database/registry.go @@ -139,6 +139,7 @@ func saveRegistry(lock bool) error { } // write file + // FIXME: write atomically (best effort) filePath := path.Join(rootStructure.Path, registryFileName) return ioutil.WriteFile(filePath, data, 0600) } diff --git a/database/storage/badger/badger.go b/database/storage/badger/badger.go index da2276d..8f473cd 100644 --- a/database/storage/badger/badger.go +++ b/database/storage/badger/badger.go @@ -118,6 +118,7 @@ func (b *Badger) Query(q *query.Query, local, internal bool) (*iterator.Iterator return queryIter, nil } +//nolint:gocognit func (b *Badger) queryExecutor(queryIter *iterator.Iterator, q *query.Query, local, internal bool) { err := b.db.View(func(txn *badger.Txn) error { it := txn.NewIterator(badger.DefaultIteratorOptions) diff --git a/database/storage/bbolt/bbolt_test.go b/database/storage/bbolt/bbolt_test.go index 5264f56..e5cdd3e 100644 --- a/database/storage/bbolt/bbolt_test.go +++ b/database/storage/bbolt/bbolt_test.go @@ -31,7 +31,7 @@ type TestRecord struct { B bool } -func TestBadger(t *testing.T) { +func TestBBolt(t *testing.T) { testDir, err := ioutil.TempDir("", "testing-") if err != nil { t.Fatal(err) diff --git a/database/storage/hashmap/map.go b/database/storage/hashmap/map.go new file mode 100644 index 0000000..c6bb65d --- /dev/null +++ b/database/storage/hashmap/map.go @@ -0,0 +1,140 @@ +package hashmap + +import ( + "errors" + "fmt" + "sync" + "time" + + "github.com/safing/portbase/database/iterator" + "github.com/safing/portbase/database/query" + "github.com/safing/portbase/database/record" + "github.com/safing/portbase/database/storage" +) + +// HashMap storage. +type HashMap struct { + name string + db map[string]record.Record + dbLock sync.RWMutex +} + +func init() { + _ = storage.Register("hashmap", NewHashMap) +} + +// NewHashMap creates a hashmap database. +func NewHashMap(name, location string) (storage.Interface, error) { + return &HashMap{ + name: name, + db: make(map[string]record.Record), + }, nil +} + +// Get returns a database record. +func (hm *HashMap) Get(key string) (record.Record, error) { + hm.dbLock.RLock() + defer hm.dbLock.RUnlock() + + r, ok := hm.db[key] + if !ok { + return nil, storage.ErrNotFound + } + return r, nil +} + +// Put stores a record in the database. +func (hm *HashMap) Put(r record.Record) error { + hm.dbLock.Lock() + defer hm.dbLock.Unlock() + + hm.db[r.DatabaseKey()] = r + return nil +} + +// Delete deletes a record from the database. +func (hm *HashMap) Delete(key string) error { + hm.dbLock.Lock() + defer hm.dbLock.Unlock() + + delete(hm.db, key) + return nil +} + +// Query returns a an iterator for the supplied query. +func (hm *HashMap) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { + _, err := q.Check() + if err != nil { + return nil, fmt.Errorf("invalid query: %s", err) + } + + queryIter := iterator.New() + + go hm.queryExecutor(queryIter, q, local, internal) + return queryIter, nil +} + +func (hm *HashMap) queryExecutor(queryIter *iterator.Iterator, q *query.Query, local, internal bool) { + hm.dbLock.RLock() + defer hm.dbLock.RUnlock() + + var err error + +mapLoop: + for key, record := range hm.db { + + switch { + case !q.MatchesKey(key): + continue + case !q.MatchesRecord(record): + continue + case !record.Meta().CheckValidity(): + continue + case !record.Meta().CheckPermission(local, internal): + continue + } + + select { + case <-queryIter.Done: + break mapLoop + case queryIter.Next <- record: + default: + select { + case <-queryIter.Done: + break mapLoop + case queryIter.Next <- record: + case <-time.After(1 * time.Second): + err = errors.New("query timeout") + break mapLoop + } + } + + } + + queryIter.Finish(err) +} + +// ReadOnly returns whether the database is read only. +func (hm *HashMap) ReadOnly() bool { + return false +} + +// Injected returns whether the database is injected. +func (hm *HashMap) Injected() bool { + return false +} + +// Maintain runs a light maintenance operation on the database. +func (hm *HashMap) Maintain() error { + return nil +} + +// MaintainThorough runs a thorough maintenance operation on the database. +func (hm *HashMap) MaintainThorough() (err error) { + return nil +} + +// Shutdown shuts down the database. +func (hm *HashMap) Shutdown() error { + return nil +} diff --git a/database/storage/hashmap/map_test.go b/database/storage/hashmap/map_test.go new file mode 100644 index 0000000..911dc5b --- /dev/null +++ b/database/storage/hashmap/map_test.go @@ -0,0 +1,147 @@ +//nolint:unparam,maligned +package hashmap + +import ( + "reflect" + "sync" + "testing" + + "github.com/safing/portbase/database/query" + "github.com/safing/portbase/database/record" +) + +type TestRecord struct { + record.Base + sync.Mutex + S string + I int + I8 int8 + I16 int16 + I32 int32 + I64 int64 + UI uint + UI8 uint8 + UI16 uint16 + UI32 uint32 + UI64 uint64 + F32 float32 + F64 float64 + B bool +} + +func TestHashMap(t *testing.T) { + // start + db, err := NewHashMap("test", "") + if err != nil { + t.Fatal(err) + } + + a := &TestRecord{ + S: "banana", + I: 42, + I8: 42, + I16: 42, + I32: 42, + I64: 42, + UI: 42, + UI8: 42, + UI16: 42, + UI32: 42, + UI64: 42, + F32: 42.42, + F64: 42.42, + B: true, + } + a.SetMeta(&record.Meta{}) + a.Meta().Update() + a.SetKey("test:A") + + // put record + err = db.Put(a) + if err != nil { + t.Fatal(err) + } + + // get and compare + a1, err := db.Get("A") + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(a, a1) { + t.Fatalf("mismatch, got %v", a1) + } + + // setup query test records + qA := &TestRecord{} + qA.SetKey("test:path/to/A") + qA.CreateMeta() + qB := &TestRecord{} + qB.SetKey("test:path/to/B") + qB.CreateMeta() + qC := &TestRecord{} + qC.SetKey("test:path/to/C") + qC.CreateMeta() + qZ := &TestRecord{} + qZ.SetKey("test:z") + qZ.CreateMeta() + // put + err = db.Put(qA) + if err == nil { + err = db.Put(qB) + } + if err == nil { + err = db.Put(qC) + } + if err == nil { + err = db.Put(qZ) + } + if err != nil { + t.Fatal(err) + } + + // test query + q := query.New("test:path/to/").MustBeValid() + it, err := db.Query(q, true, true) + if err != nil { + t.Fatal(err) + } + cnt := 0 + for range it.Next { + cnt++ + } + if it.Err() != nil { + t.Fatal(it.Err()) + } + if cnt != 3 { + t.Fatalf("unexpected query result count: %d", cnt) + } + + // delete + err = db.Delete("A") + if err != nil { + t.Fatal(err) + } + + // check if its gone + _, err = db.Get("A") + if err == nil { + t.Fatal("should fail") + } + + // maintenance + err = db.Maintain() + if err != nil { + t.Fatal(err) + } + err = db.MaintainThorough() + if err != nil { + t.Fatal(err) + } + + // shutdown + err = db.Shutdown() + if err != nil { + t.Fatal(err) + } +} diff --git a/formats/dsd/dsd_test.go b/formats/dsd/dsd_test.go index e46df20..db81e76 100644 --- a/formats/dsd/dsd_test.go +++ b/formats/dsd/dsd_test.go @@ -1,4 +1,4 @@ -//nolint:maligned,unparam,gocyclo +//nolint:maligned,unparam,gocyclo,gocognit package dsd import ( diff --git a/formats/dsd/gencode_test.go b/formats/dsd/gencode_test.go index 9d24946..cb35f09 100644 --- a/formats/dsd/gencode_test.go +++ b/formats/dsd/gencode_test.go @@ -1,4 +1,4 @@ -//nolint:nakedret,unconvert +//nolint:nakedret,unconvert,gocognit package dsd import ( diff --git a/formats/varint/varint.go b/formats/varint/varint.go index d0f6129..478e803 100644 --- a/formats/varint/varint.go +++ b/formats/varint/varint.go @@ -1,7 +1,9 @@ package varint -import "errors" -import "encoding/binary" +import ( + "encoding/binary" + "errors" +) // Pack8 packs a uint8 into a VarInt. func Pack8(n uint8) []byte { diff --git a/formats/varint/varint_test.go b/formats/varint/varint_test.go index 0de1741..5e049ad 100644 --- a/formats/varint/varint_test.go +++ b/formats/varint/varint_test.go @@ -1,3 +1,4 @@ +//nolint:gocognit package varint import ( diff --git a/info/flags.go b/info/flags.go index f2700b3..6c677b4 100644 --- a/info/flags.go +++ b/info/flags.go @@ -15,7 +15,7 @@ var ( ) func init() { - modules.Register("info", prep, nil, nil, "base") + modules.Register("info", prep, nil, nil) flag.BoolVar(&showVersion, "version", false, "show version and exit") } @@ -35,8 +35,10 @@ func prep() error { // CheckVersion checks if the metadata is ok. func CheckVersion() error { if !strings.HasSuffix(os.Args[0], ".test") { - if name == "[NAME]" || - version == "[version unknown]" || + if name == "[NAME]" { + return errors.New("must call SetInfo() before calling CheckVersion()") + } + if version == "[version unknown]" || commit == "[commit unknown]" || license == "[license unknown]" || buildOptions == "[options unknown]" || diff --git a/log/input.go b/log/input.go index e8bff6c..e8b6579 100644 --- a/log/input.go +++ b/log/input.go @@ -75,15 +75,21 @@ func log(level Severity, msg string, tracer *ContextTracer) { select { case logBuffer <- log: default: - forceEmptyingOfBuffer <- struct{}{} - logBuffer <- log + forceEmptyingLoop: + // force empty buffer until we can send to it + for { + select { + case forceEmptyingOfBuffer <- struct{}{}: + case logBuffer <- log: + break forceEmptyingLoop + } + } } // wake up writer if necessary if logsWaitingFlag.SetToIf(false, true) { logsWaiting <- struct{}{} } - } func fastcheck(level Severity) bool { diff --git a/log/logging.go b/log/logging.go index 0ca6fac..2d605c8 100644 --- a/log/logging.go +++ b/log/logging.go @@ -70,7 +70,7 @@ const ( var ( logBuffer chan *logLine - forceEmptyingOfBuffer chan struct{} + forceEmptyingOfBuffer = make(chan struct{}) logLevelInt = uint32(3) logLevel = &logLevelInt @@ -79,7 +79,7 @@ var ( pkgLevels = make(map[string]Severity) pkgLevelsLock sync.Mutex - logsWaiting = make(chan struct{}, 1) + logsWaiting = make(chan struct{}, 4) logsWaitingFlag = abool.NewBool(false) shutdownSignal = make(chan struct{}) @@ -90,7 +90,7 @@ var ( startedSignal = make(chan struct{}) ) -// SetPkgLevels sets individual log levels for packages. +// SetPkgLevels sets individual log levels for packages. Only effective after Start(). func SetPkgLevels(levels map[string]Severity) { pkgLevelsLock.Lock() pkgLevels = levels @@ -103,7 +103,7 @@ func UnSetPkgLevels() { pkgLevelsActive.UnSet() } -// SetLogLevel sets a new log level. +// SetLogLevel sets a new log level. Only effective after Start(). func SetLogLevel(level Severity) { atomic.StoreUint32(logLevel, uint32(level)) } @@ -135,11 +135,10 @@ func Start() (err error) { } logBuffer = make(chan *logLine, 1024) - forceEmptyingOfBuffer = make(chan struct{}, 16) initialLogLevel := ParseLevel(logLevelFlag) if initialLogLevel > 0 { - atomic.StoreUint32(logLevel, uint32(initialLogLevel)) + SetLogLevel(initialLogLevel) } else { err = fmt.Errorf("log warning: invalid log level \"%s\", falling back to level info", logLevelFlag) fmt.Fprintf(os.Stderr, "%s\n", err.Error()) diff --git a/log/output.go b/log/output.go index 5b44fb5..28bde37 100644 --- a/log/output.go +++ b/log/output.go @@ -2,6 +2,8 @@ package log import ( "fmt" + "os" + "runtime/debug" "time" ) @@ -39,37 +41,73 @@ func writeLine(line *logLine, duplicates uint64) { } func startWriter() { - shutdownWaitGroup.Add(1) fmt.Println(fmt.Sprintf("%s%s %s BOF%s", InfoLevel.color(), time.Now().Format(timeFormat), rightArrow, endColor())) - go writer() + + shutdownWaitGroup.Add(1) + go writerManager() } -func writer() { - var line *logLine - var lastLine *logLine - var duplicates uint64 +func writerManager() { defer shutdownWaitGroup.Done() + for { + err := writer() + if err != nil { + Errorf("log: writer failed: %s", err) + } else { + return + } + } +} + +func writer() (err error) { + defer func() { + // recover from panic + panicVal := recover() + if panicVal != nil { + err = fmt.Errorf("%s", panicVal) + + // write stack to stderr + fmt.Fprintf( + os.Stderr, + `===== Error Report ===== +Message: %s +StackTrace: + +%s +===== End of Report ===== +`, + err, + string(debug.Stack()), + ) + } + }() + + var currentLine *logLine + var nextLine *logLine + var duplicates uint64 + for { // reset - line = nil - lastLine = nil //nolint:ineffassign // only ineffectual in first loop + currentLine = nil + nextLine = nil duplicates = 0 // wait until logs need to be processed select { - case <-logsWaiting: + case <-logsWaiting: // normal process logsWaitingFlag.UnSet() - case <-shutdownSignal: + case <-forceEmptyingOfBuffer: // log buffer is full! + case <-shutdownSignal: // shutting down finalizeWriting() return } - // wait for timeslot to log, or when buffer is full + // wait for timeslot to log select { - case <-writeTrigger: - case <-forceEmptyingOfBuffer: - case <-shutdownSignal: + case <-writeTrigger: // normal process + case <-forceEmptyingOfBuffer: // log buffer is full! + case <-shutdownSignal: // shutting down finalizeWriting() return } @@ -78,38 +116,41 @@ func writer() { writeLoop: for { select { - case line = <-logBuffer: - - // look-ahead for deduplication (best effort) - dedupLoop: - for { - // check if there is another line waiting - select { - case nextLine := <-logBuffer: - lastLine = line - line = nextLine - default: - break dedupLoop - } - - // deduplication - if !line.Equal(lastLine) { - // no duplicate - break dedupLoop - } - - // duplicate - duplicates++ + case nextLine = <-logBuffer: + // first line we process, just assign to currentLine + if currentLine == nil { + currentLine = nextLine + continue writeLoop } - // write actual line - writeLine(line, duplicates) + // we now have currentLine and nextLine + + // if currentLine and nextLine are equal, do not print, just increase counter and continue + if nextLine.Equal(currentLine) { + duplicates++ + continue writeLoop + } + + // if currentLine and line are _not_ equal, output currentLine + writeLine(currentLine, duplicates) + // reset duplicate counter duplicates = 0 + // set new currentLine + currentLine = nextLine default: break writeLoop } } + // write final line + if currentLine != nil { + writeLine(currentLine, duplicates) + } + // reset state + currentLine = nil //nolint:ineffassign + nextLine = nil + duplicates = 0 //nolint:ineffassign + // back down a little select { case <-time.After(10 * time.Millisecond): diff --git a/log/trace.go b/log/trace.go index 6f77c37..a345f5d 100644 --- a/log/trace.go +++ b/log/trace.go @@ -83,7 +83,7 @@ func Tracer(ctx context.Context) *ContextTracer { // Submit collected logs on the context for further processing/outputting. Does nothing if called on a nil ContextTracer. func (tracer *ContextTracer) Submit() { - if tracer != nil { + if tracer == nil { return } @@ -119,15 +119,21 @@ func (tracer *ContextTracer) Submit() { select { case logBuffer <- log: default: - forceEmptyingOfBuffer <- struct{}{} - logBuffer <- log + forceEmptyingLoop: + // force empty buffer until we can send to it + for { + select { + case forceEmptyingOfBuffer <- struct{}{}: + case logBuffer <- log: + break forceEmptyingLoop + } + } } // wake up writer if necessary if logsWaitingFlag.SetToIf(false, true) { logsWaiting <- struct{}{} } - } func (tracer *ContextTracer) log(level Severity, msg string) { diff --git a/modules/error.go b/modules/error.go index 80d025e..04bda1c 100644 --- a/modules/error.go +++ b/modules/error.go @@ -2,11 +2,16 @@ package modules import ( "fmt" + "os" "runtime/debug" + "sync" + "time" ) var ( errorReportingChannel chan *ModuleError + reportToStdErr = true + reportingLock sync.RWMutex ) // ModuleError wraps a panic, error or message into an error that can be reported. @@ -64,12 +69,43 @@ func (me *ModuleError) Error() string { // Report reports the error through the configured reporting channel. func (me *ModuleError) Report() { + reportingLock.RLock() + defer reportingLock.RUnlock() + if errorReportingChannel != nil { select { case errorReportingChannel <- me: default: } } + + if reportToStdErr { + // default to writing to stderr + fmt.Fprintf( + os.Stderr, + `===== Error Report ===== +Message: %s +Timestamp: %s +ModuleName: %s +TaskName: %s +TaskType: %s +Severity: %s +PanicValue: %s +StackTrace: + +%s +===== End of Report ===== +`, + me.Message, + time.Now(), + me.ModuleName, + me.TaskName, + me.TaskType, + me.Severity, + me.PanicValue, + me.StackTrace, + ) + } } // IsPanic returns whether the given error is a wrapped panic by the modules package and additionally returns it, if true. @@ -84,7 +120,16 @@ func IsPanic(err error) (bool, *ModuleError) { // SetErrorReportingChannel sets the channel to report module errors through. By default only panics are reported, all other errors need to be manually wrapped into a *ModuleError and reported. func SetErrorReportingChannel(reportingChannel chan *ModuleError) { - if errorReportingChannel == nil { - errorReportingChannel = reportingChannel - } + reportingLock.Lock() + defer reportingLock.Unlock() + + errorReportingChannel = reportingChannel +} + +// SetStdErrReporting controls error reporting to stderr. +func SetStdErrReporting(on bool) { + reportingLock.Lock() + defer reportingLock.Unlock() + + reportToStdErr = on } diff --git a/modules/events.go b/modules/events.go new file mode 100644 index 0000000..0768ebd --- /dev/null +++ b/modules/events.go @@ -0,0 +1,103 @@ +package modules + +import ( + "context" + "fmt" + + "github.com/safing/portbase/log" +) + +type eventHookFn func(context.Context, interface{}) error + +type eventHook struct { + description string + hookingModule *Module + hookFn eventHookFn +} + +// TriggerEvent executes all hook functions registered to the specified event. +func (m *Module) TriggerEvent(event string, data interface{}) { + go m.processEventTrigger(event, data) +} + +func (m *Module) processEventTrigger(event string, data interface{}) { + m.eventHooksLock.RLock() + defer m.eventHooksLock.RUnlock() + + hooks, ok := m.eventHooks[event] + if !ok { + log.Warningf(`%s: tried to trigger non-existent event "%s"`, m.Name, event) + return + } + + for _, hook := range hooks { + if !hook.hookingModule.ShutdownInProgress() { + go m.runEventHook(hook, event, data) + } + } +} + +func (m *Module) runEventHook(hook *eventHook, event string, data interface{}) { + if !hook.hookingModule.Started.IsSet() { + // target module has not yet fully started, wait until start is complete + select { + case <-startCompleteSignal: + case <-shutdownSignal: + return + } + } + + err := hook.hookingModule.RunWorker( + fmt.Sprintf("event hook %s/%s -> %s/%s", m.Name, event, hook.hookingModule.Name, hook.description), + func(ctx context.Context) error { + return hook.hookFn(ctx, data) + }, + ) + if err != nil { + log.Warningf("%s: failed to execute event hook %s/%s -> %s/%s: %s", hook.hookingModule.Name, m.Name, event, hook.hookingModule.Name, hook.description, err) + } +} + +// RegisterEvent registers a new event to allow for registering hooks. +func (m *Module) RegisterEvent(event string) { + m.eventHooksLock.Lock() + defer m.eventHooksLock.Unlock() + + _, ok := m.eventHooks[event] + if !ok { + m.eventHooks[event] = make([]*eventHook, 0, 1) + } +} + +// RegisterEventHook registers a hook function with (another) modules' event. Whenever a hook is triggered and the receiving module has not yet fully started, hook execution will be delayed until all modules completed starting. +func (m *Module) RegisterEventHook(module string, event string, description string, fn func(context.Context, interface{}) error) error { + // get target module + var eventModule *Module + if module == m.Name { + eventModule = m + } else { + var ok bool + modulesLock.RLock() + eventModule, ok = modules[module] + modulesLock.RUnlock() + if !ok { + return fmt.Errorf(`module "%s" does not exist`, module) + } + } + + // get target event + eventModule.eventHooksLock.Lock() + defer eventModule.eventHooksLock.Unlock() + hooks, ok := eventModule.eventHooks[event] + if !ok { + return fmt.Errorf(`event "%s/%s" does not exist`, eventModule.Name, event) + } + + // add hook + eventModule.eventHooks[event] = append(hooks, &eventHook{ + description: description, + hookingModule: m, + hookFn: fn, + }) + return nil +} diff --git a/modules/exit.go b/modules/exit.go new file mode 100644 index 0000000..67d38bc --- /dev/null +++ b/modules/exit.go @@ -0,0 +1,16 @@ +package modules + +var ( + exitStatusCode int +) + +// SetExitStatusCode sets the exit code that the program shell return to the host after shutdown. +func SetExitStatusCode(n int) { + exitStatusCode = n +} + +// GetExitStatusCode waits for the shutdown to complete and then returns the exit code +func GetExitStatusCode() int { + <-shutdownCompleteSignal + return exitStatusCode +} diff --git a/modules/microtasks.go b/modules/microtasks.go index f874e4f..a43d427 100644 --- a/modules/microtasks.go +++ b/modules/microtasks.go @@ -45,14 +45,54 @@ func SetMaxConcurrentMicroTasks(n int) { } } -// StartMicroTask starts a new MicroTask with high priority. It will start immediately. The given function will be executed and panics caught. The supplied name should be a constant - the variable should never change as it won't be copied. -func (m *Module) StartMicroTask(name *string, fn func(context.Context) error) error { - atomic.AddInt32(microTasks, 1) +// StartMicroTask starts a new MicroTask with high priority. It will start immediately. The call starts a new goroutine and returns immediately. The given function will be executed and panics caught. The supplied name must not be changed. +func (m *Module) StartMicroTask(name *string, fn func(context.Context) error) { + go func() { + err := m.RunMicroTask(name, fn) + if err != nil { + log.Warningf("%s: microtask %s failed: %s", m.Name, *name, err) + } + }() +} + +// StartMediumPriorityMicroTask starts a new MicroTask with medium priority. The call starts a new goroutine and returns immediately. It will wait until a slot becomes available (max 3 seconds). The given function will be executed and panics caught. The supplied name must not be changed. +func (m *Module) StartMediumPriorityMicroTask(name *string, fn func(context.Context) error) { + go func() { + err := m.RunMediumPriorityMicroTask(name, fn) + if err != nil { + log.Warningf("%s: microtask %s failed: %s", m.Name, *name, err) + } + }() +} + +// StartLowPriorityMicroTask starts a new MicroTask with low priority. The call starts a new goroutine and returns immediately. It will wait until a slot becomes available (max 15 seconds). The given function will be executed and panics caught. The supplied name must not be changed. +func (m *Module) StartLowPriorityMicroTask(name *string, fn func(context.Context) error) { + go func() { + err := m.RunLowPriorityMicroTask(name, fn) + if err != nil { + log.Warningf("%s: microtask %s failed: %s", m.Name, *name, err) + } + }() +} + +// RunMicroTask runs a new MicroTask with high priority. It will start immediately. The call blocks until finished. The given function will be executed and panics caught. The supplied name must not be changed. +func (m *Module) RunMicroTask(name *string, fn func(context.Context) error) error { + if m == nil { + log.Errorf(`modules: cannot start microtask "%s" with nil module`, *name) + return errNoModule + } + + atomic.AddInt32(microTasks, 1) // increase global counter here, as high priority tasks are not started by the scheduler, where this counter is usually increased return m.runMicroTask(name, fn) } -// StartMediumPriorityMicroTask starts a new MicroTask with medium priority. It will wait until given a go (max 3 seconds). The given function will be executed and panics caught. The supplied name should be a constant - the variable should never change as it won't be copied. -func (m *Module) StartMediumPriorityMicroTask(name *string, fn func(context.Context) error) error { +// RunMediumPriorityMicroTask runs a new MicroTask with medium priority. It will wait until a slot becomes available (max 3 seconds). The call blocks until finished. The given function will be executed and panics caught. The supplied name must not be changed. +func (m *Module) RunMediumPriorityMicroTask(name *string, fn func(context.Context) error) error { + if m == nil { + log.Errorf(`modules: cannot start microtask "%s" with nil module`, *name) + return errNoModule + } + // check if we can go immediately select { case <-mediumPriorityClearance: @@ -66,8 +106,13 @@ func (m *Module) StartMediumPriorityMicroTask(name *string, fn func(context.Cont return m.runMicroTask(name, fn) } -// StartLowPriorityMicroTask starts a new MicroTask with low priority. It will wait until given a go (max 15 seconds). The given function will be executed and panics caught. The supplied name should be a constant - the variable should never change as it won't be copied. -func (m *Module) StartLowPriorityMicroTask(name *string, fn func(context.Context) error) error { +// RunLowPriorityMicroTask runs a new MicroTask with low priority. It will wait until a slot becomes available (max 15 seconds). The call blocks until finished. The given function will be executed and panics caught. The supplied name must not be changed. +func (m *Module) RunLowPriorityMicroTask(name *string, fn func(context.Context) error) error { + if m == nil { + log.Errorf(`modules: cannot start microtask "%s" with nil module`, *name) + return errNoModule + } + // check if we can go immediately select { case <-lowPriorityClearance: @@ -94,7 +139,7 @@ func (m *Module) runMicroTask(name *string, fn func(context.Context) error) (err if panicVal != nil { me := m.NewPanicError(*name, "microtask", panicVal) me.Report() - log.Errorf("%s: microtask %s panicked: %s\n%s", m.Name, *name, panicVal, me.StackTrace) + log.Errorf("%s: microtask %s panicked: %s", m.Name, *name, panicVal) err = me } @@ -150,7 +195,10 @@ microTaskManageLoop: atomic.AddInt32(microTasks, 1) } else { // wait for signal that a task was completed - <-microTaskFinished + select { + case <-microTaskFinished: + case <-time.After(1 * time.Second): + } } } diff --git a/modules/microtasks_test.go b/modules/microtasks_test.go index 1e21e03..7cde095 100644 --- a/modules/microtasks_test.go +++ b/modules/microtasks_test.go @@ -42,7 +42,7 @@ func TestMicroTaskWaiting(t *testing.T) { go func() { defer mtwWaitGroup.Done() // exec at slot 1 - _ = mtModule.StartMicroTask(&mtTestName, func(ctx context.Context) error { + _ = mtModule.RunMicroTask(&mtTestName, func(ctx context.Context) error { mtwOutputChannel <- "1" // slot 1 time.Sleep(mtwSleepDuration * 5) mtwOutputChannel <- "2" // slot 5 @@ -53,7 +53,7 @@ func TestMicroTaskWaiting(t *testing.T) { time.Sleep(mtwSleepDuration * 1) // clear clearances - _ = mtModule.StartMicroTask(&mtTestName, func(ctx context.Context) error { + _ = mtModule.RunMicroTask(&mtTestName, func(ctx context.Context) error { return nil }) @@ -61,7 +61,7 @@ func TestMicroTaskWaiting(t *testing.T) { go func() { defer mtwWaitGroup.Done() // exec at slot 2 - _ = mtModule.StartLowPriorityMicroTask(&mtTestName, func(ctx context.Context) error { + _ = mtModule.RunLowPriorityMicroTask(&mtTestName, func(ctx context.Context) error { mtwOutputChannel <- "7" // slot 16 return nil }) @@ -74,7 +74,7 @@ func TestMicroTaskWaiting(t *testing.T) { defer mtwWaitGroup.Done() time.Sleep(mtwSleepDuration * 8) // exec at slot 10 - _ = mtModule.StartMicroTask(&mtTestName, func(ctx context.Context) error { + _ = mtModule.RunMicroTask(&mtTestName, func(ctx context.Context) error { mtwOutputChannel <- "4" // slot 10 time.Sleep(mtwSleepDuration * 5) mtwOutputChannel <- "6" // slot 15 @@ -86,7 +86,7 @@ func TestMicroTaskWaiting(t *testing.T) { go func() { defer mtwWaitGroup.Done() // exec at slot 3 - _ = mtModule.StartMediumPriorityMicroTask(&mtTestName, func(ctx context.Context) error { + _ = mtModule.RunMediumPriorityMicroTask(&mtTestName, func(ctx context.Context) error { mtwOutputChannel <- "3" // slot 6 time.Sleep(mtwSleepDuration * 7) mtwOutputChannel <- "5" // slot 13 @@ -122,7 +122,7 @@ var mtoWaitCh chan struct{} func mediumPrioTaskTester() { defer mtoWaitGroup.Done() <-mtoWaitCh - _ = mtModule.StartMediumPriorityMicroTask(&mtTestName, func(ctx context.Context) error { + _ = mtModule.RunMediumPriorityMicroTask(&mtTestName, func(ctx context.Context) error { mtoOutputChannel <- "1" time.Sleep(2 * time.Millisecond) return nil @@ -132,7 +132,7 @@ func mediumPrioTaskTester() { func lowPrioTaskTester() { defer mtoWaitGroup.Done() <-mtoWaitCh - _ = mtModule.StartLowPriorityMicroTask(&mtTestName, func(ctx context.Context) error { + _ = mtModule.RunLowPriorityMicroTask(&mtTestName, func(ctx context.Context) error { mtoOutputChannel <- "2" time.Sleep(2 * time.Millisecond) return nil diff --git a/modules/modules.go b/modules/modules.go index d58e315..66c3bd4 100644 --- a/modules/modules.go +++ b/modules/modules.go @@ -13,7 +13,7 @@ import ( ) var ( - modulesLock sync.Mutex + modulesLock sync.RWMutex modules = make(map[string]*Module) // ErrCleanExit is returned by Start() when the program is interrupted before starting. This can happen for example, when using the "--help" flag. @@ -46,6 +46,10 @@ type Module struct { microTaskCnt *int32 waitGroup sync.WaitGroup + // events + eventHooks map[string][]*eventHook + eventHooksLock sync.RWMutex + // dependency mgmt depNames []string depModules []*Module @@ -67,15 +71,25 @@ func (m *Module) shutdown() error { m.shutdownFlag.Set() m.cancelCtx() + // start shutdown function + m.waitGroup.Add(1) + stopFnError := make(chan error, 1) + go func() { + stopFnError <- m.runCtrlFn("stop module", m.stop) + m.waitGroup.Done() + }() + // wait for workers done := make(chan struct{}) go func() { m.waitGroup.Wait() close(done) }() + + // wait for results select { case <-done: - case <-time.After(3 * time.Second): + case <-time.After(30 * time.Second): log.Warningf( "%s: timed out while waiting for workers/tasks to finish: workers=%d tasks=%d microtasks=%d, continuing shutdown...", m.Name, @@ -85,12 +99,17 @@ func (m *Module) shutdown() error { ) } - // call shutdown function - return m.stop() -} - -func dummyAction() error { - return nil + // collect error + select { + case err := <-stopFnError: + return err + default: + log.Warningf( + "%s: timed out while waiting for stop function to finish, continuing shutdown...", + m.Name, + ) + return nil + } } // Register registers a new module. The control functions `prep`, `start` and `stop` are technically optional. `stop` is called _after_ all added module workers finished. @@ -99,7 +118,14 @@ func Register(name string, prep, start, stop func() error, dependencies ...strin modulesLock.Lock() defer modulesLock.Unlock() + // check for already existing module + _, ok := modules[name] + if ok { + panic(fmt.Sprintf("modules: module %s is already registered", name)) + } + // add new module modules[name] = newModule + return newModule } @@ -125,20 +151,10 @@ func initNewModule(name string, prep, start, stop func() error, dependencies ... prep: prep, start: start, stop: stop, + eventHooks: make(map[string][]*eventHook), depNames: dependencies, } - // replace nil arguments with dummy action - if newModule.prep == nil { - newModule.prep = dummyAction - } - if newModule.start == nil { - newModule.start = dummyAction - } - if newModule.stop == nil { - newModule.stop = dummyAction - } - return newModule } diff --git a/modules/start.go b/modules/start.go index bfbe36d..6c945f3 100644 --- a/modules/start.go +++ b/modules/start.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "runtime" + "time" "github.com/safing/portbase/log" "github.com/tevino/abool" @@ -26,8 +27,8 @@ func WaitForStartCompletion() <-chan struct{} { // Start starts all modules in the correct order. In case of an error, it will automatically shutdown again. func Start() error { - modulesLock.Lock() - defer modulesLock.Unlock() + modulesLock.RLock() + defer modulesLock.RUnlock() // start microtask scheduler go microTaskScheduler() @@ -106,7 +107,11 @@ func prepareModules() error { go func() { reports <- &report{ module: execM, - err: execM.prep(), + err: execM.runCtrlFnWithTimeout( + "prep module", + 10*time.Second, + execM.prep, + ), } }() } @@ -154,7 +159,11 @@ func startModules() error { go func() { reports <- &report{ module: execM, - err: execM.start(), + err: execM.runCtrlFnWithTimeout( + "start module", + 60*time.Second, + execM.start, + ), } }() } diff --git a/modules/stop.go b/modules/stop.go index ad1a79e..17e8b7b 100644 --- a/modules/stop.go +++ b/modules/stop.go @@ -12,6 +12,8 @@ import ( var ( shutdownSignal = make(chan struct{}) shutdownSignalClosed = abool.NewBool(false) + + shutdownCompleteSignal = make(chan struct{}) ) // ShuttingDown returns a channel read on the global shutdown signal. @@ -45,6 +47,7 @@ func Shutdown() error { } log.Shutdown() + close(shutdownCompleteSignal) return err } diff --git a/modules/tasks.go b/modules/tasks.go index 7e695d8..2af4fb2 100644 --- a/modules/tasks.go +++ b/modules/tasks.go @@ -3,6 +3,7 @@ package modules import ( "container/list" "context" + "fmt" "sync" "sync/atomic" "time" @@ -59,6 +60,10 @@ const ( // NewTask creates a new task with a descriptive name (non-unique), a optional deadline, and the task function to be executed. You must call one of Queue, Prioritize, StartASAP, Schedule or Repeat in order to have the Task executed. func (m *Module) NewTask(name string, fn func(context.Context, *Task)) *Task { + if m == nil { + log.Errorf(`modules: cannot create task "%s" with nil module`, name) + } + return &Task{ name: name, module: m, @@ -68,6 +73,10 @@ func (m *Module) NewTask(name string, fn func(context.Context, *Task)) *Task { } func (t *Task) isActive() bool { + if t.module == nil { + return false + } + return !t.canceled && !t.module.ShutdownInProgress() } @@ -178,6 +187,7 @@ func (t *Task) Repeat(interval time.Duration) *Task { t.lock.Lock() t.repeat = interval t.executeAt = time.Now().Add(t.repeat) + t.addToSchedule() t.lock.Unlock() return t @@ -194,6 +204,10 @@ func (t *Task) Cancel() { } func (t *Task) runWithLocking() { + if t.module == nil { + return + } + // wait for good timeslot regarding microtasks select { case <-taskTimeslot: @@ -308,6 +322,7 @@ func (t *Task) getExecuteAtWithLocking() time.Time { func (t *Task) addToSchedule() { scheduleLock.Lock() defer scheduleLock.Unlock() + // defer printTaskList(taskSchedule) // for debugging // notify scheduler defer func() { @@ -439,3 +454,23 @@ func taskScheduleHandler() { } } } + +func printTaskList(*list.List) { //nolint:unused,deadcode // for debugging, NOT production use + fmt.Println("Modules Task List:") + for e := taskSchedule.Front(); e != nil; e = e.Next() { + t, ok := e.Value.(*Task) + if ok { + fmt.Printf( + "%s:%s qu=%v ca=%v exec=%v at=%s rep=%s delay=%s\n", + t.module.Name, + t.name, + t.queued, + t.canceled, + t.executing, + t.executeAt, + t.repeat, + t.maxDelay, + ) + } + } +} diff --git a/modules/worker.go b/modules/worker.go index c385c6d..d0481f9 100644 --- a/modules/worker.go +++ b/modules/worker.go @@ -2,6 +2,8 @@ package modules import ( "context" + "errors" + "fmt" "sync/atomic" "time" @@ -13,8 +15,29 @@ const ( DefaultBackoffDuration = 2 * time.Second ) +var ( + // ErrRestartNow may be returned (wrapped) by service workers to request an immediate restart. + ErrRestartNow = errors.New("requested restart") + errNoModule = errors.New("missing module (is nil!)") +) + +// StartWorker directly starts a generic worker that does not fit to be a Task or MicroTask, such as long running (and possibly mostly idle) sessions. A call to StartWorker starts a new goroutine and returns immediately. +func (m *Module) StartWorker(name string, fn func(context.Context) error) { + go func() { + err := m.RunWorker(name, fn) + if err != nil { + log.Warningf("%s: worker %s failed: %s", m.Name, name, err) + } + }() +} + // RunWorker directly runs a generic worker that does not fit to be a Task or MicroTask, such as long running (and possibly mostly idle) sessions. A call to RunWorker blocks until the worker is finished. func (m *Module) RunWorker(name string, fn func(context.Context) error) error { + if m == nil { + log.Errorf(`modules: cannot start worker "%s" with nil module`, name) + return errNoModule + } + atomic.AddInt32(m.workerCnt, 1) m.waitGroup.Add(1) defer func() { @@ -27,6 +50,11 @@ func (m *Module) RunWorker(name string, fn func(context.Context) error) error { // StartServiceWorker starts a generic worker, which is automatically restarted in case of an error. A call to StartServiceWorker runs the service-worker in a new goroutine and returns immediately. `backoffDuration` specifies how to long to wait before restarts, multiplied by the number of failed attempts. Pass `0` for the default backoff duration. For custom error remediation functionality, build your own error handling procedure using calls to RunWorker. func (m *Module) StartServiceWorker(name string, backoffDuration time.Duration, fn func(context.Context) error) { + if m == nil { + log.Errorf(`modules: cannot start service worker "%s" with nil module`, name) + return + } + go m.runServiceWorker(name, backoffDuration, fn) } @@ -42,6 +70,7 @@ func (m *Module) runServiceWorker(name string, backoffDuration time.Duration, fn backoffDuration = DefaultBackoffDuration } failCnt := 0 + lastFail := time.Now() for { if m.ShutdownInProgress() { @@ -50,11 +79,22 @@ func (m *Module) runServiceWorker(name string, backoffDuration time.Duration, fn err := m.runWorker(name, fn) if err != nil { - // log error and restart - failCnt++ - sleepFor := time.Duration(failCnt) * backoffDuration - log.Errorf("%s: service-worker %s failed (%d): %s - restarting in %s", m.Name, name, failCnt, err, sleepFor) - time.Sleep(sleepFor) + if !errors.Is(err, ErrRestartNow) { + // reset fail counter if running without error for some time + if time.Now().Add(-5 * time.Minute).After(lastFail) { + failCnt = 0 + } + // increase fail counter and set last failed time + failCnt++ + lastFail = time.Now() + // log error + sleepFor := time.Duration(failCnt) * backoffDuration + log.Errorf("%s: service-worker %s failed (%d): %s - restarting in %s", m.Name, name, failCnt, err, sleepFor) + time.Sleep(sleepFor) + // loop to restart + } else { + log.Infof("%s: service-worker %s %s - restarting now", m.Name, name, err) + } } else { // finish return @@ -77,3 +117,39 @@ func (m *Module) runWorker(name string, fn func(context.Context) error) (err err err = fn(m.Ctx) return } + +func (m *Module) runCtrlFnWithTimeout(name string, timeout time.Duration, fn func() error) error { + + stopFnError := make(chan error) + go func() { + stopFnError <- m.runCtrlFn(name, fn) + }() + + // wait for results + select { + case err := <-stopFnError: + return err + case <-time.After(timeout): + return fmt.Errorf("timed out (%s)", timeout) + } +} + +func (m *Module) runCtrlFn(name string, fn func() error) (err error) { + if fn == nil { + return + } + + defer func() { + // recover from panic + panicVal := recover() + if panicVal != nil { + me := m.NewPanicError(name, "module-control", panicVal) + me.Report() + err = me + } + }() + + // run + err = fn() + return +} diff --git a/notifications/module.go b/notifications/module.go index c0ce406..39985ae 100644 --- a/notifications/module.go +++ b/notifications/module.go @@ -11,7 +11,7 @@ var ( ) func init() { - modules.Register("notifications", nil, start, nil, "base", "database") + module = modules.Register("notifications", nil, start, nil, "base", "database") } func start() error { diff --git a/rng/test/main.go b/rng/test/main.go index 601ac95..a45e222 100644 --- a/rng/test/main.go +++ b/rng/test/main.go @@ -38,6 +38,7 @@ func noise() { } +//nolint:gocognit func main() { // generates 1MB and writes to stdout diff --git a/run/main.go b/run/main.go new file mode 100644 index 0000000..c5d3917 --- /dev/null +++ b/run/main.go @@ -0,0 +1,126 @@ +package run + +import ( + "bufio" + "flag" + "fmt" + "os" + "os/signal" + "runtime/pprof" + "syscall" + "time" + + "github.com/safing/portbase/log" + "github.com/safing/portbase/modules" +) + +var ( + printStackOnExit bool + enableInputSignals bool + + sigUSR1 = syscall.Signal(0xa) // dummy for windows +) + +func init() { + flag.BoolVar(&printStackOnExit, "print-stack-on-exit", false, "prints the stack before of shutting down") + flag.BoolVar(&enableInputSignals, "input-signals", false, "emulate signals using stdin") +} + +// Run execute a full program lifecycle (including signal handling) based on modules. Just empty-import required packages and do os.Exit(run.Run()). +func Run() int { + + // Start + err := modules.Start() + if err != nil { + if err == modules.ErrCleanExit { + return 0 + } + + _ = modules.Shutdown() + return modules.GetExitStatusCode() + } + + // Shutdown + // catch interrupt for clean shutdown + signalCh := make(chan os.Signal) + if enableInputSignals { + go inputSignals(signalCh) + } + signal.Notify( + signalCh, + os.Interrupt, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + syscall.SIGQUIT, + sigUSR1, + ) + +signalLoop: + for { + select { + case sig := <-signalCh: + // only print and continue to wait if SIGUSR1 + if sig == sigUSR1 { + _ = pprof.Lookup("goroutine").WriteTo(os.Stderr, 2) + continue signalLoop + } + + fmt.Println(" ") + log.Warning("main: program was interrupted, shutting down.") + + // catch signals during shutdown + go func() { + for { + <-signalCh + fmt.Println(" again, but already shutting down") + } + }() + + if printStackOnExit { + fmt.Println("=== PRINTING TRACES ===") + fmt.Println("=== GOROUTINES ===") + _ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 2) + fmt.Println("=== BLOCKING ===") + _ = pprof.Lookup("block").WriteTo(os.Stdout, 2) + fmt.Println("=== MUTEXES ===") + _ = pprof.Lookup("mutex").WriteTo(os.Stdout, 2) + fmt.Println("=== END TRACES ===") + } + + go func() { + time.Sleep(60 * time.Second) + fmt.Fprintln(os.Stderr, "===== TAKING TOO LONG FOR SHUTDOWN - PRINTING STACK TRACES =====") + _ = pprof.Lookup("goroutine").WriteTo(os.Stderr, 2) + os.Exit(1) + }() + + _ = modules.Shutdown() + break signalLoop + + case <-modules.ShuttingDown(): + break signalLoop + } + } + + // wait for shutdown to complete, then exit + return modules.GetExitStatusCode() +} + +func inputSignals(signalCh chan os.Signal) { + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + switch scanner.Text() { + case "SIGHUP": + signalCh <- syscall.SIGHUP + case "SIGINT": + signalCh <- syscall.SIGINT + case "SIGQUIT": + signalCh <- syscall.SIGQUIT + case "SIGTERM": + signalCh <- syscall.SIGTERM + case "SIGUSR1": + signalCh <- sigUSR1 + } + } +} diff --git a/updater/doc.go b/updater/doc.go new file mode 100644 index 0000000..829a5bd --- /dev/null +++ b/updater/doc.go @@ -0,0 +1,2 @@ +// Package updater is an update registry that manages updates and versions. +package updater diff --git a/updater/export.go b/updater/export.go new file mode 100644 index 0000000..af5ef1d --- /dev/null +++ b/updater/export.go @@ -0,0 +1,15 @@ +package updater + +// Export exports the list of resources. All resources must be locked when accessed. +func (reg *ResourceRegistry) Export() map[string]*Resource { + reg.RLock() + defer reg.RUnlock() + + // copy the map + new := make(map[string]*Resource) + for key, val := range reg.resources { + new[key] = val + } + + return new +} diff --git a/updater/fetch.go b/updater/fetch.go new file mode 100644 index 0000000..e7e1373 --- /dev/null +++ b/updater/fetch.go @@ -0,0 +1,120 @@ +package updater + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "time" + + "github.com/google/renameio" + + "github.com/safing/portbase/log" +) + +func (reg *ResourceRegistry) fetchFile(rv *ResourceVersion, tries int) error { + // backoff when retrying + if tries > 0 { + time.Sleep(time.Duration(tries*tries) * time.Second) + } + + // create URL + downloadURL, err := joinURLandPath(reg.UpdateURLs[tries%len(reg.UpdateURLs)], rv.versionedPath()) + if err != nil { + return fmt.Errorf("error build url (%s + %s): %s", reg.UpdateURLs[tries%len(reg.UpdateURLs)], rv.versionedPath(), err) + } + + // check destination dir + dirPath := filepath.Dir(rv.storagePath()) + + err = reg.storageDir.EnsureAbsPath(dirPath) + if err != nil { + return fmt.Errorf("could not create updates folder: %s", dirPath) + } + + // open file for writing + atomicFile, err := renameio.TempFile(reg.tmpDir.Path, rv.storagePath()) + if err != nil { + return fmt.Errorf("could not create temp file for download: %s", err) + } + defer atomicFile.Cleanup() //nolint:errcheck // ignore error for now, tmp dir will be cleaned later again anyway + + // start file download + resp, err := http.Get(downloadURL) //nolint:gosec // url is variable on purpose + if err != nil { + return fmt.Errorf("error fetching url (%s): %s", downloadURL, err) + } + defer resp.Body.Close() + + // download and write file + n, err := io.Copy(atomicFile, resp.Body) + if err != nil { + return fmt.Errorf("failed downloading %s: %s", downloadURL, err) + } + if resp.ContentLength != n { + return fmt.Errorf("download unfinished, written %d out of %d bytes", n, resp.ContentLength) + } + + // finalize file + err = atomicFile.CloseAtomicallyReplace() + if err != nil { + return fmt.Errorf("%s: failed to finalize file %s: %s", reg.Name, rv.storagePath(), err) + } + // set permissions + if !onWindows { + // TODO: only set executable files to 0755, set other to 0644 + err = os.Chmod(rv.storagePath(), 0755) + if err != nil { + log.Warningf("%s: failed to set permissions on downloaded file %s: %s", reg.Name, rv.storagePath(), err) + } + } + + log.Infof("%s: fetched %s (stored to %s)", reg.Name, downloadURL, rv.storagePath()) + return nil +} + +func (reg *ResourceRegistry) fetchData(downloadPath string, tries int) ([]byte, error) { + // backoff when retrying + if tries > 0 { + time.Sleep(time.Duration(tries*tries) * time.Second) + } + + // create URL + downloadURL, err := joinURLandPath(reg.UpdateURLs[tries%len(reg.UpdateURLs)], downloadPath) + if err != nil { + return nil, fmt.Errorf("error build url (%s + %s): %s", reg.UpdateURLs[tries%len(reg.UpdateURLs)], downloadPath, err) + } + + // start file download + resp, err := http.Get(downloadURL) //nolint:gosec // url is variable on purpose + if err != nil { + return nil, fmt.Errorf("error fetching url (%s): %s", downloadURL, err) + } + defer resp.Body.Close() + + // download and write file + buf := bytes.NewBuffer(make([]byte, 0, resp.ContentLength)) + n, err := io.Copy(buf, resp.Body) + if err != nil { + return nil, fmt.Errorf("failed downloading %s: %s", downloadURL, err) + } + if resp.ContentLength != n { + return nil, fmt.Errorf("download unfinished, written %d out of %d bytes", n, resp.ContentLength) + } + + return buf.Bytes(), nil +} + +func joinURLandPath(baseURL, urlPath string) (string, error) { + u, err := url.Parse(baseURL) + if err != nil { + return "", err + } + + u.Path = path.Join(u.Path, urlPath) + return u.String(), nil +} diff --git a/updater/file.go b/updater/file.go new file mode 100644 index 0000000..14addd7 --- /dev/null +++ b/updater/file.go @@ -0,0 +1,41 @@ +package updater + +// File represents a file from the update system. +type File struct { + resource *Resource + version *ResourceVersion + notifier *notifier + versionedPath string + storagePath string +} + +// Identifier returns the identifier of the file. +func (file *File) Identifier() string { + return file.resource.Identifier +} + +// Version returns the version of the file. +func (file *File) Version() string { + return file.version.VersionNumber +} + +// Path returns the absolute filepath of the file. +func (file *File) Path() string { + return file.storagePath +} + +// Blacklist notifies the update system that this file is somehow broken, and should be ignored from now on, until restarted. +func (file *File) Blacklist() error { + return file.resource.Blacklist(file.version.VersionNumber) +} + +// used marks the file as active +func (file *File) markActiveWithLocking() { + file.resource.Lock() + defer file.resource.Unlock() + + // update last used version + if file.resource.ActiveVersion != file.version { + file.resource.ActiveVersion = file.version + } +} diff --git a/updater/filename.go b/updater/filename.go new file mode 100644 index 0000000..8767b4d --- /dev/null +++ b/updater/filename.go @@ -0,0 +1,46 @@ +package updater + +import ( + "fmt" + "regexp" + "strings" +) + +var ( + fileVersionRegex = regexp.MustCompile(`_v([0-9]+-[0-9]+-[0-9]+b?|0)`) + rawVersionRegex = regexp.MustCompile(`^([0-9]+\.[0-9]+\.[0-9]+b?\*?|0)$`) +) + +// GetIdentifierAndVersion splits the given file path into its identifier and version. +func GetIdentifierAndVersion(versionedPath string) (identifier, version string, ok bool) { + // extract version + rawVersion := fileVersionRegex.FindString(versionedPath) + if rawVersion == "" { + return "", "", false + } + + // replace - with . and trim _ + version = strings.Replace(strings.TrimLeft(rawVersion, "_v"), "-", ".", -1) + + // put together without version + i := strings.Index(versionedPath, rawVersion) + if i < 0 { + // extracted version not in string (impossible) + return "", "", false + } + return versionedPath[:i] + versionedPath[i+len(rawVersion):], version, true +} + +// GetVersionedPath combines the identifier and version and returns it as a file path. +func GetVersionedPath(identifier, version string) (versionedPath string) { + // split in half + splittedFilePath := strings.SplitN(identifier, ".", 2) + // replace . with - + transformedVersion := strings.Replace(version, ".", "-", -1) + + // put together + if len(splittedFilePath) == 1 { + return fmt.Sprintf("%s_v%s", splittedFilePath[0], transformedVersion) + } + return fmt.Sprintf("%s_v%s.%s", splittedFilePath[0], transformedVersion, splittedFilePath[1]) +} diff --git a/updater/filename_test.go b/updater/filename_test.go new file mode 100644 index 0000000..c3dee13 --- /dev/null +++ b/updater/filename_test.go @@ -0,0 +1,53 @@ +package updater + +import ( + "regexp" + "testing" +) + +func testRegexMatch(t *testing.T, testRegex *regexp.Regexp, testString string, shouldMatch bool) { + if testRegex.MatchString(testString) != shouldMatch { + if shouldMatch { + t.Errorf("regex %s should match %s", testRegex, testString) + } else { + t.Errorf("regex %s should not match %s", testRegex, testString) + } + } +} + +func testRegexFind(t *testing.T, testRegex *regexp.Regexp, testString string, shouldMatch bool) { + if (testRegex.FindString(testString) != "") != shouldMatch { + if shouldMatch { + t.Errorf("regex %s should find %s", testRegex, testString) + } else { + t.Errorf("regex %s should not find %s", testRegex, testString) + } + } +} + +func TestRegexes(t *testing.T) { + testRegexMatch(t, rawVersionRegex, "0", true) + testRegexMatch(t, rawVersionRegex, "0.1.2", true) + testRegexMatch(t, rawVersionRegex, "0.1.2*", true) + testRegexMatch(t, rawVersionRegex, "0.1.2b", true) + testRegexMatch(t, rawVersionRegex, "0.1.2b*", true) + testRegexMatch(t, rawVersionRegex, "12.13.14", true) + + testRegexMatch(t, rawVersionRegex, "v0.1.2", false) + testRegexMatch(t, rawVersionRegex, "0.", false) + testRegexMatch(t, rawVersionRegex, "0.1", false) + testRegexMatch(t, rawVersionRegex, "0.1.", false) + testRegexMatch(t, rawVersionRegex, ".1.2", false) + testRegexMatch(t, rawVersionRegex, ".1.", false) + testRegexMatch(t, rawVersionRegex, "012345", false) + + testRegexFind(t, fileVersionRegex, "/path/to/file_v0", true) + testRegexFind(t, fileVersionRegex, "/path/to/file_v1-2-3", true) + testRegexFind(t, fileVersionRegex, "/path/to/file_v1-2-3.exe", true) + + testRegexFind(t, fileVersionRegex, "/path/to/file-v1-2-3", false) + testRegexFind(t, fileVersionRegex, "/path/to/file_v1.2.3", false) + testRegexFind(t, fileVersionRegex, "/path/to/file_1-2-3", false) + testRegexFind(t, fileVersionRegex, "/path/to/file_v1-2", false) + testRegexFind(t, fileVersionRegex, "/path/to/file-v1-2-3", false) +} diff --git a/updater/get.go b/updater/get.go new file mode 100644 index 0000000..a36df3b --- /dev/null +++ b/updater/get.go @@ -0,0 +1,56 @@ +package updater + +import ( + "errors" + "fmt" + + "github.com/safing/portbase/log" +) + +// Errors +var ( + ErrNotFound = errors.New("the requested file could not be found") + ErrNotAvailableLocally = errors.New("the requested file is not available locally") +) + +// GetFile returns the selected (mostly newest) file with the given identifier or an error, if it fails. +func (reg *ResourceRegistry) GetFile(identifier string) (*File, error) { + reg.RLock() + res, ok := reg.resources[identifier] + reg.RUnlock() + if !ok { + return nil, ErrNotFound + } + + file := res.GetFile() + // check if file is available locally + if file.version.Available { + file.markActiveWithLocking() + return file, nil + } + + // check if online + if !reg.Online { + return nil, ErrNotAvailableLocally + } + + // check download dir + err := reg.tmpDir.Ensure() + if err != nil { + return nil, fmt.Errorf("could not prepare tmp directory for download: %s", err) + } + + // download file + log.Tracef("%s: starting download of %s", reg.Name, file.versionedPath) + for tries := 0; tries < 5; tries++ { + err = reg.fetchFile(file.version, tries) + if err != nil { + log.Tracef("%s: failed to download %s: %s, retrying (%d)", reg.Name, file.versionedPath, err, tries+1) + } else { + file.markActiveWithLocking() + return file, nil + } + } + log.Warningf("%s: failed to download %s: %s", reg.Name, file.versionedPath, err) + return nil, err +} diff --git a/updater/notifier.go b/updater/notifier.go new file mode 100644 index 0000000..66b2832 --- /dev/null +++ b/updater/notifier.go @@ -0,0 +1,33 @@ +package updater + +import ( + "github.com/tevino/abool" +) + +type notifier struct { + upgradeAvailable *abool.AtomicBool + notifyChannel chan struct{} +} + +func newNotifier() *notifier { + return ¬ifier{ + upgradeAvailable: abool.NewBool(false), + notifyChannel: make(chan struct{}), + } +} + +func (n *notifier) markAsUpgradeable() { + if n.upgradeAvailable.SetToIf(false, true) { + close(n.notifyChannel) + } +} + +// UpgradeAvailable returns whether an upgrade is available for this file. +func (file *File) UpgradeAvailable() bool { + return file.notifier.upgradeAvailable.IsSet() +} + +// WaitForAvailableUpgrade blocks (selectable) until an upgrade for this file is available. +func (file *File) WaitForAvailableUpgrade() <-chan struct{} { + return file.notifier.notifyChannel +} diff --git a/updater/registry.go b/updater/registry.go new file mode 100644 index 0000000..d593b9f --- /dev/null +++ b/updater/registry.go @@ -0,0 +1,161 @@ +package updater + +import ( + "os" + "runtime" + "sync" + + "github.com/tevino/abool" + + "github.com/safing/portbase/log" + "github.com/safing/portbase/utils" +) + +const ( + onWindows = runtime.GOOS == "windows" +) + +// ResourceRegistry is a registry for managing update resources. +type ResourceRegistry struct { + sync.RWMutex + + Name string + storageDir *utils.DirStructure + tmpDir *utils.DirStructure + + resources map[string]*Resource + UpdateURLs []string + MandatoryUpdates []string + + Beta bool + DevMode bool + Online bool + + notifyHooks []func() + notifyHooksEnabled *abool.AtomicBool +} + +// Initialize initializes a raw registry struct and makes it ready for usage. +func (reg *ResourceRegistry) Initialize(storageDir *utils.DirStructure) error { + // check if storage dir is available + err := storageDir.Ensure() + if err != nil { + return err + } + + // set default name + if reg.Name == "" { + reg.Name = "updater" + } + + // initialize private attributes + reg.storageDir = storageDir + reg.tmpDir = storageDir.ChildDir("tmp", 0700) + reg.resources = make(map[string]*Resource) + reg.notifyHooksEnabled = abool.NewBool(true) + + return nil +} + +// StorageDir returns the main storage dir of the resource registry. +func (reg *ResourceRegistry) StorageDir() *utils.DirStructure { + return reg.storageDir +} + +// TmpDir returns the temporary working dir of the resource registry. +func (reg *ResourceRegistry) TmpDir() *utils.DirStructure { + return reg.tmpDir +} + +// SetDevMode sets the development mode flag. +func (reg *ResourceRegistry) SetDevMode(on bool) { + reg.Lock() + defer reg.Unlock() + + reg.DevMode = on +} + +// SetBeta sets the beta flag. +func (reg *ResourceRegistry) SetBeta(on bool) { + reg.Lock() + defer reg.Unlock() + + reg.Beta = on +} + +// AddResource adds a resource to the registry. Does _not_ select new version. +func (reg *ResourceRegistry) AddResource(identifier, version string, available, stableRelease, betaRelease bool) error { + reg.Lock() + defer reg.Unlock() + + err := reg.addResource(identifier, version, available, stableRelease, betaRelease) + return err +} + +func (reg *ResourceRegistry) addResource(identifier, version string, available, stableRelease, betaRelease bool) error { + res, ok := reg.resources[identifier] + if !ok { + res = reg.newResource(identifier) + reg.resources[identifier] = res + } + return res.AddVersion(version, available, stableRelease, betaRelease) +} + +// AddResources adds resources to the registry. Errors are logged, the last one is returned. Despite errors, non-failing resources are still added. Does _not_ select new versions. +func (reg *ResourceRegistry) AddResources(versions map[string]string, available, stableRelease, betaRelease bool) error { + reg.Lock() + defer reg.Unlock() + + // add versions and their flags to registry + var lastError error + for identifier, version := range versions { + lastError = reg.addResource(identifier, version, available, stableRelease, betaRelease) + if lastError != nil { + log.Warningf("%s: failed to add resource %s: %s", reg.Name, identifier, lastError) + } + } + + return lastError +} + +// SelectVersions selects new resource versions depending on the current registry state. +func (reg *ResourceRegistry) SelectVersions() { + reg.RLock() + defer reg.RUnlock() + + for _, res := range reg.resources { + res.Lock() + res.selectVersion() + res.Unlock() + } +} + +// GetSelectedVersions returns a list of the currently selected versions. +func (reg *ResourceRegistry) GetSelectedVersions() (versions map[string]string) { + reg.RLock() + defer reg.RUnlock() + + for _, res := range reg.resources { + res.Lock() + versions[res.Identifier] = res.SelectedVersion.VersionNumber + res.Unlock() + } + + return +} + +// Purge deletes old updates, retaining a certain amount, specified by the keep parameter. Will at least keep 2 updates per resource. +func (reg *ResourceRegistry) Purge(keep int) { + reg.RLock() + defer reg.RUnlock() + + for _, res := range reg.resources { + res.Purge(keep) + } +} + +// Cleanup removes temporary files. +func (reg *ResourceRegistry) Cleanup() error { + // delete download tmp dir + return os.RemoveAll(reg.tmpDir.Path) +} diff --git a/updater/registry_test.go b/updater/registry_test.go new file mode 100644 index 0000000..ca95214 --- /dev/null +++ b/updater/registry_test.go @@ -0,0 +1,38 @@ +package updater + +import ( + "io/ioutil" + "os" + "testing" + + "github.com/safing/portbase/utils" +) + +var ( + registry *ResourceRegistry +) + +func TestMain(m *testing.M) { + // setup + tmpDir, err := ioutil.TempDir("", "ci-portmaster-") + if err != nil { + panic(err) + } + registry = &ResourceRegistry{ + Beta: true, + DevMode: true, + Online: true, + } + err = registry.Initialize(utils.NewDirStructure(tmpDir, 0777)) + if err != nil { + panic(err) + } + + // run + // call flag.Parse() here if TestMain uses flags + ret := m.Run() + + // teardown + os.RemoveAll(tmpDir) + os.Exit(ret) +} diff --git a/updater/resource.go b/updater/resource.go new file mode 100644 index 0000000..f131b2d --- /dev/null +++ b/updater/resource.go @@ -0,0 +1,330 @@ +package updater + +import ( + "errors" + "os" + "path/filepath" + "sort" + "strings" + "sync" + + "github.com/safing/portbase/log" + + semver "github.com/hashicorp/go-version" +) + +// Resource represents a resource (via an identifier) and multiple file versions. +type Resource struct { + sync.Mutex + registry *ResourceRegistry + notifier *notifier + + Identifier string + Versions []*ResourceVersion + + ActiveVersion *ResourceVersion + SelectedVersion *ResourceVersion + ForceDownload bool +} + +// ResourceVersion represents a single version of a resource. +type ResourceVersion struct { + resource *Resource + + VersionNumber string + semVer *semver.Version + Available bool + StableRelease bool + BetaRelease bool + Blacklisted bool +} + +// Len is the number of elements in the collection. (sort.Interface for Versions) +func (res *Resource) Len() int { + return len(res.Versions) +} + +// Less reports whether the element with index i should sort before the element with index j. (sort.Interface for Versions) +func (res *Resource) Less(i, j int) bool { + return res.Versions[i].semVer.GreaterThan(res.Versions[j].semVer) +} + +// Swap swaps the elements with indexes i and j. (sort.Interface for Versions) +func (res *Resource) Swap(i, j int) { + res.Versions[i], res.Versions[j] = res.Versions[j], res.Versions[i] +} + +// available returns whether any version of the resource is available. +func (res *Resource) available() bool { + for _, rv := range res.Versions { + if rv.Available { + return true + } + } + return false +} + +func (reg *ResourceRegistry) newResource(identifier string) *Resource { + return &Resource{ + registry: reg, + Identifier: identifier, + Versions: make([]*ResourceVersion, 0, 1), + } +} + +// AddVersion adds a resource version to a resource. +func (res *Resource) AddVersion(version string, available, stableRelease, betaRelease bool) error { + res.Lock() + defer res.Unlock() + + // reset stable or beta release flags + if stableRelease || betaRelease { + for _, rv := range res.Versions { + if stableRelease { + rv.StableRelease = false + } + if betaRelease { + rv.BetaRelease = false + } + } + } + + var rv *ResourceVersion + // check for existing version + for _, possibleMatch := range res.Versions { + if possibleMatch.VersionNumber == version { + rv = possibleMatch + break + } + } + + // create new version if none found + if rv == nil { + // parse to semver + sv, err := semver.NewVersion(version) + if err != nil { + return err + } + + rv = &ResourceVersion{ + resource: res, + VersionNumber: version, + semVer: sv, + } + res.Versions = append(res.Versions, rv) + } + + // set flags + if available { + rv.Available = true + } + if stableRelease { + rv.StableRelease = true + } + if betaRelease { + rv.BetaRelease = true + } + + return nil +} + +// GetFile returns the selected version as a *File. +func (res *Resource) GetFile() *File { + res.Lock() + defer res.Unlock() + + // check for notifier + if res.notifier == nil { + // create new notifier + res.notifier = newNotifier() + } + + // check if version is selected + if res.SelectedVersion == nil { + res.selectVersion() + } + + // create file + return &File{ + resource: res, + version: res.SelectedVersion, + notifier: res.notifier, + versionedPath: res.SelectedVersion.versionedPath(), + storagePath: res.SelectedVersion.storagePath(), + } +} + +//nolint:gocognit // function already kept as simlpe as possible +func (res *Resource) selectVersion() { + sort.Sort(res) + + // export after we finish + defer func() { + if res.ActiveVersion != nil && // resource has already been used + res.SelectedVersion != res.ActiveVersion && // new selected version does not match previously selected version + res.notifier != nil { + res.notifier.markAsUpgradeable() + res.notifier = nil + } + }() + + if len(res.Versions) == 0 { + // TODO: find better way to deal with an empty version slice (which should not happen) + res.SelectedVersion = nil + return + } + + // Target selection + // 1) Dev release if dev mode is active and ignore blacklisting + if res.registry.DevMode { + // get last element + rv := res.Versions[len(res.Versions)-1] + // check if it's a dev version + if rv.VersionNumber == "0" && rv.Available { + res.SelectedVersion = rv + return + } + } + + // 2) Beta release if beta is active + if res.registry.Beta { + for _, rv := range res.Versions { + if rv.BetaRelease { + if !rv.Blacklisted && (rv.Available || rv.resource.registry.Online) { + res.SelectedVersion = rv + return + } + break + } + } + } + + // 3) Stable release + for _, rv := range res.Versions { + if rv.StableRelease { + if !rv.Blacklisted && (rv.Available || rv.resource.registry.Online) { + res.SelectedVersion = rv + return + } + break + } + } + + // 4) Latest stable release + for _, rv := range res.Versions { + if !strings.HasSuffix(rv.VersionNumber, "b") && !rv.Blacklisted && (rv.Available || rv.resource.registry.Online) { + res.SelectedVersion = rv + return + } + } + + // 5) Latest of any type + for _, rv := range res.Versions { + if !rv.Blacklisted && (rv.Available || rv.resource.registry.Online) { + res.SelectedVersion = rv + return + } + } + + // 6) Default to newest + res.SelectedVersion = res.Versions[0] +} + +// Blacklist blacklists the specified version and selects a new version. +func (res *Resource) Blacklist(version string) error { + res.Lock() + defer res.Unlock() + + // count already blacklisted entries + valid := 0 + for _, rv := range res.Versions { + if rv.VersionNumber == "0" { + continue // ignore dev versions + } + if !rv.Blacklisted { + valid++ + } + } + if valid <= 1 { + return errors.New("cannot blacklist last version") // last one, cannot blacklist! + } + + // find version and blacklist + for _, rv := range res.Versions { + if rv.VersionNumber == version { + // blacklist and update + rv.Blacklisted = true + res.selectVersion() + return nil + } + } + + return errors.New("could not find version") +} + +// Purge deletes old updates, retaining a certain amount, specified by the keep parameter. Will at least keep 2 updates per resource. After purging, new versions will be selected. +func (res *Resource) Purge(keep int) { + res.Lock() + defer res.Unlock() + + // safeguard + if keep < 2 { + keep = 2 + } + + // keep versions + var validVersions int + var skippedActiveVersion bool + var skippedSelectedVersion bool + var purgeFrom int + for i, rv := range res.Versions { + // continue to purging? + if validVersions >= keep && // skip at least versions + skippedActiveVersion && // skip until active version + skippedSelectedVersion { // skip until selected version + purgeFrom = i + break + } + + // keep active version + if !skippedActiveVersion && rv == res.ActiveVersion { + skippedActiveVersion = true + } + + // keep selected version + if !skippedSelectedVersion && rv == res.SelectedVersion { + skippedSelectedVersion = true + } + + // count valid (not blacklisted) versions + if !rv.Blacklisted { + validVersions++ + } + } + + // check if there is anything to purge + if purgeFrom < keep || purgeFrom > len(res.Versions) { + return + } + + // purge phase + for _, rv := range res.Versions[purgeFrom:] { + // delete + err := os.Remove(rv.storagePath()) + if err != nil { + log.Warningf("%s: failed to purge old resource %s: %s", res.registry.Name, rv.storagePath(), err) + } + } + // remove entries of deleted files + res.Versions = res.Versions[purgeFrom:] + + res.selectVersion() +} + +func (rv *ResourceVersion) versionedPath() string { + return GetVersionedPath(rv.resource.Identifier, rv.VersionNumber) +} + +func (rv *ResourceVersion) storagePath() string { + return filepath.Join(rv.resource.registry.storageDir.Path, filepath.FromSlash(rv.versionedPath())) +} diff --git a/updater/resource_test.go b/updater/resource_test.go new file mode 100644 index 0000000..f6bc4a6 --- /dev/null +++ b/updater/resource_test.go @@ -0,0 +1,81 @@ +package updater + +import ( + "testing" +) + +func TestVersionSelection(t *testing.T) { + res := registry.newResource("test/a") + + err := res.AddVersion("1.2.3", true, true, false) + if err != nil { + t.Fatal(err) + } + err = res.AddVersion("1.2.4b", true, false, true) + if err != nil { + t.Fatal(err) + } + err = res.AddVersion("1.2.2", true, false, false) + if err != nil { + t.Fatal(err) + } + err = res.AddVersion("1.2.5", false, true, false) + if err != nil { + t.Fatal(err) + } + err = res.AddVersion("0", true, false, false) + if err != nil { + t.Fatal(err) + } + + registry.Online = true + registry.Beta = true + registry.DevMode = true + res.selectVersion() + if res.SelectedVersion.VersionNumber != "0" { + t.Errorf("selected version should be 0, not %s", res.SelectedVersion.VersionNumber) + } + + registry.DevMode = false + res.selectVersion() + if res.SelectedVersion.VersionNumber != "1.2.4b" { + t.Errorf("selected version should be 1.2.4b, not %s", res.SelectedVersion.VersionNumber) + } + + registry.Beta = false + res.selectVersion() + if res.SelectedVersion.VersionNumber != "1.2.5" { + t.Errorf("selected version should be 1.2.5, not %s", res.SelectedVersion.VersionNumber) + } + + registry.Online = false + res.selectVersion() + if res.SelectedVersion.VersionNumber != "1.2.3" { + t.Errorf("selected version should be 1.2.3, not %s", res.SelectedVersion.VersionNumber) + } + + f123 := res.GetFile() + f123.markActiveWithLocking() + + err = res.Blacklist("1.2.3") + if err != nil { + t.Fatal(err) + } + if res.SelectedVersion.VersionNumber != "1.2.2" { + t.Errorf("selected version should be 1.2.2, not %s", res.SelectedVersion.VersionNumber) + } + + if !f123.UpgradeAvailable() { + t.Error("upgrade should be available (flag)") + } + select { + case <-f123.WaitForAvailableUpgrade(): + default: + t.Error("upgrade should be available (chan)") + } + + t.Logf("resource: %+v", res) + for _, rv := range res.Versions { + t.Logf("version %s: %+v", rv.VersionNumber, rv) + } +} diff --git a/updater/storage.go b/updater/storage.go new file mode 100644 index 0000000..35a6e6b --- /dev/null +++ b/updater/storage.go @@ -0,0 +1,159 @@ +package updater + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" + + "github.com/safing/portbase/log" + "github.com/safing/portbase/utils" +) + +// ScanStorage scans root within the storage dir and adds found resources to the registry. If an error occurred, it is logged and the last error is returned. Everything that was found despite errors is added to the registry anyway. Leave root empty to scan the full storage dir. +func (reg *ResourceRegistry) ScanStorage(root string) error { + var lastError error + + // prep root + if root == "" { + root = reg.storageDir.Path + } else { + var err error + root, err = filepath.Abs(root) + if err != nil { + return err + } + if !strings.HasPrefix(root, reg.storageDir.Path) { + return errors.New("supplied scan root path not within storage") + } + } + + // walk fs + _ = filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if err != nil { + lastError = fmt.Errorf("%s: could not read %s: %s", reg.Name, path, err) + log.Warning(lastError.Error()) + return nil + } + + // get relative path to storage + relativePath, err := filepath.Rel(reg.storageDir.Path, path) + if err != nil { + lastError = fmt.Errorf("%s: could not get relative path of %s: %s", reg.Name, path, err) + log.Warning(lastError.Error()) + return nil + } + // ignore files in tmp dir + if strings.HasPrefix(relativePath, reg.tmpDir.Path) { + return nil + } + + // convert to identifier and version + relativePath = filepath.ToSlash(relativePath) + identifier, version, ok := GetIdentifierAndVersion(relativePath) + if !ok { + // file does not conform to format + return nil + } + + // save + err = reg.AddResource(identifier, version, true, false, false) + if err != nil { + lastError = fmt.Errorf("%s: could not get add resource %s v%s: %s", reg.Name, identifier, version, err) + log.Warning(lastError.Error()) + } + return nil + }) + + return lastError +} + +// LoadIndexes loads the current release indexes from disk and will fetch a new version if not available and online. +func (reg *ResourceRegistry) LoadIndexes() error { + err := reg.loadIndexFile("stable.json", true, false) + if err != nil { + err = reg.downloadIndex("stable.json", true, false) + if err != nil { + return err + } + } + + err = reg.loadIndexFile("beta.json", false, true) + if err != nil { + err = reg.downloadIndex("beta.json", false, true) + if err != nil { + return err + } + } + + return nil +} + +func (reg *ResourceRegistry) loadIndexFile(name string, stableRelease, betaRelease bool) error { + data, err := ioutil.ReadFile(filepath.Join(reg.storageDir.Path, name)) + if err != nil { + return err + } + + releases := make(map[string]string) + err = json.Unmarshal(data, &releases) + if err != nil { + return err + } + + if len(releases) == 0 { + return fmt.Errorf("%s is empty", name) + } + + err = reg.AddResources(releases, false, stableRelease, betaRelease) + if err != nil { + log.Warningf("%s: failed to add resource: %s", reg.Name, err) + } + return nil +} + +// CreateSymlinks creates a directory structure with unversions symlinks to the given updates list. +func (reg *ResourceRegistry) CreateSymlinks(symlinkRoot *utils.DirStructure) error { + err := os.RemoveAll(symlinkRoot.Path) + if err != nil { + return fmt.Errorf("failed to wipe symlink root: %s", err) + } + + err = symlinkRoot.Ensure() + if err != nil { + return fmt.Errorf("failed to create symlink root: %s", err) + } + + reg.RLock() + defer reg.RUnlock() + + for _, res := range reg.resources { + if res.SelectedVersion == nil { + return fmt.Errorf("no selected version available for %s", res.Identifier) + } + + targetPath := res.SelectedVersion.storagePath() + linkPath := filepath.Join(symlinkRoot.Path, filepath.FromSlash(res.Identifier)) + linkPathDir := filepath.Dir(linkPath) + + err = symlinkRoot.EnsureAbsPath(linkPathDir) + if err != nil { + return fmt.Errorf("failed to create dir for link: %s", err) + } + + relativeTargetPath, err := filepath.Rel(linkPathDir, targetPath) + if err != nil { + return fmt.Errorf("failed to get relative target path: %s", err) + } + + err = os.Symlink(relativeTargetPath, linkPath) + if err != nil { + return fmt.Errorf("failed to link %s: %s", res.Identifier, err) + } + } + + return nil +} diff --git a/updater/storage_test.go b/updater/storage_test.go new file mode 100644 index 0000000..eafabe4 --- /dev/null +++ b/updater/storage_test.go @@ -0,0 +1,68 @@ +package updater + +/* +func testLoadLatestScope(t *testing.T, basePath, filePath, expectedIdentifier, expectedVersion string) { + fullPath := filepath.Join(basePath, filePath) + + // create dir + dirPath := filepath.Dir(fullPath) + err := os.MkdirAll(dirPath, 0755) + if err != nil { + t.Fatalf("could not create test dir: %s\n", err) + return + } + + // touch file + err = ioutil.WriteFile(fullPath, []byte{}, 0644) + if err != nil { + t.Fatalf("could not create test file: %s\n", err) + return + } + + // run loadLatestScope + latest, err := ScanForLatest(basePath, true) + if err != nil { + t.Errorf("could not update latest: %s\n", err) + return + } + for key, val := range latest { + localUpdates[key] = val + } + + // test result + version, ok := localUpdates[expectedIdentifier] + if !ok { + t.Errorf("identifier %s not in map", expectedIdentifier) + t.Errorf("current map: %v", localUpdates) + } + if version != expectedVersion { + t.Errorf("unexpected version for %s: %s", filePath, version) + } +} + +func TestLoadLatestScope(t *testing.T) { + + updatesLock.Lock() + defer updatesLock.Unlock() + + tmpDir, err := ioutil.TempDir("", "testing_") + if err != nil { + t.Fatalf("could not create test dir: %s\n", err) + return + } + defer os.RemoveAll(tmpDir) + + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-3.zip", "all/ui/assets.zip", "1.2.3") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-4b.zip", "all/ui/assets.zip", "1.2.4b") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-5.zip", "all/ui/assets.zip", "1.2.5") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-3-4.zip", "all/ui/assets.zip", "1.3.4") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v2-3-4.zip", "all/ui/assets.zip", "2.3.4") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-3.zip", "all/ui/assets.zip", "2.3.4") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-4.zip", "all/ui/assets.zip", "2.3.4") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-3-4.zip", "all/ui/assets.zip", "2.3.4") + testLoadLatestScope(t, tmpDir, "os_platform/portmaster/portmaster_v1-2-3", "os_platform/portmaster/portmaster", "1.2.3") + testLoadLatestScope(t, tmpDir, "os_platform/portmaster/portmaster_v2-1-1", "os_platform/portmaster/portmaster", "2.1.1") + testLoadLatestScope(t, tmpDir, "os_platform/portmaster/portmaster_v1-2-3", "os_platform/portmaster/portmaster", "2.1.1") + +} +*/ diff --git a/updater/updating.go b/updater/updating.go new file mode 100644 index 0000000..81c8240 --- /dev/null +++ b/updater/updating.go @@ -0,0 +1,125 @@ +package updater + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path/filepath" + + "github.com/safing/portbase/utils" + + "github.com/safing/portbase/log" +) + +// UpdateIndexes downloads the current update indexes. +func (reg *ResourceRegistry) UpdateIndexes() error { + err := reg.downloadIndex("stable.json", true, false) + if err != nil { + return err + } + + return reg.downloadIndex("beta.json", false, true) +} + +func (reg *ResourceRegistry) downloadIndex(name string, stableRelease, betaRelease bool) error { + var err error + var data []byte + + // download new index + for tries := 0; tries < 3; tries++ { + data, err = reg.fetchData(name, tries) + if err == nil { + break + } + } + if err != nil { + return fmt.Errorf("failed to download index %s: %s", name, err) + } + + // parse + new := make(map[string]string) + err = json.Unmarshal(data, &new) + if err != nil { + return fmt.Errorf("failed to parse index %s: %s", name, err) + } + + // check for content + if len(new) == 0 { + return fmt.Errorf("index %s is empty", name) + } + + // add resources to registry + _ = reg.AddResources(new, false, stableRelease, betaRelease) + + // save index + err = ioutil.WriteFile(filepath.Join(reg.storageDir.Path, name), data, 0644) + if err != nil { + log.Warningf("%s: failed to save updated index %s: %s", reg.Name, name, err) + } + + log.Infof("%s: updated index %s", reg.Name, name) + return nil +} + +// DownloadUpdates checks if updates are available and downloads updates of used components. +func (reg *ResourceRegistry) DownloadUpdates(ctx context.Context) error { + // create list of downloads + var toUpdate []*ResourceVersion + reg.RLock() + for _, res := range reg.resources { + res.Lock() + + // check if we want to download + if res.ActiveVersion != nil || // resource is currently being used + res.available() || // resource was used in the past + utils.StringInSlice(reg.MandatoryUpdates, res.Identifier) { // resource is mandatory + + // add all non-available and eligible versions to update queue + for _, rv := range res.Versions { + if !rv.Available && (rv.StableRelease || reg.Beta && rv.BetaRelease) { + toUpdate = append(toUpdate, rv) + } + } + } + + res.Unlock() + } + reg.RUnlock() + + // nothing to update + if len(toUpdate) == 0 { + log.Infof("%s: everything up to date", reg.Name) + return nil + } + + // check download dir + err := reg.tmpDir.Ensure() + if err != nil { + return fmt.Errorf("could not prepare tmp directory for download: %s", err) + } + + // download updates + log.Infof("%s: starting to download %d updates", reg.Name, len(toUpdate)) + for _, rv := range toUpdate { + for tries := 0; tries < 3; tries++ { + err = reg.fetchFile(rv, tries) + if err == nil { + break + } + } + if err != nil { + log.Warningf("%s: failed to download %s version %s: %s", reg.Name, rv.resource.Identifier, rv.VersionNumber, err) + } + } + log.Infof("%s: finished downloading updates", reg.Name) + + // remove tmp folder after we are finished + err = os.RemoveAll(reg.tmpDir.Path) + if err != nil { + log.Tracef("%s: failed to remove tmp dir %s after downloading updates: %s", reg.Name, reg.tmpDir.Path, err) + } + + return nil +} diff --git a/updater/uptool/.gitignore b/updater/uptool/.gitignore new file mode 100644 index 0000000..c5074cf --- /dev/null +++ b/updater/uptool/.gitignore @@ -0,0 +1 @@ +uptool diff --git a/updater/uptool/root.go b/updater/uptool/root.go new file mode 100644 index 0000000..3595563 --- /dev/null +++ b/updater/uptool/root.go @@ -0,0 +1,37 @@ +package main + +import ( + "os" + "path/filepath" + + "github.com/safing/portbase/updater" + "github.com/safing/portbase/utils" + "github.com/spf13/cobra" +) + +var registry *updater.ResourceRegistry + +var rootCmd = &cobra.Command{ + Use: "uptool", + Short: "helper tool for the update process", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return cmd.Usage() + }, + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + absPath, err := filepath.Abs(args[0]) + if err != nil { + return err + } + + registry = &updater.ResourceRegistry{} + return registry.Initialize(utils.NewDirStructure(absPath, 0755)) + }, + SilenceUsage: true, +} + +func main() { + if err := rootCmd.Execute(); err != nil { + os.Exit(1) + } +} diff --git a/updater/uptool/scan.go b/updater/uptool/scan.go new file mode 100644 index 0000000..8ec1590 --- /dev/null +++ b/updater/uptool/scan.go @@ -0,0 +1,80 @@ +package main + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "path/filepath" + "strings" + + "github.com/spf13/cobra" +) + +func init() { + rootCmd.AddCommand(scanCmd) +} + +var scanCmd = &cobra.Command{ + Use: "scan", + Short: "Scan the specified directory and print the result", + Args: cobra.ExactArgs(1), + RunE: scan, +} + +func scan(cmd *cobra.Command, args []string) error { + err := scanStorage() + if err != nil { + return err + } + + // export beta + data, err := json.MarshalIndent(exportSelected(true), "", " ") + if err != nil { + return err + } + // print + fmt.Println("beta:") + fmt.Println(string(data)) + + // export stable + data, err = json.MarshalIndent(exportSelected(false), "", " ") + if err != nil { + return err + } + // print + fmt.Println("\nstable:") + fmt.Println(string(data)) + + return nil +} + +func scanStorage() error { + files, err := ioutil.ReadDir(registry.StorageDir().Path) + if err != nil { + return err + } + + // scan "all" and all "os_platform" dirs + for _, file := range files { + if file.IsDir() && (file.Name() == "all" || strings.Contains(file.Name(), "_")) { + err := registry.ScanStorage(filepath.Join(registry.StorageDir().Path, file.Name())) + if err != nil { + return err + } + } + } + + return nil +} + +func exportSelected(beta bool) map[string]string { + registry.SetBeta(beta) + registry.SelectVersions() + export := registry.Export() + + versions := make(map[string]string) + for _, rv := range export { + versions[rv.Identifier] = rv.SelectedVersion.VersionNumber + } + return versions +} diff --git a/updater/uptool/update.go b/updater/uptool/update.go new file mode 100644 index 0000000..62ec159 --- /dev/null +++ b/updater/uptool/update.go @@ -0,0 +1,64 @@ +package main + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "path/filepath" + + "github.com/spf13/cobra" +) + +func init() { + rootCmd.AddCommand(updateCmd) +} + +var updateCmd = &cobra.Command{ + Use: "update", + Short: "Update scans the specified directory and registry the index and symlink structure", + Args: cobra.ExactArgs(1), + RunE: update, +} + +func update(cmd *cobra.Command, args []string) error { + err := scanStorage() + if err != nil { + return err + } + + // export beta + data, err := json.MarshalIndent(exportSelected(true), "", " ") + if err != nil { + return err + } + // print + fmt.Println("beta:") + fmt.Println(string(data)) + // write index + err = ioutil.WriteFile(filepath.Join(registry.StorageDir().Dir, "beta.json"), data, 0755) + if err != nil { + return err + } + + // export stable + data, err = json.MarshalIndent(exportSelected(false), "", " ") + if err != nil { + return err + } + // print + fmt.Println("\nstable:") + fmt.Println(string(data)) + // write index + err = ioutil.WriteFile(filepath.Join(registry.StorageDir().Dir, "stable.json"), data, 0755) + if err != nil { + return err + } + // create symlinks + err = registry.CreateSymlinks(registry.StorageDir().ChildDir("latest", 0755)) + if err != nil { + return err + } + fmt.Println("\nstable symlinks created") + + return nil +}