Merge pull request #20 from safing/develop

Release to master
This commit is contained in:
Daniel 2019-10-30 14:05:07 +01:00 committed by GitHub
commit 9eb80646c7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
62 changed files with 2683 additions and 239 deletions

View file

@ -4,3 +4,8 @@ linters:
- lll - lll
- gochecknoinits - gochecknoinits
- gochecknoglobals - gochecknoglobals
- funlen
- whitespace
- wsl
- godox

View file

@ -1,5 +1,8 @@
language: go language: go
go:
- 1.x
os: os:
- linux - linux
- windows - windows

25
Gopkg.lock generated
View file

@ -123,6 +123,14 @@
revision = "ac23dc3fea5d1a983c43f6a0f6e2c13f0195d8bd" revision = "ac23dc3fea5d1a983c43f6a0f6e2c13f0195d8bd"
version = "v1.2.0" version = "v1.2.0"
[[projects]]
digest = "1:870d441fe217b8e689d7949fef6e43efbc787e50f200cb1e70dbca9204a1d6be"
name = "github.com/inconshreveable/mousetrap"
packages = ["."]
pruneopts = "UT"
revision = "76626ae9c91c4f2a10f34cad8ce83ea42c93bb75"
version = "v1.0"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:7e8b852581596acce37bcb939a05d7d5ff27156045b50057e659e299c16fc1ca" digest = "1:7e8b852581596acce37bcb939a05d7d5ff27156045b50057e659e299c16fc1ca"
@ -208,6 +216,22 @@
pruneopts = "UT" pruneopts = "UT"
revision = "bb4de0191aa41b5507caa14b0650cdbddcd9280b" revision = "bb4de0191aa41b5507caa14b0650cdbddcd9280b"
[[projects]]
digest = "1:e096613fb7cf34743d49af87d197663cfccd61876e2219853005a57baedfa562"
name = "github.com/spf13/cobra"
packages = ["."]
pruneopts = "UT"
revision = "f2b07da1e2c38d5f12845a4f607e2e1018cbb1f5"
version = "v0.0.5"
[[projects]]
digest = "1:524b71991fc7d9246cc7dc2d9e0886ccb97648091c63e30eef619e6862c955dd"
name = "github.com/spf13/pflag"
packages = ["."]
pruneopts = "UT"
revision = "2e9d26c8c37aae03e3f9d4e90b7116f5accb7cab"
version = "v1.0.5"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:93d6687fc19da8a35c7352d72117a6acd2072dfb7e9bfd65646227bf2a913b2a" digest = "1:93d6687fc19da8a35c7352d72117a6acd2072dfb7e9bfd65646227bf2a913b2a"
@ -312,6 +336,7 @@
"github.com/satori/go.uuid", "github.com/satori/go.uuid",
"github.com/seehuhn/fortuna", "github.com/seehuhn/fortuna",
"github.com/shirou/gopsutil/host", "github.com/shirou/gopsutil/host",
"github.com/spf13/cobra",
"github.com/tevino/abool", "github.com/tevino/abool",
"github.com/tidwall/gjson", "github.com/tidwall/gjson",
"github.com/tidwall/sjson", "github.com/tidwall/sjson",

View file

@ -7,13 +7,17 @@ import (
"github.com/safing/portbase/modules" "github.com/safing/portbase/modules"
) )
var (
module *modules.Module
)
// API Errors // API Errors
var ( var (
ErrAuthenticationAlreadySet = errors.New("the authentication function has already been set") ErrAuthenticationAlreadySet = errors.New("the authentication function has already been set")
) )
func init() { func init() {
modules.Register("api", prep, start, stop, "base", "database", "config") module = modules.Register("api", prep, start, stop, "base", "database", "config")
} }
func prep() error { func prep() error {

View file

@ -1,8 +1,10 @@
package api package api
import ( import (
"context"
"net/http" "net/http"
"sync" "sync"
"time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -56,8 +58,20 @@ func Serve() {
// start serving // start serving
log.Infof("api: starting to listen on %s", server.Addr) log.Infof("api: starting to listen on %s", server.Addr)
// TODO: retry if failed backoffDuration := 10 * time.Second
log.Errorf("api: failed to listen on %s: %s", server.Addr, server.ListenAndServe()) for {
// always returns an error
err := module.RunWorker("http endpoint", func(ctx context.Context) error {
return server.ListenAndServe()
})
// return on shutdown error
if err == http.ErrServerClosed {
return
}
// log error and restart
log.Errorf("api: http endpoint failed: %s - restarting in %s", err, backoffDuration)
time.Sleep(backoffDuration)
}
} }
// GetMuxVars wraps github.com/gorilla/mux.Vars in order to mitigate context key issues in multi-repo projects. // GetMuxVars wraps github.com/gorilla/mux.Vars in order to mitigate context key issues in multi-repo projects.

2
config/doc.go Normal file
View file

@ -0,0 +1,2 @@
// Package config provides a versatile configuration management system.
package config

72
config/expertise.go Normal file
View file

@ -0,0 +1,72 @@
// Package config ... (linter fix)
//nolint:dupl
package config
import (
"fmt"
"sync/atomic"
)
// Expertise Level constants
const (
ExpertiseLevelUser uint8 = 0
ExpertiseLevelExpert uint8 = 1
ExpertiseLevelDeveloper uint8 = 2
ExpertiseLevelNameUser = "user"
ExpertiseLevelNameExpert = "expert"
ExpertiseLevelNameDeveloper = "developer"
expertiseLevelKey = "core/expertiseLevel"
)
var (
expertiseLevel *int32
)
func init() {
var expertiseLevelVal int32
expertiseLevel = &expertiseLevelVal
registerExpertiseLevelOption()
}
func registerExpertiseLevelOption() {
err := Register(&Option{
Name: "Expertise Level",
Key: expertiseLevelKey,
Description: "The Expertise Level controls the perceived complexity. Higher settings will show you more complex settings and information. This might also affect various other things relying on this setting. Modified settings in higher expertise levels stay in effect when switching back. (Unlike the Release Level)",
OptType: OptTypeString,
ExpertiseLevel: ExpertiseLevelUser,
ReleaseLevel: ExpertiseLevelUser,
RequiresRestart: false,
DefaultValue: ExpertiseLevelNameUser,
ExternalOptType: "string list",
ValidationRegex: fmt.Sprintf("^(%s|%s|%s)$", ExpertiseLevelNameUser, ExpertiseLevelNameExpert, ExpertiseLevelNameDeveloper),
})
if err != nil {
panic(err)
}
}
func updateExpertiseLevel() {
new := findStringValue(expertiseLevelKey, "")
switch new {
case ExpertiseLevelNameUser:
atomic.StoreInt32(expertiseLevel, int32(ExpertiseLevelUser))
case ExpertiseLevelNameExpert:
atomic.StoreInt32(expertiseLevel, int32(ExpertiseLevelExpert))
case ExpertiseLevelNameDeveloper:
atomic.StoreInt32(expertiseLevel, int32(ExpertiseLevelDeveloper))
default:
atomic.StoreInt32(expertiseLevel, int32(ExpertiseLevelUser))
}
}
// GetExpertiseLevel returns the current active expertise level.
func GetExpertiseLevel() uint8 {
return uint8(atomic.LoadInt32(expertiseLevel))
}

View file

@ -81,21 +81,7 @@ func findValue(key string) interface{} {
option.Lock() option.Lock()
defer option.Unlock() defer option.Unlock()
// check if option is active if option.ReleaseLevel <= getReleaseLevel() && option.activeValue != nil {
optionActive := true
switch getReleaseLevel() {
case ReleaseLevelStable:
// In stable, only stable is active
optionActive = option.ReleaseLevel == ReleaseLevelStable
case ReleaseLevelBeta:
// In beta, only stable and beta are active
optionActive = option.ReleaseLevel == ReleaseLevelStable || option.ReleaseLevel == ReleaseLevelBeta
case ReleaseLevelExperimental:
// In experimental, everything is active
optionActive = true
}
if optionActive && option.activeValue != nil {
return option.activeValue return option.activeValue
} }

View file

@ -160,21 +160,21 @@ func TestReleaseLevel(t *testing.T) {
// test option level stable // test option level stable
subsystemOption.ReleaseLevel = ReleaseLevelStable subsystemOption.ReleaseLevel = ReleaseLevelStable
err = SetConfigOption(releaseLevelKey, ReleaseLevelStable) err = SetConfigOption(releaseLevelKey, ReleaseLevelNameStable)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !testSubsystem() { if !testSubsystem() {
t.Error("should be active") t.Error("should be active")
} }
err = SetConfigOption(releaseLevelKey, ReleaseLevelBeta) err = SetConfigOption(releaseLevelKey, ReleaseLevelNameBeta)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !testSubsystem() { if !testSubsystem() {
t.Error("should be active") t.Error("should be active")
} }
err = SetConfigOption(releaseLevelKey, ReleaseLevelExperimental) err = SetConfigOption(releaseLevelKey, ReleaseLevelNameExperimental)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -184,21 +184,21 @@ func TestReleaseLevel(t *testing.T) {
// test option level beta // test option level beta
subsystemOption.ReleaseLevel = ReleaseLevelBeta subsystemOption.ReleaseLevel = ReleaseLevelBeta
err = SetConfigOption(releaseLevelKey, ReleaseLevelStable) err = SetConfigOption(releaseLevelKey, ReleaseLevelNameStable)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if testSubsystem() { if testSubsystem() {
t.Errorf("should be inactive: opt=%s system=%s", subsystemOption.ReleaseLevel, releaseLevel) t.Errorf("should be inactive: opt=%d system=%d", subsystemOption.ReleaseLevel, getReleaseLevel())
} }
err = SetConfigOption(releaseLevelKey, ReleaseLevelBeta) err = SetConfigOption(releaseLevelKey, ReleaseLevelNameBeta)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !testSubsystem() { if !testSubsystem() {
t.Error("should be active") t.Error("should be active")
} }
err = SetConfigOption(releaseLevelKey, ReleaseLevelExperimental) err = SetConfigOption(releaseLevelKey, ReleaseLevelNameExperimental)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -208,21 +208,21 @@ func TestReleaseLevel(t *testing.T) {
// test option level experimental // test option level experimental
subsystemOption.ReleaseLevel = ReleaseLevelExperimental subsystemOption.ReleaseLevel = ReleaseLevelExperimental
err = SetConfigOption(releaseLevelKey, ReleaseLevelStable) err = SetConfigOption(releaseLevelKey, ReleaseLevelNameStable)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if testSubsystem() { if testSubsystem() {
t.Error("should be inactive") t.Error("should be inactive")
} }
err = SetConfigOption(releaseLevelKey, ReleaseLevelBeta) err = SetConfigOption(releaseLevelKey, ReleaseLevelNameBeta)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if testSubsystem() { if testSubsystem() {
t.Error("should be inactive") t.Error("should be inactive")
} }
err = SetConfigOption(releaseLevelKey, ReleaseLevelExperimental) err = SetConfigOption(releaseLevelKey, ReleaseLevelNameExperimental)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -10,7 +10,12 @@ import (
"github.com/safing/portmaster/core/structure" "github.com/safing/portmaster/core/structure"
) )
const (
configChangeEvent = "config change"
)
var ( var (
module *modules.Module
dataRoot *utils.DirStructure dataRoot *utils.DirStructure
) )
@ -22,7 +27,8 @@ func SetDataRoot(root *utils.DirStructure) {
} }
func init() { func init() {
modules.Register("config", prep, start, nil, "base", "database") module = modules.Register("config", prep, start, nil, "base", "database")
module.RegisterEvent(configChangeEvent)
} }
func prep() error { func prep() error {

View file

@ -17,14 +17,6 @@ const (
OptTypeStringArray uint8 = 2 OptTypeStringArray uint8 = 2
OptTypeInt uint8 = 3 OptTypeInt uint8 = 3
OptTypeBool uint8 = 4 OptTypeBool uint8 = 4
ExpertiseLevelUser uint8 = 1
ExpertiseLevelExpert uint8 = 2
ExpertiseLevelDeveloper uint8 = 3
ReleaseLevelStable = "stable"
ReleaseLevelBeta = "beta"
ReleaseLevelExperimental = "experimental"
) )
func getTypeName(t uint8) string { func getTypeName(t uint8) string {
@ -50,9 +42,9 @@ type Option struct {
Key string // in path format: category/sub/key Key string // in path format: category/sub/key
Description string Description string
ReleaseLevel string
ExpertiseLevel uint8
OptType uint8 OptType uint8
ExpertiseLevel uint8
ReleaseLevel uint8
RequiresRestart bool RequiresRestart bool
DefaultValue interface{} DefaultValue interface{}

View file

@ -1,7 +1,6 @@
package config package config
import ( import (
"errors"
"fmt" "fmt"
"regexp" "regexp"
"sync" "sync"
@ -10,21 +9,21 @@ import (
var ( var (
optionsLock sync.RWMutex optionsLock sync.RWMutex
options = make(map[string]*Option) options = make(map[string]*Option)
// ErrIncompleteCall is return when RegisterOption is called with empty mandatory values.
ErrIncompleteCall = errors.New("could not register config option: all fields, except for the validationRegex are mandatory")
) )
// Register registers a new configuration option. // Register registers a new configuration option.
func Register(option *Option) error { func Register(option *Option) error {
if option.Name == "" {
if option.Name == "" || return fmt.Errorf("failed to register option: please set option.Name")
option.Key == "" || }
option.Description == "" || if option.Key == "" {
option.OptType == 0 || return fmt.Errorf("failed to register option: please set option.Key")
option.ExpertiseLevel == 0 || }
option.ReleaseLevel == "" { if option.Description == "" {
return ErrIncompleteCall return fmt.Errorf("failed to register option: please set option.Description")
}
if option.OptType == 0 {
return fmt.Errorf("failed to register option: please set option.OptType")
} }
if option.ValidationRegex != "" { if option.ValidationRegex != "" {
@ -37,7 +36,6 @@ func Register(option *Option) error {
optionsLock.Lock() optionsLock.Lock()
defer optionsLock.Unlock() defer optionsLock.Unlock()
options[option.Key] = option options[option.Key] = option
return nil return nil

View file

@ -1,38 +1,51 @@
// Package config ... (linter fix)
//nolint:dupl
package config package config
import ( import (
"fmt" "fmt"
"sync" "sync/atomic"
) )
// Release Level constants
const ( const (
releaseLevelKey = "core/release_level" ReleaseLevelStable uint8 = 0
ReleaseLevelBeta uint8 = 1
ReleaseLevelExperimental uint8 = 2
ReleaseLevelNameStable = "stable"
ReleaseLevelNameBeta = "beta"
ReleaseLevelNameExperimental = "experimental"
releaseLevelKey = "core/releaseLevel"
) )
var ( var (
releaseLevel = ReleaseLevelStable releaseLevel *int32
releaseLevelLock sync.Mutex
) )
func init() { func init() {
var releaseLevelVal int32
releaseLevel = &releaseLevelVal
registerReleaseLevelOption() registerReleaseLevelOption()
} }
func registerReleaseLevelOption() { func registerReleaseLevelOption() {
err := Register(&Option{ err := Register(&Option{
Name: "Release Selection", Name: "Release Level",
Key: releaseLevelKey, Key: releaseLevelKey,
Description: "Select maturity level of features that should be available", Description: "The Release Level changes which features are available to you. Some beta or experimental features are also available in the stable release channel. Unavailable settings are set to the default value.",
OptType: OptTypeString, OptType: OptTypeString,
ExpertiseLevel: ExpertiseLevelExpert, ExpertiseLevel: ExpertiseLevelExpert,
ReleaseLevel: ReleaseLevelStable, ReleaseLevel: ReleaseLevelStable,
RequiresRestart: false, RequiresRestart: false,
DefaultValue: ReleaseLevelStable, DefaultValue: ReleaseLevelNameStable,
ExternalOptType: "string list", ExternalOptType: "string list",
ValidationRegex: fmt.Sprintf("^(%s|%s|%s)$", ReleaseLevelStable, ReleaseLevelBeta, ReleaseLevelExperimental), ValidationRegex: fmt.Sprintf("^(%s|%s|%s)$", ReleaseLevelNameStable, ReleaseLevelNameBeta, ReleaseLevelNameExperimental),
}) })
if err != nil { if err != nil {
panic(err) panic(err)
@ -41,17 +54,18 @@ func registerReleaseLevelOption() {
func updateReleaseLevel() { func updateReleaseLevel() {
new := findStringValue(releaseLevelKey, "") new := findStringValue(releaseLevelKey, "")
releaseLevelLock.Lock() switch new {
if new == "" { case ReleaseLevelNameStable:
releaseLevel = ReleaseLevelStable atomic.StoreInt32(releaseLevel, int32(ReleaseLevelStable))
} else { case ReleaseLevelNameBeta:
releaseLevel = new atomic.StoreInt32(releaseLevel, int32(ReleaseLevelBeta))
case ReleaseLevelNameExperimental:
atomic.StoreInt32(releaseLevel, int32(ReleaseLevelExperimental))
default:
atomic.StoreInt32(releaseLevel, int32(ReleaseLevelStable))
} }
releaseLevelLock.Unlock()
} }
func getReleaseLevel() string { func getReleaseLevel() uint8 {
releaseLevelLock.Lock() return uint8(atomic.LoadInt32(releaseLevel))
defer releaseLevelLock.Unlock()
return releaseLevel
} }

View file

@ -17,9 +17,6 @@ var (
validityFlag = abool.NewBool(true) validityFlag = abool.NewBool(true)
validityFlagLock sync.RWMutex validityFlagLock sync.RWMutex
changedSignal = make(chan struct{})
changedSignalLock sync.Mutex
) )
func getValidityFlag() *abool.AtomicBool { func getValidityFlag() *abool.AtomicBool {
@ -28,16 +25,10 @@ func getValidityFlag() *abool.AtomicBool {
return validityFlag return validityFlag
} }
// Changed signals if any config option was changed.
func Changed() <-chan struct{} {
changedSignalLock.Lock()
defer changedSignalLock.Unlock()
return changedSignal
}
func signalChanges() { func signalChanges() {
// refetch and save release level // refetch and save release level and expertise level
updateReleaseLevel() updateReleaseLevel()
updateExpertiseLevel()
// reset validity flag // reset validity flag
validityFlagLock.Lock() validityFlagLock.Lock()
@ -45,11 +36,7 @@ func signalChanges() {
validityFlag = abool.NewBool(true) validityFlag = abool.NewBool(true)
validityFlagLock.Unlock() validityFlagLock.Unlock()
// trigger change signal: signal listeners that a config option was changed. module.TriggerEvent(configChangeEvent, nil)
changedSignalLock.Lock()
close(changedSignal)
changedSignal = make(chan struct{})
changedSignalLock.Unlock()
} }
// setConfig sets the (prioritized) user defined config. // setConfig sets the (prioritized) user defined config.

View file

@ -14,6 +14,7 @@ type snippet struct {
} }
// ParseQuery parses a plaintext query. Special characters (that must be escaped with a '\') are: `\()` and any whitespaces. // ParseQuery parses a plaintext query. Special characters (that must be escaped with a '\') are: `\()` and any whitespaces.
//nolint:gocognit
func ParseQuery(query string) (*Query, error) { func ParseQuery(query string) (*Query, error) {
snippets, err := extractSnippets(query) snippets, err := extractSnippets(query)
if err != nil { if err != nil {
@ -195,6 +196,7 @@ func extractSnippets(text string) (snippets []*snippet, err error) {
} }
//nolint:gocognit
func parseAndOr(getSnippet func() (*snippet, error), remainingSnippets func() int, rootCondition bool) (Condition, error) { func parseAndOr(getSnippet func() (*snippet, error), remainingSnippets func() int, rootCondition bool) (Condition, error) {
var isOr = false var isOr = false
var typeSet = false var typeSet = false

View file

@ -14,14 +14,14 @@ var (
) )
// GenCodeSize returns the size of the gencode marshalled byte slice // GenCodeSize returns the size of the gencode marshalled byte slice
func (d *Meta) GenCodeSize() (s int) { func (m *Meta) GenCodeSize() (s int) {
s += 34 s += 34
return return
} }
// GenCodeMarshal gencode marshalls Meta into the given byte array, or a new one if its too small. // GenCodeMarshal gencode marshalls Meta into the given byte array, or a new one if its too small.
func (d *Meta) GenCodeMarshal(buf []byte) ([]byte, error) { func (m *Meta) GenCodeMarshal(buf []byte) ([]byte, error) {
size := d.GenCodeSize() size := m.GenCodeSize()
{ {
if cap(buf) >= size { if cap(buf) >= size {
buf = buf[:size] buf = buf[:size]
@ -33,89 +33,89 @@ func (d *Meta) GenCodeMarshal(buf []byte) ([]byte, error) {
{ {
buf[0+0] = byte(d.Created >> 0) buf[0+0] = byte(m.Created >> 0)
buf[1+0] = byte(d.Created >> 8) buf[1+0] = byte(m.Created >> 8)
buf[2+0] = byte(d.Created >> 16) buf[2+0] = byte(m.Created >> 16)
buf[3+0] = byte(d.Created >> 24) buf[3+0] = byte(m.Created >> 24)
buf[4+0] = byte(d.Created >> 32) buf[4+0] = byte(m.Created >> 32)
buf[5+0] = byte(d.Created >> 40) buf[5+0] = byte(m.Created >> 40)
buf[6+0] = byte(d.Created >> 48) buf[6+0] = byte(m.Created >> 48)
buf[7+0] = byte(d.Created >> 56) buf[7+0] = byte(m.Created >> 56)
} }
{ {
buf[0+8] = byte(d.Modified >> 0) buf[0+8] = byte(m.Modified >> 0)
buf[1+8] = byte(d.Modified >> 8) buf[1+8] = byte(m.Modified >> 8)
buf[2+8] = byte(d.Modified >> 16) buf[2+8] = byte(m.Modified >> 16)
buf[3+8] = byte(d.Modified >> 24) buf[3+8] = byte(m.Modified >> 24)
buf[4+8] = byte(d.Modified >> 32) buf[4+8] = byte(m.Modified >> 32)
buf[5+8] = byte(d.Modified >> 40) buf[5+8] = byte(m.Modified >> 40)
buf[6+8] = byte(d.Modified >> 48) buf[6+8] = byte(m.Modified >> 48)
buf[7+8] = byte(d.Modified >> 56) buf[7+8] = byte(m.Modified >> 56)
} }
{ {
buf[0+16] = byte(d.Expires >> 0) buf[0+16] = byte(m.Expires >> 0)
buf[1+16] = byte(d.Expires >> 8) buf[1+16] = byte(m.Expires >> 8)
buf[2+16] = byte(d.Expires >> 16) buf[2+16] = byte(m.Expires >> 16)
buf[3+16] = byte(d.Expires >> 24) buf[3+16] = byte(m.Expires >> 24)
buf[4+16] = byte(d.Expires >> 32) buf[4+16] = byte(m.Expires >> 32)
buf[5+16] = byte(d.Expires >> 40) buf[5+16] = byte(m.Expires >> 40)
buf[6+16] = byte(d.Expires >> 48) buf[6+16] = byte(m.Expires >> 48)
buf[7+16] = byte(d.Expires >> 56) buf[7+16] = byte(m.Expires >> 56)
} }
{ {
buf[0+24] = byte(d.Deleted >> 0) buf[0+24] = byte(m.Deleted >> 0)
buf[1+24] = byte(d.Deleted >> 8) buf[1+24] = byte(m.Deleted >> 8)
buf[2+24] = byte(d.Deleted >> 16) buf[2+24] = byte(m.Deleted >> 16)
buf[3+24] = byte(d.Deleted >> 24) buf[3+24] = byte(m.Deleted >> 24)
buf[4+24] = byte(d.Deleted >> 32) buf[4+24] = byte(m.Deleted >> 32)
buf[5+24] = byte(d.Deleted >> 40) buf[5+24] = byte(m.Deleted >> 40)
buf[6+24] = byte(d.Deleted >> 48) buf[6+24] = byte(m.Deleted >> 48)
buf[7+24] = byte(d.Deleted >> 56) buf[7+24] = byte(m.Deleted >> 56)
} }
{ {
if d.secret { if m.secret {
buf[32] = 1 buf[32] = 1
} else { } else {
buf[32] = 0 buf[32] = 0
} }
} }
{ {
if d.cronjewel { if m.cronjewel {
buf[33] = 1 buf[33] = 1
} else { } else {
buf[33] = 0 buf[33] = 0
@ -125,38 +125,38 @@ func (d *Meta) GenCodeMarshal(buf []byte) ([]byte, error) {
} }
// GenCodeUnmarshal gencode unmarshalls Meta and returns the bytes read. // GenCodeUnmarshal gencode unmarshalls Meta and returns the bytes read.
func (d *Meta) GenCodeUnmarshal(buf []byte) (uint64, error) { func (m *Meta) GenCodeUnmarshal(buf []byte) (uint64, error) {
if len(buf) < d.GenCodeSize() { if len(buf) < m.GenCodeSize() {
return 0, fmt.Errorf("insufficient data: got %d out of %d bytes", len(buf), d.GenCodeSize()) return 0, fmt.Errorf("insufficient data: got %d out of %d bytes", len(buf), m.GenCodeSize())
} }
i := uint64(0) i := uint64(0)
{ {
d.Created = 0 | (int64(buf[0+0]) << 0) | (int64(buf[1+0]) << 8) | (int64(buf[2+0]) << 16) | (int64(buf[3+0]) << 24) | (int64(buf[4+0]) << 32) | (int64(buf[5+0]) << 40) | (int64(buf[6+0]) << 48) | (int64(buf[7+0]) << 56) m.Created = 0 | (int64(buf[0+0]) << 0) | (int64(buf[1+0]) << 8) | (int64(buf[2+0]) << 16) | (int64(buf[3+0]) << 24) | (int64(buf[4+0]) << 32) | (int64(buf[5+0]) << 40) | (int64(buf[6+0]) << 48) | (int64(buf[7+0]) << 56)
} }
{ {
d.Modified = 0 | (int64(buf[0+8]) << 0) | (int64(buf[1+8]) << 8) | (int64(buf[2+8]) << 16) | (int64(buf[3+8]) << 24) | (int64(buf[4+8]) << 32) | (int64(buf[5+8]) << 40) | (int64(buf[6+8]) << 48) | (int64(buf[7+8]) << 56) m.Modified = 0 | (int64(buf[0+8]) << 0) | (int64(buf[1+8]) << 8) | (int64(buf[2+8]) << 16) | (int64(buf[3+8]) << 24) | (int64(buf[4+8]) << 32) | (int64(buf[5+8]) << 40) | (int64(buf[6+8]) << 48) | (int64(buf[7+8]) << 56)
} }
{ {
d.Expires = 0 | (int64(buf[0+16]) << 0) | (int64(buf[1+16]) << 8) | (int64(buf[2+16]) << 16) | (int64(buf[3+16]) << 24) | (int64(buf[4+16]) << 32) | (int64(buf[5+16]) << 40) | (int64(buf[6+16]) << 48) | (int64(buf[7+16]) << 56) m.Expires = 0 | (int64(buf[0+16]) << 0) | (int64(buf[1+16]) << 8) | (int64(buf[2+16]) << 16) | (int64(buf[3+16]) << 24) | (int64(buf[4+16]) << 32) | (int64(buf[5+16]) << 40) | (int64(buf[6+16]) << 48) | (int64(buf[7+16]) << 56)
} }
{ {
d.Deleted = 0 | (int64(buf[0+24]) << 0) | (int64(buf[1+24]) << 8) | (int64(buf[2+24]) << 16) | (int64(buf[3+24]) << 24) | (int64(buf[4+24]) << 32) | (int64(buf[5+24]) << 40) | (int64(buf[6+24]) << 48) | (int64(buf[7+24]) << 56) m.Deleted = 0 | (int64(buf[0+24]) << 0) | (int64(buf[1+24]) << 8) | (int64(buf[2+24]) << 16) | (int64(buf[3+24]) << 24) | (int64(buf[4+24]) << 32) | (int64(buf[5+24]) << 40) | (int64(buf[6+24]) << 48) | (int64(buf[7+24]) << 56)
} }
{ {
d.secret = buf[32] == 1 m.secret = buf[32] == 1
} }
{ {
d.cronjewel = buf[33] == 1 m.cronjewel = buf[33] == 1
} }
return i + 34, nil return i + 34, nil
} }

View file

@ -139,6 +139,7 @@ func saveRegistry(lock bool) error {
} }
// write file // write file
// FIXME: write atomically (best effort)
filePath := path.Join(rootStructure.Path, registryFileName) filePath := path.Join(rootStructure.Path, registryFileName)
return ioutil.WriteFile(filePath, data, 0600) return ioutil.WriteFile(filePath, data, 0600)
} }

View file

@ -118,6 +118,7 @@ func (b *Badger) Query(q *query.Query, local, internal bool) (*iterator.Iterator
return queryIter, nil return queryIter, nil
} }
//nolint:gocognit
func (b *Badger) queryExecutor(queryIter *iterator.Iterator, q *query.Query, local, internal bool) { func (b *Badger) queryExecutor(queryIter *iterator.Iterator, q *query.Query, local, internal bool) {
err := b.db.View(func(txn *badger.Txn) error { err := b.db.View(func(txn *badger.Txn) error {
it := txn.NewIterator(badger.DefaultIteratorOptions) it := txn.NewIterator(badger.DefaultIteratorOptions)

View file

@ -31,7 +31,7 @@ type TestRecord struct {
B bool B bool
} }
func TestBadger(t *testing.T) { func TestBBolt(t *testing.T) {
testDir, err := ioutil.TempDir("", "testing-") testDir, err := ioutil.TempDir("", "testing-")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View file

@ -0,0 +1,140 @@
package hashmap
import (
"errors"
"fmt"
"sync"
"time"
"github.com/safing/portbase/database/iterator"
"github.com/safing/portbase/database/query"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/database/storage"
)
// HashMap storage.
type HashMap struct {
name string
db map[string]record.Record
dbLock sync.RWMutex
}
func init() {
_ = storage.Register("hashmap", NewHashMap)
}
// NewHashMap creates a hashmap database.
func NewHashMap(name, location string) (storage.Interface, error) {
return &HashMap{
name: name,
db: make(map[string]record.Record),
}, nil
}
// Get returns a database record.
func (hm *HashMap) Get(key string) (record.Record, error) {
hm.dbLock.RLock()
defer hm.dbLock.RUnlock()
r, ok := hm.db[key]
if !ok {
return nil, storage.ErrNotFound
}
return r, nil
}
// Put stores a record in the database.
func (hm *HashMap) Put(r record.Record) error {
hm.dbLock.Lock()
defer hm.dbLock.Unlock()
hm.db[r.DatabaseKey()] = r
return nil
}
// Delete deletes a record from the database.
func (hm *HashMap) Delete(key string) error {
hm.dbLock.Lock()
defer hm.dbLock.Unlock()
delete(hm.db, key)
return nil
}
// Query returns a an iterator for the supplied query.
func (hm *HashMap) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
_, err := q.Check()
if err != nil {
return nil, fmt.Errorf("invalid query: %s", err)
}
queryIter := iterator.New()
go hm.queryExecutor(queryIter, q, local, internal)
return queryIter, nil
}
func (hm *HashMap) queryExecutor(queryIter *iterator.Iterator, q *query.Query, local, internal bool) {
hm.dbLock.RLock()
defer hm.dbLock.RUnlock()
var err error
mapLoop:
for key, record := range hm.db {
switch {
case !q.MatchesKey(key):
continue
case !q.MatchesRecord(record):
continue
case !record.Meta().CheckValidity():
continue
case !record.Meta().CheckPermission(local, internal):
continue
}
select {
case <-queryIter.Done:
break mapLoop
case queryIter.Next <- record:
default:
select {
case <-queryIter.Done:
break mapLoop
case queryIter.Next <- record:
case <-time.After(1 * time.Second):
err = errors.New("query timeout")
break mapLoop
}
}
}
queryIter.Finish(err)
}
// ReadOnly returns whether the database is read only.
func (hm *HashMap) ReadOnly() bool {
return false
}
// Injected returns whether the database is injected.
func (hm *HashMap) Injected() bool {
return false
}
// Maintain runs a light maintenance operation on the database.
func (hm *HashMap) Maintain() error {
return nil
}
// MaintainThorough runs a thorough maintenance operation on the database.
func (hm *HashMap) MaintainThorough() (err error) {
return nil
}
// Shutdown shuts down the database.
func (hm *HashMap) Shutdown() error {
return nil
}

View file

@ -0,0 +1,147 @@
//nolint:unparam,maligned
package hashmap
import (
"reflect"
"sync"
"testing"
"github.com/safing/portbase/database/query"
"github.com/safing/portbase/database/record"
)
type TestRecord struct {
record.Base
sync.Mutex
S string
I int
I8 int8
I16 int16
I32 int32
I64 int64
UI uint
UI8 uint8
UI16 uint16
UI32 uint32
UI64 uint64
F32 float32
F64 float64
B bool
}
func TestHashMap(t *testing.T) {
// start
db, err := NewHashMap("test", "")
if err != nil {
t.Fatal(err)
}
a := &TestRecord{
S: "banana",
I: 42,
I8: 42,
I16: 42,
I32: 42,
I64: 42,
UI: 42,
UI8: 42,
UI16: 42,
UI32: 42,
UI64: 42,
F32: 42.42,
F64: 42.42,
B: true,
}
a.SetMeta(&record.Meta{})
a.Meta().Update()
a.SetKey("test:A")
// put record
err = db.Put(a)
if err != nil {
t.Fatal(err)
}
// get and compare
a1, err := db.Get("A")
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(a, a1) {
t.Fatalf("mismatch, got %v", a1)
}
// setup query test records
qA := &TestRecord{}
qA.SetKey("test:path/to/A")
qA.CreateMeta()
qB := &TestRecord{}
qB.SetKey("test:path/to/B")
qB.CreateMeta()
qC := &TestRecord{}
qC.SetKey("test:path/to/C")
qC.CreateMeta()
qZ := &TestRecord{}
qZ.SetKey("test:z")
qZ.CreateMeta()
// put
err = db.Put(qA)
if err == nil {
err = db.Put(qB)
}
if err == nil {
err = db.Put(qC)
}
if err == nil {
err = db.Put(qZ)
}
if err != nil {
t.Fatal(err)
}
// test query
q := query.New("test:path/to/").MustBeValid()
it, err := db.Query(q, true, true)
if err != nil {
t.Fatal(err)
}
cnt := 0
for range it.Next {
cnt++
}
if it.Err() != nil {
t.Fatal(it.Err())
}
if cnt != 3 {
t.Fatalf("unexpected query result count: %d", cnt)
}
// delete
err = db.Delete("A")
if err != nil {
t.Fatal(err)
}
// check if its gone
_, err = db.Get("A")
if err == nil {
t.Fatal("should fail")
}
// maintenance
err = db.Maintain()
if err != nil {
t.Fatal(err)
}
err = db.MaintainThorough()
if err != nil {
t.Fatal(err)
}
// shutdown
err = db.Shutdown()
if err != nil {
t.Fatal(err)
}
}

View file

@ -1,4 +1,4 @@
//nolint:maligned,unparam,gocyclo //nolint:maligned,unparam,gocyclo,gocognit
package dsd package dsd
import ( import (

View file

@ -1,4 +1,4 @@
//nolint:nakedret,unconvert //nolint:nakedret,unconvert,gocognit
package dsd package dsd
import ( import (

View file

@ -1,7 +1,9 @@
package varint package varint
import "errors" import (
import "encoding/binary" "encoding/binary"
"errors"
)
// Pack8 packs a uint8 into a VarInt. // Pack8 packs a uint8 into a VarInt.
func Pack8(n uint8) []byte { func Pack8(n uint8) []byte {

View file

@ -1,3 +1,4 @@
//nolint:gocognit
package varint package varint
import ( import (

View file

@ -15,7 +15,7 @@ var (
) )
func init() { func init() {
modules.Register("info", prep, nil, nil, "base") modules.Register("info", prep, nil, nil)
flag.BoolVar(&showVersion, "version", false, "show version and exit") flag.BoolVar(&showVersion, "version", false, "show version and exit")
} }
@ -35,8 +35,10 @@ func prep() error {
// CheckVersion checks if the metadata is ok. // CheckVersion checks if the metadata is ok.
func CheckVersion() error { func CheckVersion() error {
if !strings.HasSuffix(os.Args[0], ".test") { if !strings.HasSuffix(os.Args[0], ".test") {
if name == "[NAME]" || if name == "[NAME]" {
version == "[version unknown]" || return errors.New("must call SetInfo() before calling CheckVersion()")
}
if version == "[version unknown]" ||
commit == "[commit unknown]" || commit == "[commit unknown]" ||
license == "[license unknown]" || license == "[license unknown]" ||
buildOptions == "[options unknown]" || buildOptions == "[options unknown]" ||

View file

@ -75,15 +75,21 @@ func log(level Severity, msg string, tracer *ContextTracer) {
select { select {
case logBuffer <- log: case logBuffer <- log:
default: default:
forceEmptyingOfBuffer <- struct{}{} forceEmptyingLoop:
logBuffer <- log // force empty buffer until we can send to it
for {
select {
case forceEmptyingOfBuffer <- struct{}{}:
case logBuffer <- log:
break forceEmptyingLoop
}
}
} }
// wake up writer if necessary // wake up writer if necessary
if logsWaitingFlag.SetToIf(false, true) { if logsWaitingFlag.SetToIf(false, true) {
logsWaiting <- struct{}{} logsWaiting <- struct{}{}
} }
} }
func fastcheck(level Severity) bool { func fastcheck(level Severity) bool {

View file

@ -70,7 +70,7 @@ const (
var ( var (
logBuffer chan *logLine logBuffer chan *logLine
forceEmptyingOfBuffer chan struct{} forceEmptyingOfBuffer = make(chan struct{})
logLevelInt = uint32(3) logLevelInt = uint32(3)
logLevel = &logLevelInt logLevel = &logLevelInt
@ -79,7 +79,7 @@ var (
pkgLevels = make(map[string]Severity) pkgLevels = make(map[string]Severity)
pkgLevelsLock sync.Mutex pkgLevelsLock sync.Mutex
logsWaiting = make(chan struct{}, 1) logsWaiting = make(chan struct{}, 4)
logsWaitingFlag = abool.NewBool(false) logsWaitingFlag = abool.NewBool(false)
shutdownSignal = make(chan struct{}) shutdownSignal = make(chan struct{})
@ -90,7 +90,7 @@ var (
startedSignal = make(chan struct{}) startedSignal = make(chan struct{})
) )
// SetPkgLevels sets individual log levels for packages. // SetPkgLevels sets individual log levels for packages. Only effective after Start().
func SetPkgLevels(levels map[string]Severity) { func SetPkgLevels(levels map[string]Severity) {
pkgLevelsLock.Lock() pkgLevelsLock.Lock()
pkgLevels = levels pkgLevels = levels
@ -103,7 +103,7 @@ func UnSetPkgLevels() {
pkgLevelsActive.UnSet() pkgLevelsActive.UnSet()
} }
// SetLogLevel sets a new log level. // SetLogLevel sets a new log level. Only effective after Start().
func SetLogLevel(level Severity) { func SetLogLevel(level Severity) {
atomic.StoreUint32(logLevel, uint32(level)) atomic.StoreUint32(logLevel, uint32(level))
} }
@ -135,11 +135,10 @@ func Start() (err error) {
} }
logBuffer = make(chan *logLine, 1024) logBuffer = make(chan *logLine, 1024)
forceEmptyingOfBuffer = make(chan struct{}, 16)
initialLogLevel := ParseLevel(logLevelFlag) initialLogLevel := ParseLevel(logLevelFlag)
if initialLogLevel > 0 { if initialLogLevel > 0 {
atomic.StoreUint32(logLevel, uint32(initialLogLevel)) SetLogLevel(initialLogLevel)
} else { } else {
err = fmt.Errorf("log warning: invalid log level \"%s\", falling back to level info", logLevelFlag) err = fmt.Errorf("log warning: invalid log level \"%s\", falling back to level info", logLevelFlag)
fmt.Fprintf(os.Stderr, "%s\n", err.Error()) fmt.Fprintf(os.Stderr, "%s\n", err.Error())

View file

@ -2,6 +2,8 @@ package log
import ( import (
"fmt" "fmt"
"os"
"runtime/debug"
"time" "time"
) )
@ -39,37 +41,73 @@ func writeLine(line *logLine, duplicates uint64) {
} }
func startWriter() { func startWriter() {
shutdownWaitGroup.Add(1)
fmt.Println(fmt.Sprintf("%s%s %s BOF%s", InfoLevel.color(), time.Now().Format(timeFormat), rightArrow, endColor())) fmt.Println(fmt.Sprintf("%s%s %s BOF%s", InfoLevel.color(), time.Now().Format(timeFormat), rightArrow, endColor()))
go writer()
shutdownWaitGroup.Add(1)
go writerManager()
} }
func writer() { func writerManager() {
var line *logLine
var lastLine *logLine
var duplicates uint64
defer shutdownWaitGroup.Done() defer shutdownWaitGroup.Done()
for {
err := writer()
if err != nil {
Errorf("log: writer failed: %s", err)
} else {
return
}
}
}
func writer() (err error) {
defer func() {
// recover from panic
panicVal := recover()
if panicVal != nil {
err = fmt.Errorf("%s", panicVal)
// write stack to stderr
fmt.Fprintf(
os.Stderr,
`===== Error Report =====
Message: %s
StackTrace:
%s
===== End of Report =====
`,
err,
string(debug.Stack()),
)
}
}()
var currentLine *logLine
var nextLine *logLine
var duplicates uint64
for { for {
// reset // reset
line = nil currentLine = nil
lastLine = nil //nolint:ineffassign // only ineffectual in first loop nextLine = nil
duplicates = 0 duplicates = 0
// wait until logs need to be processed // wait until logs need to be processed
select { select {
case <-logsWaiting: case <-logsWaiting: // normal process
logsWaitingFlag.UnSet() logsWaitingFlag.UnSet()
case <-shutdownSignal: case <-forceEmptyingOfBuffer: // log buffer is full!
case <-shutdownSignal: // shutting down
finalizeWriting() finalizeWriting()
return return
} }
// wait for timeslot to log, or when buffer is full // wait for timeslot to log
select { select {
case <-writeTrigger: case <-writeTrigger: // normal process
case <-forceEmptyingOfBuffer: case <-forceEmptyingOfBuffer: // log buffer is full!
case <-shutdownSignal: case <-shutdownSignal: // shutting down
finalizeWriting() finalizeWriting()
return return
} }
@ -78,38 +116,41 @@ func writer() {
writeLoop: writeLoop:
for { for {
select { select {
case line = <-logBuffer: case nextLine = <-logBuffer:
// first line we process, just assign to currentLine
// look-ahead for deduplication (best effort) if currentLine == nil {
dedupLoop: currentLine = nextLine
for { continue writeLoop
// check if there is another line waiting
select {
case nextLine := <-logBuffer:
lastLine = line
line = nextLine
default:
break dedupLoop
}
// deduplication
if !line.Equal(lastLine) {
// no duplicate
break dedupLoop
}
// duplicate
duplicates++
} }
// write actual line // we now have currentLine and nextLine
writeLine(line, duplicates)
// if currentLine and nextLine are equal, do not print, just increase counter and continue
if nextLine.Equal(currentLine) {
duplicates++
continue writeLoop
}
// if currentLine and line are _not_ equal, output currentLine
writeLine(currentLine, duplicates)
// reset duplicate counter
duplicates = 0 duplicates = 0
// set new currentLine
currentLine = nextLine
default: default:
break writeLoop break writeLoop
} }
} }
// write final line
if currentLine != nil {
writeLine(currentLine, duplicates)
}
// reset state
currentLine = nil //nolint:ineffassign
nextLine = nil
duplicates = 0 //nolint:ineffassign
// back down a little // back down a little
select { select {
case <-time.After(10 * time.Millisecond): case <-time.After(10 * time.Millisecond):

View file

@ -83,7 +83,7 @@ func Tracer(ctx context.Context) *ContextTracer {
// Submit collected logs on the context for further processing/outputting. Does nothing if called on a nil ContextTracer. // Submit collected logs on the context for further processing/outputting. Does nothing if called on a nil ContextTracer.
func (tracer *ContextTracer) Submit() { func (tracer *ContextTracer) Submit() {
if tracer != nil { if tracer == nil {
return return
} }
@ -119,15 +119,21 @@ func (tracer *ContextTracer) Submit() {
select { select {
case logBuffer <- log: case logBuffer <- log:
default: default:
forceEmptyingOfBuffer <- struct{}{} forceEmptyingLoop:
logBuffer <- log // force empty buffer until we can send to it
for {
select {
case forceEmptyingOfBuffer <- struct{}{}:
case logBuffer <- log:
break forceEmptyingLoop
}
}
} }
// wake up writer if necessary // wake up writer if necessary
if logsWaitingFlag.SetToIf(false, true) { if logsWaitingFlag.SetToIf(false, true) {
logsWaiting <- struct{}{} logsWaiting <- struct{}{}
} }
} }
func (tracer *ContextTracer) log(level Severity, msg string) { func (tracer *ContextTracer) log(level Severity, msg string) {

View file

@ -2,11 +2,16 @@ package modules
import ( import (
"fmt" "fmt"
"os"
"runtime/debug" "runtime/debug"
"sync"
"time"
) )
var ( var (
errorReportingChannel chan *ModuleError errorReportingChannel chan *ModuleError
reportToStdErr = true
reportingLock sync.RWMutex
) )
// ModuleError wraps a panic, error or message into an error that can be reported. // ModuleError wraps a panic, error or message into an error that can be reported.
@ -64,12 +69,43 @@ func (me *ModuleError) Error() string {
// Report reports the error through the configured reporting channel. // Report reports the error through the configured reporting channel.
func (me *ModuleError) Report() { func (me *ModuleError) Report() {
reportingLock.RLock()
defer reportingLock.RUnlock()
if errorReportingChannel != nil { if errorReportingChannel != nil {
select { select {
case errorReportingChannel <- me: case errorReportingChannel <- me:
default: default:
} }
} }
if reportToStdErr {
// default to writing to stderr
fmt.Fprintf(
os.Stderr,
`===== Error Report =====
Message: %s
Timestamp: %s
ModuleName: %s
TaskName: %s
TaskType: %s
Severity: %s
PanicValue: %s
StackTrace:
%s
===== End of Report =====
`,
me.Message,
time.Now(),
me.ModuleName,
me.TaskName,
me.TaskType,
me.Severity,
me.PanicValue,
me.StackTrace,
)
}
} }
// IsPanic returns whether the given error is a wrapped panic by the modules package and additionally returns it, if true. // IsPanic returns whether the given error is a wrapped panic by the modules package and additionally returns it, if true.
@ -84,7 +120,16 @@ func IsPanic(err error) (bool, *ModuleError) {
// SetErrorReportingChannel sets the channel to report module errors through. By default only panics are reported, all other errors need to be manually wrapped into a *ModuleError and reported. // SetErrorReportingChannel sets the channel to report module errors through. By default only panics are reported, all other errors need to be manually wrapped into a *ModuleError and reported.
func SetErrorReportingChannel(reportingChannel chan *ModuleError) { func SetErrorReportingChannel(reportingChannel chan *ModuleError) {
if errorReportingChannel == nil { reportingLock.Lock()
errorReportingChannel = reportingChannel defer reportingLock.Unlock()
}
errorReportingChannel = reportingChannel
}
// SetStdErrReporting controls error reporting to stderr.
func SetStdErrReporting(on bool) {
reportingLock.Lock()
defer reportingLock.Unlock()
reportToStdErr = on
} }

103
modules/events.go Normal file
View file

@ -0,0 +1,103 @@
package modules
import (
"context"
"fmt"
"github.com/safing/portbase/log"
)
type eventHookFn func(context.Context, interface{}) error
type eventHook struct {
description string
hookingModule *Module
hookFn eventHookFn
}
// TriggerEvent executes all hook functions registered to the specified event.
func (m *Module) TriggerEvent(event string, data interface{}) {
go m.processEventTrigger(event, data)
}
func (m *Module) processEventTrigger(event string, data interface{}) {
m.eventHooksLock.RLock()
defer m.eventHooksLock.RUnlock()
hooks, ok := m.eventHooks[event]
if !ok {
log.Warningf(`%s: tried to trigger non-existent event "%s"`, m.Name, event)
return
}
for _, hook := range hooks {
if !hook.hookingModule.ShutdownInProgress() {
go m.runEventHook(hook, event, data)
}
}
}
func (m *Module) runEventHook(hook *eventHook, event string, data interface{}) {
if !hook.hookingModule.Started.IsSet() {
// target module has not yet fully started, wait until start is complete
select {
case <-startCompleteSignal:
case <-shutdownSignal:
return
}
}
err := hook.hookingModule.RunWorker(
fmt.Sprintf("event hook %s/%s -> %s/%s", m.Name, event, hook.hookingModule.Name, hook.description),
func(ctx context.Context) error {
return hook.hookFn(ctx, data)
},
)
if err != nil {
log.Warningf("%s: failed to execute event hook %s/%s -> %s/%s: %s", hook.hookingModule.Name, m.Name, event, hook.hookingModule.Name, hook.description, err)
}
}
// RegisterEvent registers a new event to allow for registering hooks.
func (m *Module) RegisterEvent(event string) {
m.eventHooksLock.Lock()
defer m.eventHooksLock.Unlock()
_, ok := m.eventHooks[event]
if !ok {
m.eventHooks[event] = make([]*eventHook, 0, 1)
}
}
// RegisterEventHook registers a hook function with (another) modules' event. Whenever a hook is triggered and the receiving module has not yet fully started, hook execution will be delayed until all modules completed starting.
func (m *Module) RegisterEventHook(module string, event string, description string, fn func(context.Context, interface{}) error) error {
// get target module
var eventModule *Module
if module == m.Name {
eventModule = m
} else {
var ok bool
modulesLock.RLock()
eventModule, ok = modules[module]
modulesLock.RUnlock()
if !ok {
return fmt.Errorf(`module "%s" does not exist`, module)
}
}
// get target event
eventModule.eventHooksLock.Lock()
defer eventModule.eventHooksLock.Unlock()
hooks, ok := eventModule.eventHooks[event]
if !ok {
return fmt.Errorf(`event "%s/%s" does not exist`, eventModule.Name, event)
}
// add hook
eventModule.eventHooks[event] = append(hooks, &eventHook{
description: description,
hookingModule: m,
hookFn: fn,
})
return nil
}

16
modules/exit.go Normal file
View file

@ -0,0 +1,16 @@
package modules
var (
exitStatusCode int
)
// SetExitStatusCode sets the exit code that the program shell return to the host after shutdown.
func SetExitStatusCode(n int) {
exitStatusCode = n
}
// GetExitStatusCode waits for the shutdown to complete and then returns the exit code
func GetExitStatusCode() int {
<-shutdownCompleteSignal
return exitStatusCode
}

View file

@ -45,14 +45,54 @@ func SetMaxConcurrentMicroTasks(n int) {
} }
} }
// StartMicroTask starts a new MicroTask with high priority. It will start immediately. The given function will be executed and panics caught. The supplied name should be a constant - the variable should never change as it won't be copied. // StartMicroTask starts a new MicroTask with high priority. It will start immediately. The call starts a new goroutine and returns immediately. The given function will be executed and panics caught. The supplied name must not be changed.
func (m *Module) StartMicroTask(name *string, fn func(context.Context) error) error { func (m *Module) StartMicroTask(name *string, fn func(context.Context) error) {
atomic.AddInt32(microTasks, 1) go func() {
err := m.RunMicroTask(name, fn)
if err != nil {
log.Warningf("%s: microtask %s failed: %s", m.Name, *name, err)
}
}()
}
// StartMediumPriorityMicroTask starts a new MicroTask with medium priority. The call starts a new goroutine and returns immediately. It will wait until a slot becomes available (max 3 seconds). The given function will be executed and panics caught. The supplied name must not be changed.
func (m *Module) StartMediumPriorityMicroTask(name *string, fn func(context.Context) error) {
go func() {
err := m.RunMediumPriorityMicroTask(name, fn)
if err != nil {
log.Warningf("%s: microtask %s failed: %s", m.Name, *name, err)
}
}()
}
// StartLowPriorityMicroTask starts a new MicroTask with low priority. The call starts a new goroutine and returns immediately. It will wait until a slot becomes available (max 15 seconds). The given function will be executed and panics caught. The supplied name must not be changed.
func (m *Module) StartLowPriorityMicroTask(name *string, fn func(context.Context) error) {
go func() {
err := m.RunLowPriorityMicroTask(name, fn)
if err != nil {
log.Warningf("%s: microtask %s failed: %s", m.Name, *name, err)
}
}()
}
// RunMicroTask runs a new MicroTask with high priority. It will start immediately. The call blocks until finished. The given function will be executed and panics caught. The supplied name must not be changed.
func (m *Module) RunMicroTask(name *string, fn func(context.Context) error) error {
if m == nil {
log.Errorf(`modules: cannot start microtask "%s" with nil module`, *name)
return errNoModule
}
atomic.AddInt32(microTasks, 1) // increase global counter here, as high priority tasks are not started by the scheduler, where this counter is usually increased
return m.runMicroTask(name, fn) return m.runMicroTask(name, fn)
} }
// StartMediumPriorityMicroTask starts a new MicroTask with medium priority. It will wait until given a go (max 3 seconds). The given function will be executed and panics caught. The supplied name should be a constant - the variable should never change as it won't be copied. // RunMediumPriorityMicroTask runs a new MicroTask with medium priority. It will wait until a slot becomes available (max 3 seconds). The call blocks until finished. The given function will be executed and panics caught. The supplied name must not be changed.
func (m *Module) StartMediumPriorityMicroTask(name *string, fn func(context.Context) error) error { func (m *Module) RunMediumPriorityMicroTask(name *string, fn func(context.Context) error) error {
if m == nil {
log.Errorf(`modules: cannot start microtask "%s" with nil module`, *name)
return errNoModule
}
// check if we can go immediately // check if we can go immediately
select { select {
case <-mediumPriorityClearance: case <-mediumPriorityClearance:
@ -66,8 +106,13 @@ func (m *Module) StartMediumPriorityMicroTask(name *string, fn func(context.Cont
return m.runMicroTask(name, fn) return m.runMicroTask(name, fn)
} }
// StartLowPriorityMicroTask starts a new MicroTask with low priority. It will wait until given a go (max 15 seconds). The given function will be executed and panics caught. The supplied name should be a constant - the variable should never change as it won't be copied. // RunLowPriorityMicroTask runs a new MicroTask with low priority. It will wait until a slot becomes available (max 15 seconds). The call blocks until finished. The given function will be executed and panics caught. The supplied name must not be changed.
func (m *Module) StartLowPriorityMicroTask(name *string, fn func(context.Context) error) error { func (m *Module) RunLowPriorityMicroTask(name *string, fn func(context.Context) error) error {
if m == nil {
log.Errorf(`modules: cannot start microtask "%s" with nil module`, *name)
return errNoModule
}
// check if we can go immediately // check if we can go immediately
select { select {
case <-lowPriorityClearance: case <-lowPriorityClearance:
@ -94,7 +139,7 @@ func (m *Module) runMicroTask(name *string, fn func(context.Context) error) (err
if panicVal != nil { if panicVal != nil {
me := m.NewPanicError(*name, "microtask", panicVal) me := m.NewPanicError(*name, "microtask", panicVal)
me.Report() me.Report()
log.Errorf("%s: microtask %s panicked: %s\n%s", m.Name, *name, panicVal, me.StackTrace) log.Errorf("%s: microtask %s panicked: %s", m.Name, *name, panicVal)
err = me err = me
} }
@ -150,7 +195,10 @@ microTaskManageLoop:
atomic.AddInt32(microTasks, 1) atomic.AddInt32(microTasks, 1)
} else { } else {
// wait for signal that a task was completed // wait for signal that a task was completed
<-microTaskFinished select {
case <-microTaskFinished:
case <-time.After(1 * time.Second):
}
} }
} }

View file

@ -42,7 +42,7 @@ func TestMicroTaskWaiting(t *testing.T) {
go func() { go func() {
defer mtwWaitGroup.Done() defer mtwWaitGroup.Done()
// exec at slot 1 // exec at slot 1
_ = mtModule.StartMicroTask(&mtTestName, func(ctx context.Context) error { _ = mtModule.RunMicroTask(&mtTestName, func(ctx context.Context) error {
mtwOutputChannel <- "1" // slot 1 mtwOutputChannel <- "1" // slot 1
time.Sleep(mtwSleepDuration * 5) time.Sleep(mtwSleepDuration * 5)
mtwOutputChannel <- "2" // slot 5 mtwOutputChannel <- "2" // slot 5
@ -53,7 +53,7 @@ func TestMicroTaskWaiting(t *testing.T) {
time.Sleep(mtwSleepDuration * 1) time.Sleep(mtwSleepDuration * 1)
// clear clearances // clear clearances
_ = mtModule.StartMicroTask(&mtTestName, func(ctx context.Context) error { _ = mtModule.RunMicroTask(&mtTestName, func(ctx context.Context) error {
return nil return nil
}) })
@ -61,7 +61,7 @@ func TestMicroTaskWaiting(t *testing.T) {
go func() { go func() {
defer mtwWaitGroup.Done() defer mtwWaitGroup.Done()
// exec at slot 2 // exec at slot 2
_ = mtModule.StartLowPriorityMicroTask(&mtTestName, func(ctx context.Context) error { _ = mtModule.RunLowPriorityMicroTask(&mtTestName, func(ctx context.Context) error {
mtwOutputChannel <- "7" // slot 16 mtwOutputChannel <- "7" // slot 16
return nil return nil
}) })
@ -74,7 +74,7 @@ func TestMicroTaskWaiting(t *testing.T) {
defer mtwWaitGroup.Done() defer mtwWaitGroup.Done()
time.Sleep(mtwSleepDuration * 8) time.Sleep(mtwSleepDuration * 8)
// exec at slot 10 // exec at slot 10
_ = mtModule.StartMicroTask(&mtTestName, func(ctx context.Context) error { _ = mtModule.RunMicroTask(&mtTestName, func(ctx context.Context) error {
mtwOutputChannel <- "4" // slot 10 mtwOutputChannel <- "4" // slot 10
time.Sleep(mtwSleepDuration * 5) time.Sleep(mtwSleepDuration * 5)
mtwOutputChannel <- "6" // slot 15 mtwOutputChannel <- "6" // slot 15
@ -86,7 +86,7 @@ func TestMicroTaskWaiting(t *testing.T) {
go func() { go func() {
defer mtwWaitGroup.Done() defer mtwWaitGroup.Done()
// exec at slot 3 // exec at slot 3
_ = mtModule.StartMediumPriorityMicroTask(&mtTestName, func(ctx context.Context) error { _ = mtModule.RunMediumPriorityMicroTask(&mtTestName, func(ctx context.Context) error {
mtwOutputChannel <- "3" // slot 6 mtwOutputChannel <- "3" // slot 6
time.Sleep(mtwSleepDuration * 7) time.Sleep(mtwSleepDuration * 7)
mtwOutputChannel <- "5" // slot 13 mtwOutputChannel <- "5" // slot 13
@ -122,7 +122,7 @@ var mtoWaitCh chan struct{}
func mediumPrioTaskTester() { func mediumPrioTaskTester() {
defer mtoWaitGroup.Done() defer mtoWaitGroup.Done()
<-mtoWaitCh <-mtoWaitCh
_ = mtModule.StartMediumPriorityMicroTask(&mtTestName, func(ctx context.Context) error { _ = mtModule.RunMediumPriorityMicroTask(&mtTestName, func(ctx context.Context) error {
mtoOutputChannel <- "1" mtoOutputChannel <- "1"
time.Sleep(2 * time.Millisecond) time.Sleep(2 * time.Millisecond)
return nil return nil
@ -132,7 +132,7 @@ func mediumPrioTaskTester() {
func lowPrioTaskTester() { func lowPrioTaskTester() {
defer mtoWaitGroup.Done() defer mtoWaitGroup.Done()
<-mtoWaitCh <-mtoWaitCh
_ = mtModule.StartLowPriorityMicroTask(&mtTestName, func(ctx context.Context) error { _ = mtModule.RunLowPriorityMicroTask(&mtTestName, func(ctx context.Context) error {
mtoOutputChannel <- "2" mtoOutputChannel <- "2"
time.Sleep(2 * time.Millisecond) time.Sleep(2 * time.Millisecond)
return nil return nil

View file

@ -13,7 +13,7 @@ import (
) )
var ( var (
modulesLock sync.Mutex modulesLock sync.RWMutex
modules = make(map[string]*Module) modules = make(map[string]*Module)
// ErrCleanExit is returned by Start() when the program is interrupted before starting. This can happen for example, when using the "--help" flag. // ErrCleanExit is returned by Start() when the program is interrupted before starting. This can happen for example, when using the "--help" flag.
@ -46,6 +46,10 @@ type Module struct {
microTaskCnt *int32 microTaskCnt *int32
waitGroup sync.WaitGroup waitGroup sync.WaitGroup
// events
eventHooks map[string][]*eventHook
eventHooksLock sync.RWMutex
// dependency mgmt // dependency mgmt
depNames []string depNames []string
depModules []*Module depModules []*Module
@ -67,15 +71,25 @@ func (m *Module) shutdown() error {
m.shutdownFlag.Set() m.shutdownFlag.Set()
m.cancelCtx() m.cancelCtx()
// start shutdown function
m.waitGroup.Add(1)
stopFnError := make(chan error, 1)
go func() {
stopFnError <- m.runCtrlFn("stop module", m.stop)
m.waitGroup.Done()
}()
// wait for workers // wait for workers
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
m.waitGroup.Wait() m.waitGroup.Wait()
close(done) close(done)
}() }()
// wait for results
select { select {
case <-done: case <-done:
case <-time.After(3 * time.Second): case <-time.After(30 * time.Second):
log.Warningf( log.Warningf(
"%s: timed out while waiting for workers/tasks to finish: workers=%d tasks=%d microtasks=%d, continuing shutdown...", "%s: timed out while waiting for workers/tasks to finish: workers=%d tasks=%d microtasks=%d, continuing shutdown...",
m.Name, m.Name,
@ -85,12 +99,17 @@ func (m *Module) shutdown() error {
) )
} }
// call shutdown function // collect error
return m.stop() select {
} case err := <-stopFnError:
return err
func dummyAction() error { default:
return nil log.Warningf(
"%s: timed out while waiting for stop function to finish, continuing shutdown...",
m.Name,
)
return nil
}
} }
// Register registers a new module. The control functions `prep`, `start` and `stop` are technically optional. `stop` is called _after_ all added module workers finished. // Register registers a new module. The control functions `prep`, `start` and `stop` are technically optional. `stop` is called _after_ all added module workers finished.
@ -99,7 +118,14 @@ func Register(name string, prep, start, stop func() error, dependencies ...strin
modulesLock.Lock() modulesLock.Lock()
defer modulesLock.Unlock() defer modulesLock.Unlock()
// check for already existing module
_, ok := modules[name]
if ok {
panic(fmt.Sprintf("modules: module %s is already registered", name))
}
// add new module
modules[name] = newModule modules[name] = newModule
return newModule return newModule
} }
@ -125,20 +151,10 @@ func initNewModule(name string, prep, start, stop func() error, dependencies ...
prep: prep, prep: prep,
start: start, start: start,
stop: stop, stop: stop,
eventHooks: make(map[string][]*eventHook),
depNames: dependencies, depNames: dependencies,
} }
// replace nil arguments with dummy action
if newModule.prep == nil {
newModule.prep = dummyAction
}
if newModule.start == nil {
newModule.start = dummyAction
}
if newModule.stop == nil {
newModule.stop = dummyAction
}
return newModule return newModule
} }

View file

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"os" "os"
"runtime" "runtime"
"time"
"github.com/safing/portbase/log" "github.com/safing/portbase/log"
"github.com/tevino/abool" "github.com/tevino/abool"
@ -26,8 +27,8 @@ func WaitForStartCompletion() <-chan struct{} {
// Start starts all modules in the correct order. In case of an error, it will automatically shutdown again. // Start starts all modules in the correct order. In case of an error, it will automatically shutdown again.
func Start() error { func Start() error {
modulesLock.Lock() modulesLock.RLock()
defer modulesLock.Unlock() defer modulesLock.RUnlock()
// start microtask scheduler // start microtask scheduler
go microTaskScheduler() go microTaskScheduler()
@ -106,7 +107,11 @@ func prepareModules() error {
go func() { go func() {
reports <- &report{ reports <- &report{
module: execM, module: execM,
err: execM.prep(), err: execM.runCtrlFnWithTimeout(
"prep module",
10*time.Second,
execM.prep,
),
} }
}() }()
} }
@ -154,7 +159,11 @@ func startModules() error {
go func() { go func() {
reports <- &report{ reports <- &report{
module: execM, module: execM,
err: execM.start(), err: execM.runCtrlFnWithTimeout(
"start module",
60*time.Second,
execM.start,
),
} }
}() }()
} }

View file

@ -12,6 +12,8 @@ import (
var ( var (
shutdownSignal = make(chan struct{}) shutdownSignal = make(chan struct{})
shutdownSignalClosed = abool.NewBool(false) shutdownSignalClosed = abool.NewBool(false)
shutdownCompleteSignal = make(chan struct{})
) )
// ShuttingDown returns a channel read on the global shutdown signal. // ShuttingDown returns a channel read on the global shutdown signal.
@ -45,6 +47,7 @@ func Shutdown() error {
} }
log.Shutdown() log.Shutdown()
close(shutdownCompleteSignal)
return err return err
} }

View file

@ -3,6 +3,7 @@ package modules
import ( import (
"container/list" "container/list"
"context" "context"
"fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -59,6 +60,10 @@ const (
// NewTask creates a new task with a descriptive name (non-unique), a optional deadline, and the task function to be executed. You must call one of Queue, Prioritize, StartASAP, Schedule or Repeat in order to have the Task executed. // NewTask creates a new task with a descriptive name (non-unique), a optional deadline, and the task function to be executed. You must call one of Queue, Prioritize, StartASAP, Schedule or Repeat in order to have the Task executed.
func (m *Module) NewTask(name string, fn func(context.Context, *Task)) *Task { func (m *Module) NewTask(name string, fn func(context.Context, *Task)) *Task {
if m == nil {
log.Errorf(`modules: cannot create task "%s" with nil module`, name)
}
return &Task{ return &Task{
name: name, name: name,
module: m, module: m,
@ -68,6 +73,10 @@ func (m *Module) NewTask(name string, fn func(context.Context, *Task)) *Task {
} }
func (t *Task) isActive() bool { func (t *Task) isActive() bool {
if t.module == nil {
return false
}
return !t.canceled && !t.module.ShutdownInProgress() return !t.canceled && !t.module.ShutdownInProgress()
} }
@ -178,6 +187,7 @@ func (t *Task) Repeat(interval time.Duration) *Task {
t.lock.Lock() t.lock.Lock()
t.repeat = interval t.repeat = interval
t.executeAt = time.Now().Add(t.repeat) t.executeAt = time.Now().Add(t.repeat)
t.addToSchedule()
t.lock.Unlock() t.lock.Unlock()
return t return t
@ -194,6 +204,10 @@ func (t *Task) Cancel() {
} }
func (t *Task) runWithLocking() { func (t *Task) runWithLocking() {
if t.module == nil {
return
}
// wait for good timeslot regarding microtasks // wait for good timeslot regarding microtasks
select { select {
case <-taskTimeslot: case <-taskTimeslot:
@ -308,6 +322,7 @@ func (t *Task) getExecuteAtWithLocking() time.Time {
func (t *Task) addToSchedule() { func (t *Task) addToSchedule() {
scheduleLock.Lock() scheduleLock.Lock()
defer scheduleLock.Unlock() defer scheduleLock.Unlock()
// defer printTaskList(taskSchedule) // for debugging
// notify scheduler // notify scheduler
defer func() { defer func() {
@ -439,3 +454,23 @@ func taskScheduleHandler() {
} }
} }
} }
func printTaskList(*list.List) { //nolint:unused,deadcode // for debugging, NOT production use
fmt.Println("Modules Task List:")
for e := taskSchedule.Front(); e != nil; e = e.Next() {
t, ok := e.Value.(*Task)
if ok {
fmt.Printf(
"%s:%s qu=%v ca=%v exec=%v at=%s rep=%s delay=%s\n",
t.module.Name,
t.name,
t.queued,
t.canceled,
t.executing,
t.executeAt,
t.repeat,
t.maxDelay,
)
}
}
}

View file

@ -2,6 +2,8 @@ package modules
import ( import (
"context" "context"
"errors"
"fmt"
"sync/atomic" "sync/atomic"
"time" "time"
@ -13,8 +15,29 @@ const (
DefaultBackoffDuration = 2 * time.Second DefaultBackoffDuration = 2 * time.Second
) )
var (
// ErrRestartNow may be returned (wrapped) by service workers to request an immediate restart.
ErrRestartNow = errors.New("requested restart")
errNoModule = errors.New("missing module (is nil!)")
)
// StartWorker directly starts a generic worker that does not fit to be a Task or MicroTask, such as long running (and possibly mostly idle) sessions. A call to StartWorker starts a new goroutine and returns immediately.
func (m *Module) StartWorker(name string, fn func(context.Context) error) {
go func() {
err := m.RunWorker(name, fn)
if err != nil {
log.Warningf("%s: worker %s failed: %s", m.Name, name, err)
}
}()
}
// RunWorker directly runs a generic worker that does not fit to be a Task or MicroTask, such as long running (and possibly mostly idle) sessions. A call to RunWorker blocks until the worker is finished. // RunWorker directly runs a generic worker that does not fit to be a Task or MicroTask, such as long running (and possibly mostly idle) sessions. A call to RunWorker blocks until the worker is finished.
func (m *Module) RunWorker(name string, fn func(context.Context) error) error { func (m *Module) RunWorker(name string, fn func(context.Context) error) error {
if m == nil {
log.Errorf(`modules: cannot start worker "%s" with nil module`, name)
return errNoModule
}
atomic.AddInt32(m.workerCnt, 1) atomic.AddInt32(m.workerCnt, 1)
m.waitGroup.Add(1) m.waitGroup.Add(1)
defer func() { defer func() {
@ -27,6 +50,11 @@ func (m *Module) RunWorker(name string, fn func(context.Context) error) error {
// StartServiceWorker starts a generic worker, which is automatically restarted in case of an error. A call to StartServiceWorker runs the service-worker in a new goroutine and returns immediately. `backoffDuration` specifies how to long to wait before restarts, multiplied by the number of failed attempts. Pass `0` for the default backoff duration. For custom error remediation functionality, build your own error handling procedure using calls to RunWorker. // StartServiceWorker starts a generic worker, which is automatically restarted in case of an error. A call to StartServiceWorker runs the service-worker in a new goroutine and returns immediately. `backoffDuration` specifies how to long to wait before restarts, multiplied by the number of failed attempts. Pass `0` for the default backoff duration. For custom error remediation functionality, build your own error handling procedure using calls to RunWorker.
func (m *Module) StartServiceWorker(name string, backoffDuration time.Duration, fn func(context.Context) error) { func (m *Module) StartServiceWorker(name string, backoffDuration time.Duration, fn func(context.Context) error) {
if m == nil {
log.Errorf(`modules: cannot start service worker "%s" with nil module`, name)
return
}
go m.runServiceWorker(name, backoffDuration, fn) go m.runServiceWorker(name, backoffDuration, fn)
} }
@ -42,6 +70,7 @@ func (m *Module) runServiceWorker(name string, backoffDuration time.Duration, fn
backoffDuration = DefaultBackoffDuration backoffDuration = DefaultBackoffDuration
} }
failCnt := 0 failCnt := 0
lastFail := time.Now()
for { for {
if m.ShutdownInProgress() { if m.ShutdownInProgress() {
@ -50,11 +79,22 @@ func (m *Module) runServiceWorker(name string, backoffDuration time.Duration, fn
err := m.runWorker(name, fn) err := m.runWorker(name, fn)
if err != nil { if err != nil {
// log error and restart if !errors.Is(err, ErrRestartNow) {
failCnt++ // reset fail counter if running without error for some time
sleepFor := time.Duration(failCnt) * backoffDuration if time.Now().Add(-5 * time.Minute).After(lastFail) {
log.Errorf("%s: service-worker %s failed (%d): %s - restarting in %s", m.Name, name, failCnt, err, sleepFor) failCnt = 0
time.Sleep(sleepFor) }
// increase fail counter and set last failed time
failCnt++
lastFail = time.Now()
// log error
sleepFor := time.Duration(failCnt) * backoffDuration
log.Errorf("%s: service-worker %s failed (%d): %s - restarting in %s", m.Name, name, failCnt, err, sleepFor)
time.Sleep(sleepFor)
// loop to restart
} else {
log.Infof("%s: service-worker %s %s - restarting now", m.Name, name, err)
}
} else { } else {
// finish // finish
return return
@ -77,3 +117,39 @@ func (m *Module) runWorker(name string, fn func(context.Context) error) (err err
err = fn(m.Ctx) err = fn(m.Ctx)
return return
} }
func (m *Module) runCtrlFnWithTimeout(name string, timeout time.Duration, fn func() error) error {
stopFnError := make(chan error)
go func() {
stopFnError <- m.runCtrlFn(name, fn)
}()
// wait for results
select {
case err := <-stopFnError:
return err
case <-time.After(timeout):
return fmt.Errorf("timed out (%s)", timeout)
}
}
func (m *Module) runCtrlFn(name string, fn func() error) (err error) {
if fn == nil {
return
}
defer func() {
// recover from panic
panicVal := recover()
if panicVal != nil {
me := m.NewPanicError(name, "module-control", panicVal)
me.Report()
err = me
}
}()
// run
err = fn()
return
}

View file

@ -11,7 +11,7 @@ var (
) )
func init() { func init() {
modules.Register("notifications", nil, start, nil, "base", "database") module = modules.Register("notifications", nil, start, nil, "base", "database")
} }
func start() error { func start() error {

View file

@ -38,6 +38,7 @@ func noise() {
} }
//nolint:gocognit
func main() { func main() {
// generates 1MB and writes to stdout // generates 1MB and writes to stdout

126
run/main.go Normal file
View file

@ -0,0 +1,126 @@
package run
import (
"bufio"
"flag"
"fmt"
"os"
"os/signal"
"runtime/pprof"
"syscall"
"time"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
)
var (
printStackOnExit bool
enableInputSignals bool
sigUSR1 = syscall.Signal(0xa) // dummy for windows
)
func init() {
flag.BoolVar(&printStackOnExit, "print-stack-on-exit", false, "prints the stack before of shutting down")
flag.BoolVar(&enableInputSignals, "input-signals", false, "emulate signals using stdin")
}
// Run execute a full program lifecycle (including signal handling) based on modules. Just empty-import required packages and do os.Exit(run.Run()).
func Run() int {
// Start
err := modules.Start()
if err != nil {
if err == modules.ErrCleanExit {
return 0
}
_ = modules.Shutdown()
return modules.GetExitStatusCode()
}
// Shutdown
// catch interrupt for clean shutdown
signalCh := make(chan os.Signal)
if enableInputSignals {
go inputSignals(signalCh)
}
signal.Notify(
signalCh,
os.Interrupt,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT,
sigUSR1,
)
signalLoop:
for {
select {
case sig := <-signalCh:
// only print and continue to wait if SIGUSR1
if sig == sigUSR1 {
_ = pprof.Lookup("goroutine").WriteTo(os.Stderr, 2)
continue signalLoop
}
fmt.Println(" <INTERRUPT>")
log.Warning("main: program was interrupted, shutting down.")
// catch signals during shutdown
go func() {
for {
<-signalCh
fmt.Println(" <INTERRUPT> again, but already shutting down")
}
}()
if printStackOnExit {
fmt.Println("=== PRINTING TRACES ===")
fmt.Println("=== GOROUTINES ===")
_ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 2)
fmt.Println("=== BLOCKING ===")
_ = pprof.Lookup("block").WriteTo(os.Stdout, 2)
fmt.Println("=== MUTEXES ===")
_ = pprof.Lookup("mutex").WriteTo(os.Stdout, 2)
fmt.Println("=== END TRACES ===")
}
go func() {
time.Sleep(60 * time.Second)
fmt.Fprintln(os.Stderr, "===== TAKING TOO LONG FOR SHUTDOWN - PRINTING STACK TRACES =====")
_ = pprof.Lookup("goroutine").WriteTo(os.Stderr, 2)
os.Exit(1)
}()
_ = modules.Shutdown()
break signalLoop
case <-modules.ShuttingDown():
break signalLoop
}
}
// wait for shutdown to complete, then exit
return modules.GetExitStatusCode()
}
func inputSignals(signalCh chan os.Signal) {
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
switch scanner.Text() {
case "SIGHUP":
signalCh <- syscall.SIGHUP
case "SIGINT":
signalCh <- syscall.SIGINT
case "SIGQUIT":
signalCh <- syscall.SIGQUIT
case "SIGTERM":
signalCh <- syscall.SIGTERM
case "SIGUSR1":
signalCh <- sigUSR1
}
}
}

2
updater/doc.go Normal file
View file

@ -0,0 +1,2 @@
// Package updater is an update registry that manages updates and versions.
package updater

15
updater/export.go Normal file
View file

@ -0,0 +1,15 @@
package updater
// Export exports the list of resources. All resources must be locked when accessed.
func (reg *ResourceRegistry) Export() map[string]*Resource {
reg.RLock()
defer reg.RUnlock()
// copy the map
new := make(map[string]*Resource)
for key, val := range reg.resources {
new[key] = val
}
return new
}

120
updater/fetch.go Normal file
View file

@ -0,0 +1,120 @@
package updater
import (
"bytes"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"time"
"github.com/google/renameio"
"github.com/safing/portbase/log"
)
func (reg *ResourceRegistry) fetchFile(rv *ResourceVersion, tries int) error {
// backoff when retrying
if tries > 0 {
time.Sleep(time.Duration(tries*tries) * time.Second)
}
// create URL
downloadURL, err := joinURLandPath(reg.UpdateURLs[tries%len(reg.UpdateURLs)], rv.versionedPath())
if err != nil {
return fmt.Errorf("error build url (%s + %s): %s", reg.UpdateURLs[tries%len(reg.UpdateURLs)], rv.versionedPath(), err)
}
// check destination dir
dirPath := filepath.Dir(rv.storagePath())
err = reg.storageDir.EnsureAbsPath(dirPath)
if err != nil {
return fmt.Errorf("could not create updates folder: %s", dirPath)
}
// open file for writing
atomicFile, err := renameio.TempFile(reg.tmpDir.Path, rv.storagePath())
if err != nil {
return fmt.Errorf("could not create temp file for download: %s", err)
}
defer atomicFile.Cleanup() //nolint:errcheck // ignore error for now, tmp dir will be cleaned later again anyway
// start file download
resp, err := http.Get(downloadURL) //nolint:gosec // url is variable on purpose
if err != nil {
return fmt.Errorf("error fetching url (%s): %s", downloadURL, err)
}
defer resp.Body.Close()
// download and write file
n, err := io.Copy(atomicFile, resp.Body)
if err != nil {
return fmt.Errorf("failed downloading %s: %s", downloadURL, err)
}
if resp.ContentLength != n {
return fmt.Errorf("download unfinished, written %d out of %d bytes", n, resp.ContentLength)
}
// finalize file
err = atomicFile.CloseAtomicallyReplace()
if err != nil {
return fmt.Errorf("%s: failed to finalize file %s: %s", reg.Name, rv.storagePath(), err)
}
// set permissions
if !onWindows {
// TODO: only set executable files to 0755, set other to 0644
err = os.Chmod(rv.storagePath(), 0755)
if err != nil {
log.Warningf("%s: failed to set permissions on downloaded file %s: %s", reg.Name, rv.storagePath(), err)
}
}
log.Infof("%s: fetched %s (stored to %s)", reg.Name, downloadURL, rv.storagePath())
return nil
}
func (reg *ResourceRegistry) fetchData(downloadPath string, tries int) ([]byte, error) {
// backoff when retrying
if tries > 0 {
time.Sleep(time.Duration(tries*tries) * time.Second)
}
// create URL
downloadURL, err := joinURLandPath(reg.UpdateURLs[tries%len(reg.UpdateURLs)], downloadPath)
if err != nil {
return nil, fmt.Errorf("error build url (%s + %s): %s", reg.UpdateURLs[tries%len(reg.UpdateURLs)], downloadPath, err)
}
// start file download
resp, err := http.Get(downloadURL) //nolint:gosec // url is variable on purpose
if err != nil {
return nil, fmt.Errorf("error fetching url (%s): %s", downloadURL, err)
}
defer resp.Body.Close()
// download and write file
buf := bytes.NewBuffer(make([]byte, 0, resp.ContentLength))
n, err := io.Copy(buf, resp.Body)
if err != nil {
return nil, fmt.Errorf("failed downloading %s: %s", downloadURL, err)
}
if resp.ContentLength != n {
return nil, fmt.Errorf("download unfinished, written %d out of %d bytes", n, resp.ContentLength)
}
return buf.Bytes(), nil
}
func joinURLandPath(baseURL, urlPath string) (string, error) {
u, err := url.Parse(baseURL)
if err != nil {
return "", err
}
u.Path = path.Join(u.Path, urlPath)
return u.String(), nil
}

41
updater/file.go Normal file
View file

@ -0,0 +1,41 @@
package updater
// File represents a file from the update system.
type File struct {
resource *Resource
version *ResourceVersion
notifier *notifier
versionedPath string
storagePath string
}
// Identifier returns the identifier of the file.
func (file *File) Identifier() string {
return file.resource.Identifier
}
// Version returns the version of the file.
func (file *File) Version() string {
return file.version.VersionNumber
}
// Path returns the absolute filepath of the file.
func (file *File) Path() string {
return file.storagePath
}
// Blacklist notifies the update system that this file is somehow broken, and should be ignored from now on, until restarted.
func (file *File) Blacklist() error {
return file.resource.Blacklist(file.version.VersionNumber)
}
// used marks the file as active
func (file *File) markActiveWithLocking() {
file.resource.Lock()
defer file.resource.Unlock()
// update last used version
if file.resource.ActiveVersion != file.version {
file.resource.ActiveVersion = file.version
}
}

46
updater/filename.go Normal file
View file

@ -0,0 +1,46 @@
package updater
import (
"fmt"
"regexp"
"strings"
)
var (
fileVersionRegex = regexp.MustCompile(`_v([0-9]+-[0-9]+-[0-9]+b?|0)`)
rawVersionRegex = regexp.MustCompile(`^([0-9]+\.[0-9]+\.[0-9]+b?\*?|0)$`)
)
// GetIdentifierAndVersion splits the given file path into its identifier and version.
func GetIdentifierAndVersion(versionedPath string) (identifier, version string, ok bool) {
// extract version
rawVersion := fileVersionRegex.FindString(versionedPath)
if rawVersion == "" {
return "", "", false
}
// replace - with . and trim _
version = strings.Replace(strings.TrimLeft(rawVersion, "_v"), "-", ".", -1)
// put together without version
i := strings.Index(versionedPath, rawVersion)
if i < 0 {
// extracted version not in string (impossible)
return "", "", false
}
return versionedPath[:i] + versionedPath[i+len(rawVersion):], version, true
}
// GetVersionedPath combines the identifier and version and returns it as a file path.
func GetVersionedPath(identifier, version string) (versionedPath string) {
// split in half
splittedFilePath := strings.SplitN(identifier, ".", 2)
// replace . with -
transformedVersion := strings.Replace(version, ".", "-", -1)
// put together
if len(splittedFilePath) == 1 {
return fmt.Sprintf("%s_v%s", splittedFilePath[0], transformedVersion)
}
return fmt.Sprintf("%s_v%s.%s", splittedFilePath[0], transformedVersion, splittedFilePath[1])
}

53
updater/filename_test.go Normal file
View file

@ -0,0 +1,53 @@
package updater
import (
"regexp"
"testing"
)
func testRegexMatch(t *testing.T, testRegex *regexp.Regexp, testString string, shouldMatch bool) {
if testRegex.MatchString(testString) != shouldMatch {
if shouldMatch {
t.Errorf("regex %s should match %s", testRegex, testString)
} else {
t.Errorf("regex %s should not match %s", testRegex, testString)
}
}
}
func testRegexFind(t *testing.T, testRegex *regexp.Regexp, testString string, shouldMatch bool) {
if (testRegex.FindString(testString) != "") != shouldMatch {
if shouldMatch {
t.Errorf("regex %s should find %s", testRegex, testString)
} else {
t.Errorf("regex %s should not find %s", testRegex, testString)
}
}
}
func TestRegexes(t *testing.T) {
testRegexMatch(t, rawVersionRegex, "0", true)
testRegexMatch(t, rawVersionRegex, "0.1.2", true)
testRegexMatch(t, rawVersionRegex, "0.1.2*", true)
testRegexMatch(t, rawVersionRegex, "0.1.2b", true)
testRegexMatch(t, rawVersionRegex, "0.1.2b*", true)
testRegexMatch(t, rawVersionRegex, "12.13.14", true)
testRegexMatch(t, rawVersionRegex, "v0.1.2", false)
testRegexMatch(t, rawVersionRegex, "0.", false)
testRegexMatch(t, rawVersionRegex, "0.1", false)
testRegexMatch(t, rawVersionRegex, "0.1.", false)
testRegexMatch(t, rawVersionRegex, ".1.2", false)
testRegexMatch(t, rawVersionRegex, ".1.", false)
testRegexMatch(t, rawVersionRegex, "012345", false)
testRegexFind(t, fileVersionRegex, "/path/to/file_v0", true)
testRegexFind(t, fileVersionRegex, "/path/to/file_v1-2-3", true)
testRegexFind(t, fileVersionRegex, "/path/to/file_v1-2-3.exe", true)
testRegexFind(t, fileVersionRegex, "/path/to/file-v1-2-3", false)
testRegexFind(t, fileVersionRegex, "/path/to/file_v1.2.3", false)
testRegexFind(t, fileVersionRegex, "/path/to/file_1-2-3", false)
testRegexFind(t, fileVersionRegex, "/path/to/file_v1-2", false)
testRegexFind(t, fileVersionRegex, "/path/to/file-v1-2-3", false)
}

56
updater/get.go Normal file
View file

@ -0,0 +1,56 @@
package updater
import (
"errors"
"fmt"
"github.com/safing/portbase/log"
)
// Errors
var (
ErrNotFound = errors.New("the requested file could not be found")
ErrNotAvailableLocally = errors.New("the requested file is not available locally")
)
// GetFile returns the selected (mostly newest) file with the given identifier or an error, if it fails.
func (reg *ResourceRegistry) GetFile(identifier string) (*File, error) {
reg.RLock()
res, ok := reg.resources[identifier]
reg.RUnlock()
if !ok {
return nil, ErrNotFound
}
file := res.GetFile()
// check if file is available locally
if file.version.Available {
file.markActiveWithLocking()
return file, nil
}
// check if online
if !reg.Online {
return nil, ErrNotAvailableLocally
}
// check download dir
err := reg.tmpDir.Ensure()
if err != nil {
return nil, fmt.Errorf("could not prepare tmp directory for download: %s", err)
}
// download file
log.Tracef("%s: starting download of %s", reg.Name, file.versionedPath)
for tries := 0; tries < 5; tries++ {
err = reg.fetchFile(file.version, tries)
if err != nil {
log.Tracef("%s: failed to download %s: %s, retrying (%d)", reg.Name, file.versionedPath, err, tries+1)
} else {
file.markActiveWithLocking()
return file, nil
}
}
log.Warningf("%s: failed to download %s: %s", reg.Name, file.versionedPath, err)
return nil, err
}

33
updater/notifier.go Normal file
View file

@ -0,0 +1,33 @@
package updater
import (
"github.com/tevino/abool"
)
type notifier struct {
upgradeAvailable *abool.AtomicBool
notifyChannel chan struct{}
}
func newNotifier() *notifier {
return &notifier{
upgradeAvailable: abool.NewBool(false),
notifyChannel: make(chan struct{}),
}
}
func (n *notifier) markAsUpgradeable() {
if n.upgradeAvailable.SetToIf(false, true) {
close(n.notifyChannel)
}
}
// UpgradeAvailable returns whether an upgrade is available for this file.
func (file *File) UpgradeAvailable() bool {
return file.notifier.upgradeAvailable.IsSet()
}
// WaitForAvailableUpgrade blocks (selectable) until an upgrade for this file is available.
func (file *File) WaitForAvailableUpgrade() <-chan struct{} {
return file.notifier.notifyChannel
}

161
updater/registry.go Normal file
View file

@ -0,0 +1,161 @@
package updater
import (
"os"
"runtime"
"sync"
"github.com/tevino/abool"
"github.com/safing/portbase/log"
"github.com/safing/portbase/utils"
)
const (
onWindows = runtime.GOOS == "windows"
)
// ResourceRegistry is a registry for managing update resources.
type ResourceRegistry struct {
sync.RWMutex
Name string
storageDir *utils.DirStructure
tmpDir *utils.DirStructure
resources map[string]*Resource
UpdateURLs []string
MandatoryUpdates []string
Beta bool
DevMode bool
Online bool
notifyHooks []func()
notifyHooksEnabled *abool.AtomicBool
}
// Initialize initializes a raw registry struct and makes it ready for usage.
func (reg *ResourceRegistry) Initialize(storageDir *utils.DirStructure) error {
// check if storage dir is available
err := storageDir.Ensure()
if err != nil {
return err
}
// set default name
if reg.Name == "" {
reg.Name = "updater"
}
// initialize private attributes
reg.storageDir = storageDir
reg.tmpDir = storageDir.ChildDir("tmp", 0700)
reg.resources = make(map[string]*Resource)
reg.notifyHooksEnabled = abool.NewBool(true)
return nil
}
// StorageDir returns the main storage dir of the resource registry.
func (reg *ResourceRegistry) StorageDir() *utils.DirStructure {
return reg.storageDir
}
// TmpDir returns the temporary working dir of the resource registry.
func (reg *ResourceRegistry) TmpDir() *utils.DirStructure {
return reg.tmpDir
}
// SetDevMode sets the development mode flag.
func (reg *ResourceRegistry) SetDevMode(on bool) {
reg.Lock()
defer reg.Unlock()
reg.DevMode = on
}
// SetBeta sets the beta flag.
func (reg *ResourceRegistry) SetBeta(on bool) {
reg.Lock()
defer reg.Unlock()
reg.Beta = on
}
// AddResource adds a resource to the registry. Does _not_ select new version.
func (reg *ResourceRegistry) AddResource(identifier, version string, available, stableRelease, betaRelease bool) error {
reg.Lock()
defer reg.Unlock()
err := reg.addResource(identifier, version, available, stableRelease, betaRelease)
return err
}
func (reg *ResourceRegistry) addResource(identifier, version string, available, stableRelease, betaRelease bool) error {
res, ok := reg.resources[identifier]
if !ok {
res = reg.newResource(identifier)
reg.resources[identifier] = res
}
return res.AddVersion(version, available, stableRelease, betaRelease)
}
// AddResources adds resources to the registry. Errors are logged, the last one is returned. Despite errors, non-failing resources are still added. Does _not_ select new versions.
func (reg *ResourceRegistry) AddResources(versions map[string]string, available, stableRelease, betaRelease bool) error {
reg.Lock()
defer reg.Unlock()
// add versions and their flags to registry
var lastError error
for identifier, version := range versions {
lastError = reg.addResource(identifier, version, available, stableRelease, betaRelease)
if lastError != nil {
log.Warningf("%s: failed to add resource %s: %s", reg.Name, identifier, lastError)
}
}
return lastError
}
// SelectVersions selects new resource versions depending on the current registry state.
func (reg *ResourceRegistry) SelectVersions() {
reg.RLock()
defer reg.RUnlock()
for _, res := range reg.resources {
res.Lock()
res.selectVersion()
res.Unlock()
}
}
// GetSelectedVersions returns a list of the currently selected versions.
func (reg *ResourceRegistry) GetSelectedVersions() (versions map[string]string) {
reg.RLock()
defer reg.RUnlock()
for _, res := range reg.resources {
res.Lock()
versions[res.Identifier] = res.SelectedVersion.VersionNumber
res.Unlock()
}
return
}
// Purge deletes old updates, retaining a certain amount, specified by the keep parameter. Will at least keep 2 updates per resource.
func (reg *ResourceRegistry) Purge(keep int) {
reg.RLock()
defer reg.RUnlock()
for _, res := range reg.resources {
res.Purge(keep)
}
}
// Cleanup removes temporary files.
func (reg *ResourceRegistry) Cleanup() error {
// delete download tmp dir
return os.RemoveAll(reg.tmpDir.Path)
}

38
updater/registry_test.go Normal file
View file

@ -0,0 +1,38 @@
package updater
import (
"io/ioutil"
"os"
"testing"
"github.com/safing/portbase/utils"
)
var (
registry *ResourceRegistry
)
func TestMain(m *testing.M) {
// setup
tmpDir, err := ioutil.TempDir("", "ci-portmaster-")
if err != nil {
panic(err)
}
registry = &ResourceRegistry{
Beta: true,
DevMode: true,
Online: true,
}
err = registry.Initialize(utils.NewDirStructure(tmpDir, 0777))
if err != nil {
panic(err)
}
// run
// call flag.Parse() here if TestMain uses flags
ret := m.Run()
// teardown
os.RemoveAll(tmpDir)
os.Exit(ret)
}

330
updater/resource.go Normal file
View file

@ -0,0 +1,330 @@
package updater
import (
"errors"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"github.com/safing/portbase/log"
semver "github.com/hashicorp/go-version"
)
// Resource represents a resource (via an identifier) and multiple file versions.
type Resource struct {
sync.Mutex
registry *ResourceRegistry
notifier *notifier
Identifier string
Versions []*ResourceVersion
ActiveVersion *ResourceVersion
SelectedVersion *ResourceVersion
ForceDownload bool
}
// ResourceVersion represents a single version of a resource.
type ResourceVersion struct {
resource *Resource
VersionNumber string
semVer *semver.Version
Available bool
StableRelease bool
BetaRelease bool
Blacklisted bool
}
// Len is the number of elements in the collection. (sort.Interface for Versions)
func (res *Resource) Len() int {
return len(res.Versions)
}
// Less reports whether the element with index i should sort before the element with index j. (sort.Interface for Versions)
func (res *Resource) Less(i, j int) bool {
return res.Versions[i].semVer.GreaterThan(res.Versions[j].semVer)
}
// Swap swaps the elements with indexes i and j. (sort.Interface for Versions)
func (res *Resource) Swap(i, j int) {
res.Versions[i], res.Versions[j] = res.Versions[j], res.Versions[i]
}
// available returns whether any version of the resource is available.
func (res *Resource) available() bool {
for _, rv := range res.Versions {
if rv.Available {
return true
}
}
return false
}
func (reg *ResourceRegistry) newResource(identifier string) *Resource {
return &Resource{
registry: reg,
Identifier: identifier,
Versions: make([]*ResourceVersion, 0, 1),
}
}
// AddVersion adds a resource version to a resource.
func (res *Resource) AddVersion(version string, available, stableRelease, betaRelease bool) error {
res.Lock()
defer res.Unlock()
// reset stable or beta release flags
if stableRelease || betaRelease {
for _, rv := range res.Versions {
if stableRelease {
rv.StableRelease = false
}
if betaRelease {
rv.BetaRelease = false
}
}
}
var rv *ResourceVersion
// check for existing version
for _, possibleMatch := range res.Versions {
if possibleMatch.VersionNumber == version {
rv = possibleMatch
break
}
}
// create new version if none found
if rv == nil {
// parse to semver
sv, err := semver.NewVersion(version)
if err != nil {
return err
}
rv = &ResourceVersion{
resource: res,
VersionNumber: version,
semVer: sv,
}
res.Versions = append(res.Versions, rv)
}
// set flags
if available {
rv.Available = true
}
if stableRelease {
rv.StableRelease = true
}
if betaRelease {
rv.BetaRelease = true
}
return nil
}
// GetFile returns the selected version as a *File.
func (res *Resource) GetFile() *File {
res.Lock()
defer res.Unlock()
// check for notifier
if res.notifier == nil {
// create new notifier
res.notifier = newNotifier()
}
// check if version is selected
if res.SelectedVersion == nil {
res.selectVersion()
}
// create file
return &File{
resource: res,
version: res.SelectedVersion,
notifier: res.notifier,
versionedPath: res.SelectedVersion.versionedPath(),
storagePath: res.SelectedVersion.storagePath(),
}
}
//nolint:gocognit // function already kept as simlpe as possible
func (res *Resource) selectVersion() {
sort.Sort(res)
// export after we finish
defer func() {
if res.ActiveVersion != nil && // resource has already been used
res.SelectedVersion != res.ActiveVersion && // new selected version does not match previously selected version
res.notifier != nil {
res.notifier.markAsUpgradeable()
res.notifier = nil
}
}()
if len(res.Versions) == 0 {
// TODO: find better way to deal with an empty version slice (which should not happen)
res.SelectedVersion = nil
return
}
// Target selection
// 1) Dev release if dev mode is active and ignore blacklisting
if res.registry.DevMode {
// get last element
rv := res.Versions[len(res.Versions)-1]
// check if it's a dev version
if rv.VersionNumber == "0" && rv.Available {
res.SelectedVersion = rv
return
}
}
// 2) Beta release if beta is active
if res.registry.Beta {
for _, rv := range res.Versions {
if rv.BetaRelease {
if !rv.Blacklisted && (rv.Available || rv.resource.registry.Online) {
res.SelectedVersion = rv
return
}
break
}
}
}
// 3) Stable release
for _, rv := range res.Versions {
if rv.StableRelease {
if !rv.Blacklisted && (rv.Available || rv.resource.registry.Online) {
res.SelectedVersion = rv
return
}
break
}
}
// 4) Latest stable release
for _, rv := range res.Versions {
if !strings.HasSuffix(rv.VersionNumber, "b") && !rv.Blacklisted && (rv.Available || rv.resource.registry.Online) {
res.SelectedVersion = rv
return
}
}
// 5) Latest of any type
for _, rv := range res.Versions {
if !rv.Blacklisted && (rv.Available || rv.resource.registry.Online) {
res.SelectedVersion = rv
return
}
}
// 6) Default to newest
res.SelectedVersion = res.Versions[0]
}
// Blacklist blacklists the specified version and selects a new version.
func (res *Resource) Blacklist(version string) error {
res.Lock()
defer res.Unlock()
// count already blacklisted entries
valid := 0
for _, rv := range res.Versions {
if rv.VersionNumber == "0" {
continue // ignore dev versions
}
if !rv.Blacklisted {
valid++
}
}
if valid <= 1 {
return errors.New("cannot blacklist last version") // last one, cannot blacklist!
}
// find version and blacklist
for _, rv := range res.Versions {
if rv.VersionNumber == version {
// blacklist and update
rv.Blacklisted = true
res.selectVersion()
return nil
}
}
return errors.New("could not find version")
}
// Purge deletes old updates, retaining a certain amount, specified by the keep parameter. Will at least keep 2 updates per resource. After purging, new versions will be selected.
func (res *Resource) Purge(keep int) {
res.Lock()
defer res.Unlock()
// safeguard
if keep < 2 {
keep = 2
}
// keep versions
var validVersions int
var skippedActiveVersion bool
var skippedSelectedVersion bool
var purgeFrom int
for i, rv := range res.Versions {
// continue to purging?
if validVersions >= keep && // skip at least <keep> versions
skippedActiveVersion && // skip until active version
skippedSelectedVersion { // skip until selected version
purgeFrom = i
break
}
// keep active version
if !skippedActiveVersion && rv == res.ActiveVersion {
skippedActiveVersion = true
}
// keep selected version
if !skippedSelectedVersion && rv == res.SelectedVersion {
skippedSelectedVersion = true
}
// count valid (not blacklisted) versions
if !rv.Blacklisted {
validVersions++
}
}
// check if there is anything to purge
if purgeFrom < keep || purgeFrom > len(res.Versions) {
return
}
// purge phase
for _, rv := range res.Versions[purgeFrom:] {
// delete
err := os.Remove(rv.storagePath())
if err != nil {
log.Warningf("%s: failed to purge old resource %s: %s", res.registry.Name, rv.storagePath(), err)
}
}
// remove entries of deleted files
res.Versions = res.Versions[purgeFrom:]
res.selectVersion()
}
func (rv *ResourceVersion) versionedPath() string {
return GetVersionedPath(rv.resource.Identifier, rv.VersionNumber)
}
func (rv *ResourceVersion) storagePath() string {
return filepath.Join(rv.resource.registry.storageDir.Path, filepath.FromSlash(rv.versionedPath()))
}

81
updater/resource_test.go Normal file
View file

@ -0,0 +1,81 @@
package updater
import (
"testing"
)
func TestVersionSelection(t *testing.T) {
res := registry.newResource("test/a")
err := res.AddVersion("1.2.3", true, true, false)
if err != nil {
t.Fatal(err)
}
err = res.AddVersion("1.2.4b", true, false, true)
if err != nil {
t.Fatal(err)
}
err = res.AddVersion("1.2.2", true, false, false)
if err != nil {
t.Fatal(err)
}
err = res.AddVersion("1.2.5", false, true, false)
if err != nil {
t.Fatal(err)
}
err = res.AddVersion("0", true, false, false)
if err != nil {
t.Fatal(err)
}
registry.Online = true
registry.Beta = true
registry.DevMode = true
res.selectVersion()
if res.SelectedVersion.VersionNumber != "0" {
t.Errorf("selected version should be 0, not %s", res.SelectedVersion.VersionNumber)
}
registry.DevMode = false
res.selectVersion()
if res.SelectedVersion.VersionNumber != "1.2.4b" {
t.Errorf("selected version should be 1.2.4b, not %s", res.SelectedVersion.VersionNumber)
}
registry.Beta = false
res.selectVersion()
if res.SelectedVersion.VersionNumber != "1.2.5" {
t.Errorf("selected version should be 1.2.5, not %s", res.SelectedVersion.VersionNumber)
}
registry.Online = false
res.selectVersion()
if res.SelectedVersion.VersionNumber != "1.2.3" {
t.Errorf("selected version should be 1.2.3, not %s", res.SelectedVersion.VersionNumber)
}
f123 := res.GetFile()
f123.markActiveWithLocking()
err = res.Blacklist("1.2.3")
if err != nil {
t.Fatal(err)
}
if res.SelectedVersion.VersionNumber != "1.2.2" {
t.Errorf("selected version should be 1.2.2, not %s", res.SelectedVersion.VersionNumber)
}
if !f123.UpgradeAvailable() {
t.Error("upgrade should be available (flag)")
}
select {
case <-f123.WaitForAvailableUpgrade():
default:
t.Error("upgrade should be available (chan)")
}
t.Logf("resource: %+v", res)
for _, rv := range res.Versions {
t.Logf("version %s: %+v", rv.VersionNumber, rv)
}
}

159
updater/storage.go Normal file
View file

@ -0,0 +1,159 @@
package updater
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strings"
"github.com/safing/portbase/log"
"github.com/safing/portbase/utils"
)
// ScanStorage scans root within the storage dir and adds found resources to the registry. If an error occurred, it is logged and the last error is returned. Everything that was found despite errors is added to the registry anyway. Leave root empty to scan the full storage dir.
func (reg *ResourceRegistry) ScanStorage(root string) error {
var lastError error
// prep root
if root == "" {
root = reg.storageDir.Path
} else {
var err error
root, err = filepath.Abs(root)
if err != nil {
return err
}
if !strings.HasPrefix(root, reg.storageDir.Path) {
return errors.New("supplied scan root path not within storage")
}
}
// walk fs
_ = filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
if err != nil {
lastError = fmt.Errorf("%s: could not read %s: %s", reg.Name, path, err)
log.Warning(lastError.Error())
return nil
}
// get relative path to storage
relativePath, err := filepath.Rel(reg.storageDir.Path, path)
if err != nil {
lastError = fmt.Errorf("%s: could not get relative path of %s: %s", reg.Name, path, err)
log.Warning(lastError.Error())
return nil
}
// ignore files in tmp dir
if strings.HasPrefix(relativePath, reg.tmpDir.Path) {
return nil
}
// convert to identifier and version
relativePath = filepath.ToSlash(relativePath)
identifier, version, ok := GetIdentifierAndVersion(relativePath)
if !ok {
// file does not conform to format
return nil
}
// save
err = reg.AddResource(identifier, version, true, false, false)
if err != nil {
lastError = fmt.Errorf("%s: could not get add resource %s v%s: %s", reg.Name, identifier, version, err)
log.Warning(lastError.Error())
}
return nil
})
return lastError
}
// LoadIndexes loads the current release indexes from disk and will fetch a new version if not available and online.
func (reg *ResourceRegistry) LoadIndexes() error {
err := reg.loadIndexFile("stable.json", true, false)
if err != nil {
err = reg.downloadIndex("stable.json", true, false)
if err != nil {
return err
}
}
err = reg.loadIndexFile("beta.json", false, true)
if err != nil {
err = reg.downloadIndex("beta.json", false, true)
if err != nil {
return err
}
}
return nil
}
func (reg *ResourceRegistry) loadIndexFile(name string, stableRelease, betaRelease bool) error {
data, err := ioutil.ReadFile(filepath.Join(reg.storageDir.Path, name))
if err != nil {
return err
}
releases := make(map[string]string)
err = json.Unmarshal(data, &releases)
if err != nil {
return err
}
if len(releases) == 0 {
return fmt.Errorf("%s is empty", name)
}
err = reg.AddResources(releases, false, stableRelease, betaRelease)
if err != nil {
log.Warningf("%s: failed to add resource: %s", reg.Name, err)
}
return nil
}
// CreateSymlinks creates a directory structure with unversions symlinks to the given updates list.
func (reg *ResourceRegistry) CreateSymlinks(symlinkRoot *utils.DirStructure) error {
err := os.RemoveAll(symlinkRoot.Path)
if err != nil {
return fmt.Errorf("failed to wipe symlink root: %s", err)
}
err = symlinkRoot.Ensure()
if err != nil {
return fmt.Errorf("failed to create symlink root: %s", err)
}
reg.RLock()
defer reg.RUnlock()
for _, res := range reg.resources {
if res.SelectedVersion == nil {
return fmt.Errorf("no selected version available for %s", res.Identifier)
}
targetPath := res.SelectedVersion.storagePath()
linkPath := filepath.Join(symlinkRoot.Path, filepath.FromSlash(res.Identifier))
linkPathDir := filepath.Dir(linkPath)
err = symlinkRoot.EnsureAbsPath(linkPathDir)
if err != nil {
return fmt.Errorf("failed to create dir for link: %s", err)
}
relativeTargetPath, err := filepath.Rel(linkPathDir, targetPath)
if err != nil {
return fmt.Errorf("failed to get relative target path: %s", err)
}
err = os.Symlink(relativeTargetPath, linkPath)
if err != nil {
return fmt.Errorf("failed to link %s: %s", res.Identifier, err)
}
}
return nil
}

68
updater/storage_test.go Normal file
View file

@ -0,0 +1,68 @@
package updater
/*
func testLoadLatestScope(t *testing.T, basePath, filePath, expectedIdentifier, expectedVersion string) {
fullPath := filepath.Join(basePath, filePath)
// create dir
dirPath := filepath.Dir(fullPath)
err := os.MkdirAll(dirPath, 0755)
if err != nil {
t.Fatalf("could not create test dir: %s\n", err)
return
}
// touch file
err = ioutil.WriteFile(fullPath, []byte{}, 0644)
if err != nil {
t.Fatalf("could not create test file: %s\n", err)
return
}
// run loadLatestScope
latest, err := ScanForLatest(basePath, true)
if err != nil {
t.Errorf("could not update latest: %s\n", err)
return
}
for key, val := range latest {
localUpdates[key] = val
}
// test result
version, ok := localUpdates[expectedIdentifier]
if !ok {
t.Errorf("identifier %s not in map", expectedIdentifier)
t.Errorf("current map: %v", localUpdates)
}
if version != expectedVersion {
t.Errorf("unexpected version for %s: %s", filePath, version)
}
}
func TestLoadLatestScope(t *testing.T) {
updatesLock.Lock()
defer updatesLock.Unlock()
tmpDir, err := ioutil.TempDir("", "testing_")
if err != nil {
t.Fatalf("could not create test dir: %s\n", err)
return
}
defer os.RemoveAll(tmpDir)
testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-3.zip", "all/ui/assets.zip", "1.2.3")
testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-4b.zip", "all/ui/assets.zip", "1.2.4b")
testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-5.zip", "all/ui/assets.zip", "1.2.5")
testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-3-4.zip", "all/ui/assets.zip", "1.3.4")
testLoadLatestScope(t, tmpDir, "all/ui/assets_v2-3-4.zip", "all/ui/assets.zip", "2.3.4")
testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-3.zip", "all/ui/assets.zip", "2.3.4")
testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-4.zip", "all/ui/assets.zip", "2.3.4")
testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-3-4.zip", "all/ui/assets.zip", "2.3.4")
testLoadLatestScope(t, tmpDir, "os_platform/portmaster/portmaster_v1-2-3", "os_platform/portmaster/portmaster", "1.2.3")
testLoadLatestScope(t, tmpDir, "os_platform/portmaster/portmaster_v2-1-1", "os_platform/portmaster/portmaster", "2.1.1")
testLoadLatestScope(t, tmpDir, "os_platform/portmaster/portmaster_v1-2-3", "os_platform/portmaster/portmaster", "2.1.1")
}
*/

125
updater/updating.go Normal file
View file

@ -0,0 +1,125 @@
package updater
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"github.com/safing/portbase/utils"
"github.com/safing/portbase/log"
)
// UpdateIndexes downloads the current update indexes.
func (reg *ResourceRegistry) UpdateIndexes() error {
err := reg.downloadIndex("stable.json", true, false)
if err != nil {
return err
}
return reg.downloadIndex("beta.json", false, true)
}
func (reg *ResourceRegistry) downloadIndex(name string, stableRelease, betaRelease bool) error {
var err error
var data []byte
// download new index
for tries := 0; tries < 3; tries++ {
data, err = reg.fetchData(name, tries)
if err == nil {
break
}
}
if err != nil {
return fmt.Errorf("failed to download index %s: %s", name, err)
}
// parse
new := make(map[string]string)
err = json.Unmarshal(data, &new)
if err != nil {
return fmt.Errorf("failed to parse index %s: %s", name, err)
}
// check for content
if len(new) == 0 {
return fmt.Errorf("index %s is empty", name)
}
// add resources to registry
_ = reg.AddResources(new, false, stableRelease, betaRelease)
// save index
err = ioutil.WriteFile(filepath.Join(reg.storageDir.Path, name), data, 0644)
if err != nil {
log.Warningf("%s: failed to save updated index %s: %s", reg.Name, name, err)
}
log.Infof("%s: updated index %s", reg.Name, name)
return nil
}
// DownloadUpdates checks if updates are available and downloads updates of used components.
func (reg *ResourceRegistry) DownloadUpdates(ctx context.Context) error {
// create list of downloads
var toUpdate []*ResourceVersion
reg.RLock()
for _, res := range reg.resources {
res.Lock()
// check if we want to download
if res.ActiveVersion != nil || // resource is currently being used
res.available() || // resource was used in the past
utils.StringInSlice(reg.MandatoryUpdates, res.Identifier) { // resource is mandatory
// add all non-available and eligible versions to update queue
for _, rv := range res.Versions {
if !rv.Available && (rv.StableRelease || reg.Beta && rv.BetaRelease) {
toUpdate = append(toUpdate, rv)
}
}
}
res.Unlock()
}
reg.RUnlock()
// nothing to update
if len(toUpdate) == 0 {
log.Infof("%s: everything up to date", reg.Name)
return nil
}
// check download dir
err := reg.tmpDir.Ensure()
if err != nil {
return fmt.Errorf("could not prepare tmp directory for download: %s", err)
}
// download updates
log.Infof("%s: starting to download %d updates", reg.Name, len(toUpdate))
for _, rv := range toUpdate {
for tries := 0; tries < 3; tries++ {
err = reg.fetchFile(rv, tries)
if err == nil {
break
}
}
if err != nil {
log.Warningf("%s: failed to download %s version %s: %s", reg.Name, rv.resource.Identifier, rv.VersionNumber, err)
}
}
log.Infof("%s: finished downloading updates", reg.Name)
// remove tmp folder after we are finished
err = os.RemoveAll(reg.tmpDir.Path)
if err != nil {
log.Tracef("%s: failed to remove tmp dir %s after downloading updates: %s", reg.Name, reg.tmpDir.Path, err)
}
return nil
}

1
updater/uptool/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
uptool

37
updater/uptool/root.go Normal file
View file

@ -0,0 +1,37 @@
package main
import (
"os"
"path/filepath"
"github.com/safing/portbase/updater"
"github.com/safing/portbase/utils"
"github.com/spf13/cobra"
)
var registry *updater.ResourceRegistry
var rootCmd = &cobra.Command{
Use: "uptool",
Short: "helper tool for the update process",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
return cmd.Usage()
},
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
absPath, err := filepath.Abs(args[0])
if err != nil {
return err
}
registry = &updater.ResourceRegistry{}
return registry.Initialize(utils.NewDirStructure(absPath, 0755))
},
SilenceUsage: true,
}
func main() {
if err := rootCmd.Execute(); err != nil {
os.Exit(1)
}
}

80
updater/uptool/scan.go Normal file
View file

@ -0,0 +1,80 @@
package main
import (
"encoding/json"
"fmt"
"io/ioutil"
"path/filepath"
"strings"
"github.com/spf13/cobra"
)
func init() {
rootCmd.AddCommand(scanCmd)
}
var scanCmd = &cobra.Command{
Use: "scan",
Short: "Scan the specified directory and print the result",
Args: cobra.ExactArgs(1),
RunE: scan,
}
func scan(cmd *cobra.Command, args []string) error {
err := scanStorage()
if err != nil {
return err
}
// export beta
data, err := json.MarshalIndent(exportSelected(true), "", " ")
if err != nil {
return err
}
// print
fmt.Println("beta:")
fmt.Println(string(data))
// export stable
data, err = json.MarshalIndent(exportSelected(false), "", " ")
if err != nil {
return err
}
// print
fmt.Println("\nstable:")
fmt.Println(string(data))
return nil
}
func scanStorage() error {
files, err := ioutil.ReadDir(registry.StorageDir().Path)
if err != nil {
return err
}
// scan "all" and all "os_platform" dirs
for _, file := range files {
if file.IsDir() && (file.Name() == "all" || strings.Contains(file.Name(), "_")) {
err := registry.ScanStorage(filepath.Join(registry.StorageDir().Path, file.Name()))
if err != nil {
return err
}
}
}
return nil
}
func exportSelected(beta bool) map[string]string {
registry.SetBeta(beta)
registry.SelectVersions()
export := registry.Export()
versions := make(map[string]string)
for _, rv := range export {
versions[rv.Identifier] = rv.SelectedVersion.VersionNumber
}
return versions
}

64
updater/uptool/update.go Normal file
View file

@ -0,0 +1,64 @@
package main
import (
"encoding/json"
"fmt"
"io/ioutil"
"path/filepath"
"github.com/spf13/cobra"
)
func init() {
rootCmd.AddCommand(updateCmd)
}
var updateCmd = &cobra.Command{
Use: "update",
Short: "Update scans the specified directory and registry the index and symlink structure",
Args: cobra.ExactArgs(1),
RunE: update,
}
func update(cmd *cobra.Command, args []string) error {
err := scanStorage()
if err != nil {
return err
}
// export beta
data, err := json.MarshalIndent(exportSelected(true), "", " ")
if err != nil {
return err
}
// print
fmt.Println("beta:")
fmt.Println(string(data))
// write index
err = ioutil.WriteFile(filepath.Join(registry.StorageDir().Dir, "beta.json"), data, 0755)
if err != nil {
return err
}
// export stable
data, err = json.MarshalIndent(exportSelected(false), "", " ")
if err != nil {
return err
}
// print
fmt.Println("\nstable:")
fmt.Println(string(data))
// write index
err = ioutil.WriteFile(filepath.Join(registry.StorageDir().Dir, "stable.json"), data, 0755)
if err != nil {
return err
}
// create symlinks
err = registry.CreateSymlinks(registry.StorageDir().ChildDir("latest", 0755))
if err != nil {
return err
}
fmt.Println("\nstable symlinks created")
return nil
}