Compare commits

..

No commits in common. "develop" and "v0.12.3" have entirely different histories.

245 changed files with 3001 additions and 6984 deletions

View file

@ -1,11 +0,0 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "gomod"
directory: "/"
schedule:
interval: "daily"

View file

@ -1,40 +0,0 @@
# Configuration for Label Actions - https://github.com/dessant/label-actions
community support:
comment: |
Hey @{issue-author}, thank you for raising this issue with us.
After a first review we noticed that this does not seem to be a technical issue, but rather a configuration issue or general question about how Portmaster works.
Thus, we invite the community to help with configuration and/or answering this questions.
If you are in a hurry or haven't received an answer, a good place to ask is in [our Discord community](https://discord.gg/safing).
If your problem or question has been resolved or answered, please come back and give an update here for other users encountering the same and then close this issue.
If you are a paying subscriber and want this issue to be checked out by Safing, please send us a message [on Discord](https://discord.gg/safing) or [via Email](mailto:support@safing.io) with your username and the link to this issue, so we can prioritize accordingly.
needs debug info:
comment: |
Hey @{issue-author}, thank you for raising this issue with us.
After a first review we noticed that we will require the Debug Info for further investigation. However, you haven't supplied any Debug Info in your report.
Please [collect Debug Info](https://wiki.safing.io/en/FAQ/DebugInfo) from Portmaster _while_ the reported issue is present.
in/compatibility:
comment: |
Hey @{issue-author}, thank you for reporting on a compatibility.
We keep a list of compatible software and user provided guides for improving compatibility [in the wiki - please have a look there](https://wiki.safing.io/en/Portmaster/App/Compatibility).
If you can't find your software in the list, then a good starting point is our guide on [How do I make software compatible with Portmaster](https://wiki.safing.io/en/FAQ/MakeSoftwareCompatibleWithPortmaster).
If you have managed to establish compatibility with an application, please share your findings here. This will greatly help other users encountering the same issues.
fixed:
comment: |
This issue has been fixed by the recently referenced commit or PR.
However, the fix is not released yet.
It is expected to go into the [Beta Release Channel](https://wiki.safing.io/en/FAQ/SwitchReleaseChannel) for testing within the next two weeks and will be available for everyone within the next four weeks. While this is the typical timeline we work with, things are subject to change.

View file

@ -15,41 +15,73 @@ jobs:
name: Linter
runs-on: ubuntu-latest
steps:
- name: Check out code
uses: actions/checkout@v3
- name: Check out code into the Go module directory
uses: actions/checkout@v2
- name: Setup Go
uses: actions/setup-go@v4
- uses: actions/setup-go@v2
with:
go-version: '^1.21'
go-version: '^1.15'
# nektos/act does not have sudo install but we need it on GH actions so
# try to install it.
- name: Install sudo
run: bash -c "apt-get update || true ; apt-get install sudo || true"
env:
DEBIAN_FRONTEND: noninteractive
- name: Install git and gcc
run: sudo bash -c "apt-get update && apt-get install -y git gcc libc6-dev"
env:
DEBIAN_FRONTEND: noninteractive
- name: Run golangci-lint
uses: golangci/golangci-lint-action@v2
with:
version: v1.29
only-new-issues: true
args: -c ./.golangci.yml
skip-go-installation: true
- name: Get dependencies
run: go mod download
- name: Run golangci-lint
uses: golangci/golangci-lint-action@v3
with:
version: v1.52.2
only-new-issues: true
args: -c ./.golangci.yml --timeout 15m
- name: Run go vet
run: go vet ./...
- name: Install golint
run: bash -c "GOBIN=$(pwd) go get -u golang.org/x/lint/golint"
- name: Run golint
run: ./golint -set_exit_status -min_confidence 1.0 ./...
- name: Run gofmt
run: bash -c "test -z $(gofmt -s -l .)"
test:
name: Test
runs-on: ubuntu-latest
steps:
- name: Check out code
uses: actions/checkout@v3
- name: Check out code into the Go module directory
uses: actions/checkout@v2
- name: Setup Go
uses: actions/setup-go@v4
- uses: actions/setup-go@v2
with:
go-version: '^1.21'
go-version: '^1.15'
# nektos/act does not have sudo install but we need it on GH actions so
# try to install it.
- name: Install sudo
run: bash -c "apt-get update || true ; apt-get install sudo || true"
env:
DEBIAN_FRONTEND: noninteractive
- name: Install git and gcc
run: sudo bash -c "apt-get update && apt-get install -y git gcc libc6-dev"
env:
DEBIAN_FRONTEND: noninteractive
- name: Get dependencies
run: go mod download
- name: Run tests
- name: Test
run: ./test --test-only

View file

@ -1,26 +0,0 @@
# This workflow responds to first time posters with a greeting message.
# Docs: https://github.com/actions/first-interaction
name: Greet New Users
# This workflow is triggered when a new issue is created.
on:
issues:
types: opened
permissions:
contents: read
issues: write
jobs:
greet:
runs-on: ubuntu-latest
steps:
- uses: actions/first-interaction@v1
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
# Respond to first time issue raisers.
issue-message: |
Greetings and welcome to our community! As this is the first issue you opened here, we wanted to share some useful infos with you:
- 🗣️ Our community on [Discord](https://discord.gg/safing) is super helpful and active. We also have an AI-enabled support bot that knows Portmaster well and can give you immediate help.
- 📖 The [Wiki](https://wiki.safing.io/) answers all common questions and has many important details. If you can't find an answer there, let us know, so we can add anything that's missing.

View file

@ -1,22 +0,0 @@
# This workflow responds with a message when certain labels are added to an issue or PR.
# Docs: https://github.com/dessant/label-actions
name: Label Actions
# This workflow is triggered when a label is added to an issue.
on:
issues:
types: labeled
permissions:
contents: read
issues: write
jobs:
action:
runs-on: ubuntu-latest
steps:
- uses: dessant/label-actions@v3
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
config-path: ".github/label-actions.yml"
process-only: "issues"

View file

@ -1,42 +0,0 @@
# This workflow warns and then closes stale issues and PRs.
# Docs: https://github.com/actions/stale
name: Close Stale Issues
on:
schedule:
- cron: "17 5 * * 1-5" # run at 5:17 (UTC) on Monday to Friday
workflow_dispatch:
permissions:
contents: read
issues: write
jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v8
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
# Increase max operations.
# When using GITHUB_TOKEN, the rate limit is 1,000 requests per hour per repository.
operations-per-run: 500
# Handle stale issues
stale-issue-label: 'stale'
# Exemptions
exempt-all-issue-assignees: true
exempt-issue-labels: 'support,dependencies,pinned,security'
# Mark as stale
days-before-issue-stale: 63 # 2 months / 9 weeks
stale-issue-message: |
This issue has been automatically marked as inactive because it has not had activity in the past two months.
If no further activity occurs, this issue will be automatically closed in one week in order to increase our focus on active topics.
# Close
days-before-issue-close: 7 # 1 week
close-issue-message: |
This issue has been automatically closed because it has not had recent activity. Thank you for your contributions.
If the issue has not been resolved, you can [find more information in our Wiki](https://wiki.safing.io/) or [continue the conversation on our Discord](https://discord.gg/safing).
# TODO: Handle stale PRs
days-before-pr-stale: 36500 # 100 years - effectively disabled.

2
.gitignore vendored
View file

@ -4,5 +4,3 @@ misc
go.mod.*
vendor
go.work
go.work.sum

View file

@ -1,72 +1,19 @@
# Docs:
# https://golangci-lint.run/usage/linters/
linters:
enable-all: true
disable:
- containedctx
- contextcheck
- cyclop
- depguard
- exhaustivestruct
- exhaustruct
- forbidigo
- funlen
- gochecknoglobals
- gochecknoinits
- gocognit
- gocyclo
- goerr113
- gomnd
- ifshort
- interfacebloat
- interfacer
- ireturn
- lll
- musttag
- nestif
- nilnil
- nlreturn
- noctx
- nolintlint
- nonamedreturns
- nosnakecase
- revive
- tagliatelle
- testpackage
- varnamelen
- gochecknoinits
- gochecknoglobals
- funlen
- whitespace
- wrapcheck
- wsl
- gomnd
- goerr113
- testpackage
linters-settings:
revive:
# See https://github.com/mgechev/revive#available-rules for details.
enable-all-rules: true
gci:
# put imports beginning with prefix after 3rd-party packages;
# only support one prefix
# if not set, use goimports.local-prefixes
local-prefixes: github.com/safing
godox:
# report any comments starting with keywords, this is useful for TODO or FIXME comments that
# might be left in the code accidentally and should be resolved before merging
keywords:
- FIXME
gosec:
# To specify a set of rules to explicitly exclude.
# Available rules: https://github.com/securego/gosec#available-rules
excludes:
- G204 # Variables in commands.
- G304 # Variables in file paths.
- G505 # We need crypto/sha1 for non-security stuff. Using `nolint:` triggers another linter.
issues:
exclude-use-default: false
exclude-rules:
- text: "a blank import .*"
linters:
- golint
- text: "ST1000: at least one file in a package should have a package comment.*"
linters:
- stylecheck

View file

@ -37,7 +37,6 @@ type endpointBridgeStorage struct {
storage.InjectBase
}
// EndpointBridgeRequest holds a bridged request API request.
type EndpointBridgeRequest struct {
record.Base
sync.Mutex
@ -49,7 +48,6 @@ type EndpointBridgeRequest struct {
MimeType string
}
// EndpointBridgeResponse holds a bridged request API response.
type EndpointBridgeResponse struct {
record.Base
sync.Mutex

View file

@ -13,9 +13,9 @@ import (
"github.com/tevino/abool"
"github.com/safing/portbase/config"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/log"
"github.com/safing/portbase/rng"
)
@ -84,9 +84,8 @@ type AuthenticatorFunc func(r *http.Request, s *http.Server) (*AuthToken, error)
// later. Functions may be called at any time.
// The Write permission implicitly also includes reading.
type AuthToken struct {
Read Permission
Write Permission
ValidUntil *time.Time
Read Permission
Write Permission
}
type session struct {
@ -134,13 +133,16 @@ func SetAuthenticator(fn AuthenticatorFunc) error {
return nil
}
func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler http.Handler, readMethod bool) *AuthToken {
func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler http.Handler) *AuthToken {
tracer := log.Tracer(r.Context())
// Check if request is read only.
readRequest := isReadMethod(r.Method)
// Get required permission for target handler.
requiredPermission := PermitSelf
if authdHandler, ok := targetHandler.(AuthenticatedHandler); ok {
if readMethod {
if readRequest {
requiredPermission = authdHandler.ReadPermission(r)
} else {
requiredPermission = authdHandler.WritePermission(r)
@ -148,10 +150,10 @@ func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler h
}
// Check if we need to do any authentication at all.
switch requiredPermission { //nolint:exhaustive
switch requiredPermission {
case NotFound:
// Not found.
tracer.Debug("api: no API endpoint registered for this path")
tracer.Trace("api: authenticated handler reported: not found")
http.Error(w, "Not found.", http.StatusNotFound)
return nil
case NotSupported:
@ -198,7 +200,7 @@ func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler h
// Get effective permission for request.
var requestPermission Permission
if readMethod {
if readRequest {
requestPermission = token.Read
} else {
requestPermission = token.Write
@ -219,10 +221,7 @@ func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler h
if requestPermission < requiredPermission {
// If the token is strictly public, return an authentication request.
if token.Read == PermitAnyone && token.Write == PermitAnyone {
w.Header().Set(
"WWW-Authenticate",
`Bearer realm="Portmaster API" domain="/"`,
)
w.Header().Set("WWW-Authenticate", "Bearer realm=Portmaster API")
http.Error(w, "Authorization required.", http.StatusUnauthorized)
return nil
}
@ -342,12 +341,6 @@ func checkAPIKey(r *http.Request) *AuthToken {
return nil
}
// Abort if the token is expired.
if token.ValidUntil != nil && time.Now().After(*token.ValidUntil) {
log.Tracer(r.Context()).Warningf("api: denying api access from %s using expired token", r.RemoteAddr)
return nil
}
return token
}
@ -362,26 +355,15 @@ func updateAPIKeys(_ context.Context, _ interface{}) error {
delete(apiKeys, k)
}
// whether or not we found expired API keys that should be removed
// from the setting
hasExpiredKeys := false
// a list of valid API keys. Used when hasExpiredKeys is set to true.
// in that case we'll update the setting to only contain validAPIKeys
validAPIKeys := []string{}
// Parse new keys.
for _, key := range configuredAPIKeys() {
u, err := url.Parse(key)
if err != nil {
log.Errorf("api: failed to parse configured API key %s: %s", key, err)
continue
}
if u.Path == "" {
log.Errorf("api: malformed API key %s: missing path section", key)
continue
}
@ -408,40 +390,8 @@ func updateAPIKeys(_ context.Context, _ interface{}) error {
}
token.Write = writePermission
expireStr := q.Get("expires")
if expireStr != "" {
validUntil, err := time.Parse(time.RFC3339, expireStr)
if err != nil {
log.Errorf("api: invalid API key %s: %s", key, err)
continue
}
// continue to the next token if this one is already invalid
if time.Now().After(validUntil) {
// mark the key as expired so we'll remove it from the setting afterwards
hasExpiredKeys = true
continue
}
token.ValidUntil = &validUntil
}
// Save token.
apiKeys[u.Path] = token
validAPIKeys = append(validAPIKeys, key)
}
if hasExpiredKeys {
module.StartLowPriorityMicroTask("api key cleanup", 0, func(ctx context.Context) error {
if err := config.SetConfigOption(CfgAPIKeys, validAPIKeys); err != nil {
log.Errorf("api: failed to remove expired API keys: %s", err)
} else {
log.Infof("api: removed expired API keys from %s", CfgAPIKeys)
}
return nil
})
}
return nil
@ -527,24 +477,12 @@ func deleteSession(sessionKey string) {
delete(sessions, sessionKey)
}
func getEffectiveMethod(r *http.Request) (eMethod string, readMethod bool, ok bool) {
method := r.Method
// Get CORS request method if OPTIONS request.
if r.Method == http.MethodOptions {
method = r.Header.Get("Access-Control-Request-Method")
if method == "" {
return "", false, false
}
}
func isReadMethod(method string) bool {
switch method {
case http.MethodGet, http.MethodHead:
return http.MethodGet, true, true
case http.MethodPost, http.MethodPut, http.MethodDelete:
return method, false, true
case http.MethodGet, http.MethodHead, http.MethodOptions:
return true
default:
return "", false, false
return false
}
}
@ -593,8 +531,6 @@ func (p Permission) Role() string {
return "Admin"
case PermitSelf:
return "Self"
case Dynamic, NotFound, NotSupported:
return "Invalid"
default:
return "Invalid"
}

View file

@ -9,7 +9,9 @@ import (
"github.com/stretchr/testify/assert"
)
var testToken = new(AuthToken)
var (
testToken = new(AuthToken)
)
func testAuthenticator(r *http.Request, s *http.Server) (*AuthToken, error) {
switch {
@ -63,9 +65,7 @@ func init() {
}
}
func TestPermissions(t *testing.T) {
t.Parallel()
func TestPermissions(t *testing.T) { //nolint:gocognit
testHandler := &mainHandler{
mux: mainMux,
}
@ -99,11 +99,10 @@ func TestPermissions(t *testing.T) {
http.MethodHead,
http.MethodPost,
http.MethodPut,
http.MethodDelete,
} {
// Set request permission for test requests.
_, reading, _ := getEffectiveMethod(&http.Request{Method: method})
reading := isReadMethod(method)
if reading {
testToken.Read = requestPerm
testToken.Write = NotSupported
@ -148,6 +147,7 @@ func TestPermissions(t *testing.T) {
}
if expectSuccess {
// Test for success.
if !assert.HTTPBodyContains(
t,
@ -164,7 +164,9 @@ func TestPermissions(t *testing.T) {
handlerPerm, handlerPerm,
)
}
} else {
// Test for error.
if !assert.HTTPError(t,
testHandler.ServeHTTP,
@ -179,6 +181,7 @@ func TestPermissions(t *testing.T) {
handlerPerm, handlerPerm,
)
}
}
}
}
@ -186,8 +189,6 @@ func TestPermissions(t *testing.T) {
}
func TestPermissionDefinitions(t *testing.T) {
t.Parallel()
if NotSupported != 0 {
t.Fatalf("NotSupported must be zero, was %v", NotSupported)
}

View file

@ -5,9 +5,9 @@ import (
"sync"
"time"
"github.com/tevino/abool"
"github.com/safing/portbase/log"
"github.com/tevino/abool"
)
const (

View file

@ -25,4 +25,6 @@ const (
apiSeperator = "|"
)
var apiSeperatorBytes = []byte(apiSeperator)
var (
apiSeperatorBytes = []byte(apiSeperator)
)

View file

@ -4,14 +4,15 @@ import (
"bytes"
"errors"
"github.com/tevino/abool"
"github.com/safing/portbase/container"
"github.com/safing/portbase/formats/dsd"
"github.com/tevino/abool"
)
// ErrMalformedMessage is returned when a malformed message was encountered.
var ErrMalformedMessage = errors.New("malformed message")
// Client errors.
var (
ErrMalformedMessage = errors.New("malformed message")
)
// Message is an API message.
type Message struct {

View file

@ -4,10 +4,10 @@ import (
"fmt"
"sync"
"github.com/gorilla/websocket"
"github.com/safing/portbase/log"
"github.com/tevino/abool"
"github.com/safing/portbase/log"
"github.com/gorilla/websocket"
)
type wsState struct {
@ -41,7 +41,7 @@ func (c *Client) wsConnect() error {
case <-c.shutdownSignal:
state.Error("")
}
_ = state.wsConn.Close()
state.wsConn.Close()
state.wg.Wait()
return nil

View file

@ -64,8 +64,7 @@ func registerConfig() error {
err = config.Register(&config.Option{
Name: "API Keys",
Key: CfgAPIKeys,
Description: "Define API keys for privileged access to the API. Every entry is a separate API key with respective permissions. Format is `<key>?read=<perm>&write=<perm>`. Permissions are `anyone`, `user` and `admin`, and may be omitted.",
Sensitive: true,
Description: "Define API keys for priviledged access to the API. Every entry is a separate API key with respective permissions. Format is `<key>?read=<perm>&write=<perm>`. Permissions are `anyone`, `user` and `admin`, and may be omitted.",
OptType: config.OptTypeStringArray,
ExpertiseLevel: config.ExpertiseLevelDeveloper,
ReleaseLevel: config.ReleaseLevelStable,

View file

@ -8,18 +8,19 @@ import (
"net/http"
"sync"
"github.com/tidwall/sjson"
"github.com/safing/portbase/database/iterator"
"github.com/safing/portbase/formats/varint"
"github.com/gorilla/websocket"
"github.com/tevino/abool"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/safing/portbase/container"
"github.com/safing/portbase/database"
"github.com/safing/portbase/database/iterator"
"github.com/safing/portbase/database/query"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/formats/varint"
"github.com/safing/portbase/log"
)
@ -44,7 +45,7 @@ var (
func init() {
RegisterHandler("/api/database/v1", WrapInAuthHandler(
startDatabaseWebsocketAPI,
startDatabaseAPI,
// Default to admin read/write permissions until the database gets support
// for api permissions.
dbCompatibilityPermission,
@ -52,8 +53,11 @@ func init() {
))
}
// DatabaseAPI is a generic database API interface.
// DatabaseAPI is a database API instance.
type DatabaseAPI struct {
conn *websocket.Conn
sendQueue chan []byte
queriesLock sync.Mutex
queries map[string]*iterator.Iterator
@ -63,35 +67,14 @@ type DatabaseAPI struct {
shutdownSignal chan struct{}
shuttingDown *abool.AtomicBool
db *database.Interface
sendBytes func(data []byte)
}
// DatabaseWebsocketAPI is a database websocket API interface.
type DatabaseWebsocketAPI struct {
DatabaseAPI
sendQueue chan []byte
conn *websocket.Conn
}
func allowAnyOrigin(r *http.Request) bool {
return true
}
// CreateDatabaseAPI creates a new database interface.
func CreateDatabaseAPI(sendFunction func(data []byte)) DatabaseAPI {
return DatabaseAPI{
queries: make(map[string]*iterator.Iterator),
subs: make(map[string]*database.Subscription),
shutdownSignal: make(chan struct{}),
shuttingDown: abool.NewBool(false),
db: database.NewInterface(nil),
sendBytes: sendFunction,
}
}
func startDatabaseAPI(w http.ResponseWriter, r *http.Request) {
func startDatabaseWebsocketAPI(w http.ResponseWriter, r *http.Request) {
upgrader := websocket.Upgrader{
CheckOrigin: allowAnyOrigin,
ReadBufferSize: 1024,
@ -101,104 +84,28 @@ func startDatabaseWebsocketAPI(w http.ResponseWriter, r *http.Request) {
if err != nil {
errMsg := fmt.Sprintf("could not upgrade: %s", err)
log.Error(errMsg)
http.Error(w, errMsg, http.StatusBadRequest)
http.Error(w, errMsg, 400)
return
}
newDBAPI := &DatabaseWebsocketAPI{
DatabaseAPI: DatabaseAPI{
queries: make(map[string]*iterator.Iterator),
subs: make(map[string]*database.Subscription),
shutdownSignal: make(chan struct{}),
shuttingDown: abool.NewBool(false),
db: database.NewInterface(nil),
},
sendQueue: make(chan []byte, 100),
conn: wsConn,
new := &DatabaseAPI{
conn: wsConn,
sendQueue: make(chan []byte, 100),
queries: make(map[string]*iterator.Iterator),
subs: make(map[string]*database.Subscription),
shutdownSignal: make(chan struct{}),
shuttingDown: abool.NewBool(false),
db: database.NewInterface(nil),
}
newDBAPI.sendBytes = func(data []byte) {
newDBAPI.sendQueue <- data
}
module.StartWorker("database api handler", newDBAPI.handler)
module.StartWorker("database api writer", newDBAPI.writer)
module.StartWorker("database api handler", new.handler)
module.StartWorker("database api writer", new.writer)
log.Tracer(r.Context()).Infof("api request: init websocket %s %s", r.RemoteAddr, r.RequestURI)
}
func (api *DatabaseWebsocketAPI) handler(context.Context) error {
defer func() {
_ = api.shutdown(nil)
}()
func (api *DatabaseAPI) handler(context.Context) error {
for {
_, msg, err := api.conn.ReadMessage()
if err != nil {
return api.shutdown(err)
}
api.Handle(msg)
}
}
func (api *DatabaseWebsocketAPI) writer(ctx context.Context) error {
defer func() {
_ = api.shutdown(nil)
}()
var data []byte
var err error
for {
select {
// prioritize direct writes
case data = <-api.sendQueue:
if len(data) == 0 {
return nil
}
case <-ctx.Done():
return nil
case <-api.shutdownSignal:
return nil
}
// log.Tracef("api: sending %s", string(*msg))
err = api.conn.WriteMessage(websocket.BinaryMessage, data)
if err != nil {
return api.shutdown(err)
}
}
}
func (api *DatabaseWebsocketAPI) shutdown(err error) error {
// Check if we are the first to shut down.
if !api.shuttingDown.SetToIf(false, true) {
return nil
}
// Check the given error.
if err != nil {
if websocket.IsCloseError(err,
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseAbnormalClosure,
) {
log.Infof("api: websocket connection to %s closed", api.conn.RemoteAddr())
} else {
log.Warningf("api: websocket connection error with %s: %s", api.conn.RemoteAddr(), err)
}
}
// Trigger shutdown.
close(api.shutdownSignal)
_ = api.conn.Close()
return nil
}
// Handle handles a message for the database API.
func (api *DatabaseAPI) Handle(msg []byte) {
// 123|get|<key>
// 123|ok|<key>|<data>
// 123|error|<message>
@ -237,62 +144,120 @@ func (api *DatabaseAPI) Handle(msg []byte) {
// 131|success
// 131|error|<message>
parts := bytes.SplitN(msg, []byte("|"), 3)
for {
// Handle special command "cancel"
if len(parts) == 2 && string(parts[1]) == "cancel" {
// 124|cancel
// 125|cancel
// 127|cancel
go api.handleCancel(parts[0])
return
}
_, msg, err := api.conn.ReadMessage()
if err != nil {
return api.shutdown(err)
}
if len(parts) != 3 {
api.send(nil, dbMsgTypeError, "bad request: malformed message", nil)
return
}
parts := bytes.SplitN(msg, []byte("|"), 3)
switch string(parts[1]) {
case "get":
// 123|get|<key>
go api.handleGet(parts[0], string(parts[2]))
case "query":
// 124|query|<query>
go api.handleQuery(parts[0], string(parts[2]))
case "sub":
// 125|sub|<query>
go api.handleSub(parts[0], string(parts[2]))
case "qsub":
// 127|qsub|<query>
go api.handleQsub(parts[0], string(parts[2]))
case "create", "update", "insert":
// split key and payload
dataParts := bytes.SplitN(parts[2], []byte("|"), 2)
if len(dataParts) != 2 {
// Handle special command "cancel"
if len(parts) == 2 && string(parts[1]) == "cancel" {
// 124|cancel
// 125|cancel
// 127|cancel
go api.handleCancel(parts[0])
continue
}
if len(parts) != 3 {
api.send(nil, dbMsgTypeError, "bad request: malformed message", nil)
return
continue
}
switch string(parts[1]) {
case "create":
// 128|create|<key>|<data>
go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], true)
case "update":
// 129|update|<key>|<data>
go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], false)
case "insert":
// 130|insert|<key>|<data>
go api.handleInsert(parts[0], string(dataParts[0]), dataParts[1])
case "get":
// 123|get|<key>
go api.handleGet(parts[0], string(parts[2]))
case "query":
// 124|query|<query>
go api.handleQuery(parts[0], string(parts[2]))
case "sub":
// 125|sub|<query>
go api.handleSub(parts[0], string(parts[2]))
case "qsub":
// 127|qsub|<query>
go api.handleQsub(parts[0], string(parts[2]))
case "create", "update", "insert":
// split key and payload
dataParts := bytes.SplitN(parts[2], []byte("|"), 2)
if len(dataParts) != 2 {
api.send(nil, dbMsgTypeError, "bad request: malformed message", nil)
continue
}
switch string(parts[1]) {
case "create":
// 128|create|<key>|<data>
go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], true)
case "update":
// 129|update|<key>|<data>
go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], false)
case "insert":
// 130|insert|<key>|<data>
go api.handleInsert(parts[0], string(dataParts[0]), dataParts[1])
}
case "delete":
// 131|delete|<key>
go api.handleDelete(parts[0], string(parts[2]))
default:
api.send(parts[0], dbMsgTypeError, "bad request: unknown method", nil)
}
case "delete":
// 131|delete|<key>
go api.handleDelete(parts[0], string(parts[2]))
default:
api.send(parts[0], dbMsgTypeError, "bad request: unknown method", nil)
}
}
func (api *DatabaseAPI) writer(ctx context.Context) error {
var data []byte
var err error
for {
select {
// prioritize direct writes
case data = <-api.sendQueue:
if len(data) == 0 {
return api.shutdown(nil)
}
case <-ctx.Done():
return api.shutdown(nil)
case <-api.shutdownSignal:
return api.shutdown(nil)
}
// log.Tracef("api: sending %s", string(*msg))
err = api.conn.WriteMessage(websocket.BinaryMessage, data)
if err != nil {
return api.shutdown(err)
}
}
}
func (api *DatabaseAPI) shutdown(err error) error {
// Check if we are the first to shut down.
if !api.shuttingDown.SetToIf(false, true) {
return nil
}
// Check the given error.
if err != nil {
if websocket.IsCloseError(err,
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseAbnormalClosure,
) {
log.Infof("api: websocket connection to %s closed", api.conn.RemoteAddr())
} else {
log.Warningf("api: websocket connection error with %s: %s", api.conn.RemoteAddr(), err)
}
}
// Trigger shutdown.
close(api.shutdownSignal)
api.conn.Close()
return nil
}
func (api *DatabaseAPI) send(opID []byte, msgType string, msgOrKey string, data []byte) {
c := container.New(opID)
c.Append(dbAPISeperatorBytes)
@ -308,7 +273,7 @@ func (api *DatabaseAPI) send(opID []byte, msgType string, msgOrKey string, data
c.Append(data)
}
api.sendBytes(c.CompileData())
api.sendQueue <- c.CompileData()
}
func (api *DatabaseAPI) handleGet(opID []byte, key string) {
@ -320,7 +285,7 @@ func (api *DatabaseAPI) handleGet(opID []byte, key string) {
r, err := api.db.Get(key)
if err == nil {
data, err = MarshalRecord(r, true)
data, err = marshalRecord(r, true)
}
if err != nil {
api.send(opID, dbMsgTypeError, err.Error(), nil)
@ -373,15 +338,14 @@ func (api *DatabaseAPI) processQuery(opID []byte, q *query.Query) (ok bool) {
case <-api.shutdownSignal:
// cancel query and return
it.Cancel()
return false
return
case r := <-it.Next:
// process query feed
if r != nil {
// process record
data, err := MarshalRecord(r, true)
data, err := marshalRecord(r, true)
if err != nil {
api.send(opID, dbMsgTypeWarning, err.Error(), nil)
continue
}
api.send(opID, dbMsgTypeOk, r.Key(), data)
} else {
@ -397,7 +361,7 @@ func (api *DatabaseAPI) processQuery(opID []byte, q *query.Query) (ok bool) {
}
}
// func (api *DatabaseWebsocketAPI) runQuery()
// func (api *DatabaseAPI) runQuery()
func (api *DatabaseAPI) handleSub(opID []byte, queryText string) {
// 125|sub|<query>
@ -455,7 +419,7 @@ func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) {
// process sub feed
if r != nil {
// process record
data, err := MarshalRecord(r, true)
data, err := marshalRecord(r, true)
if err != nil {
api.send(opID, dbMsgTypeWarning, err.Error(), nil)
continue
@ -463,12 +427,12 @@ func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) {
// TODO: use upd, new and delete msgTypes
r.Lock()
isDeleted := r.Meta().IsDeleted()
isNew := r.Meta().Created == r.Meta().Modified
new := r.Meta().Created == r.Meta().Modified
r.Unlock()
switch {
case isDeleted:
api.send(opID, dbMsgTypeDel, r.Key(), nil)
case isNew:
case new:
api.send(opID, dbMsgTypeNew, r.Key(), data)
default:
api.send(opID, dbMsgTypeUpd, r.Key(), data)
@ -569,9 +533,9 @@ func (api *DatabaseAPI) handlePut(opID []byte, key string, data []byte, create b
}
// TODO - staged for deletion: remove transition code
// if data[0] != dsd.JSON {
// if data[0] != record.JSON {
// typedData := make([]byte, len(data)+1)
// typedData[0] = dsd.JSON
// typedData[0] = record.JSON
// copy(typedData[1:], data)
// data = typedData
// }
@ -659,20 +623,20 @@ func (api *DatabaseAPI) handleDelete(opID []byte, key string) {
api.send(opID, dbMsgTypeSuccess, emptyString, nil)
}
// MarshalRecord locks and marshals the given record, additionally adding
// marsharlRecords locks and marshals the given record, additionally adding
// metadata and returning it as json.
func MarshalRecord(r record.Record, withDSDIdentifier bool) ([]byte, error) {
func marshalRecord(r record.Record, withDSDIdentifier bool) ([]byte, error) {
r.Lock()
defer r.Unlock()
// Pour record into JSON.
jsonData, err := r.Marshal(r, dsd.JSON)
jsonData, err := r.Marshal(r, record.JSON)
if err != nil {
return nil, err
}
// Remove JSON identifier for manual editing.
jsonData = bytes.TrimPrefix(jsonData, varint.Pack8(dsd.JSON))
jsonData = bytes.TrimPrefix(jsonData, varint.Pack8(record.JSON))
// Add metadata.
jsonData, err = sjson.SetBytes(jsonData, "_meta", r.Meta())
@ -688,7 +652,7 @@ func MarshalRecord(r record.Record, withDSDIdentifier bool) ([]byte, error) {
// Add JSON identifier again.
if withDSDIdentifier {
formatID := varint.Pack8(dsd.JSON)
formatID := varint.Pack8(record.JSON)
finalData := make([]byte, 0, len(formatID)+len(jsonData))
finalData = append(finalData, formatID...)
finalData = append(finalData, jsonData...)

View file

@ -1,10 +1,10 @@
package api
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"sort"
"strconv"
@ -14,7 +14,6 @@ import (
"github.com/gorilla/mux"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/log"
"github.com/safing/portbase/modules"
)
@ -22,49 +21,12 @@ import (
// Endpoint describes an API Endpoint.
// Path and at least one permission are required.
// As is exactly one function.
type Endpoint struct { //nolint:maligned
// Name is the human reabable name of the endpoint.
Name string
// Description is the human readable description and documentation of the endpoint.
Description string
// Parameters is the parameter documentation.
Parameters []Parameter `json:",omitempty"`
// Path describes the URL path of the endpoint.
Path string
// MimeType defines the content type of the returned data.
MimeType string
// Read defines the required read permission.
Read Permission `json:",omitempty"`
// ReadMethod sets the required read method for the endpoint.
// Available methods are:
// GET: Returns data only, no action is taken, nothing is changed.
// If omitted, defaults to GET.
//
// This field is currently being introduced and will only warn and not deny
// access if the write method does not match.
ReadMethod string `json:",omitempty"`
// Write defines the required write permission.
Write Permission `json:",omitempty"`
// WriteMethod sets the required write method for the endpoint.
// Available methods are:
// POST: Create a new resource; Change a status; Execute a function
// PUT: Update an existing resource
// DELETE: Remove an existing resource
// If omitted, defaults to POST.
//
// This field is currently being introduced and will only warn and not deny
// access if the write method does not match.
WriteMethod string `json:",omitempty"`
// BelongsTo defines which module this endpoint belongs to.
// The endpoint will not be accessible if the module is not online.
BelongsTo *modules.Module `json:"-"`
type Endpoint struct {
Path string
MimeType string
Read Permission
Write Permission
BelongsTo *modules.Module
// ActionFunc is for simple actions with a return message for the user.
ActionFunc ActionFunc `json:"-"`
@ -81,6 +43,12 @@ type Endpoint struct { //nolint:maligned
// HandlerFunc is the raw http handler.
HandlerFunc http.HandlerFunc `json:"-"`
// Documentation Metadata.
Name string
Description string
Parameters []Parameter
}
// Parameter describes a parameterized variation of an endpoint.
@ -91,41 +59,6 @@ type Parameter struct {
Description string
}
// HTTPStatusProvider is an interface for errors to provide a custom HTTP
// status code.
type HTTPStatusProvider interface {
HTTPStatus() int
}
// HTTPStatusError represents an error with an HTTP status code.
type HTTPStatusError struct {
err error
code int
}
// Error returns the error message.
func (e *HTTPStatusError) Error() string {
return e.err.Error()
}
// Unwrap return the wrapped error.
func (e *HTTPStatusError) Unwrap() error {
return e.err
}
// HTTPStatus returns the HTTP status code this error.
func (e *HTTPStatusError) HTTPStatus() int {
return e.code
}
// ErrorWithStatus adds the HTTP status code to the error.
func ErrorWithStatus(err error, code int) error {
return &HTTPStatusError{
err: err,
code: code,
}
}
type (
// ActionFunc is for simple actions with a return message for the user.
ActionFunc func(ar *Request) (msg string, err error)
@ -209,7 +142,7 @@ func getAPIContext(r *http.Request) (apiEndpoint *Endpoint, apiRequest *Request)
// does not pass the sanity checks.
func RegisterEndpoint(e Endpoint) error {
if err := e.check(); err != nil {
return fmt.Errorf("%w: %w", ErrInvalidEndpoint, err)
return fmt.Errorf("%w: %s", ErrInvalidEndpoint, err)
}
endpointsLock.Lock()
@ -225,18 +158,6 @@ func RegisterEndpoint(e Endpoint) error {
return nil
}
// GetEndpointByPath returns the endpoint registered with the given path.
func GetEndpointByPath(path string) (*Endpoint, error) {
endpointsLock.Lock()
defer endpointsLock.Unlock()
endpoint, ok := endpoints[path]
if !ok {
return nil, fmt.Errorf("no registered endpoint on path: %q", path)
}
return endpoint, nil
}
func (e *Endpoint) check() error {
// Check path.
if strings.TrimSpace(e.Path) == "" {
@ -251,36 +172,6 @@ func (e *Endpoint) check() error {
return errors.New("invalid write permission")
}
// Check methods.
if e.Read != NotSupported {
switch e.ReadMethod {
case http.MethodGet:
// All good.
case "":
// Set to default.
e.ReadMethod = http.MethodGet
default:
return errors.New("invalid read method")
}
} else {
e.ReadMethod = ""
}
if e.Write != NotSupported {
switch e.WriteMethod {
case http.MethodPost,
http.MethodPut,
http.MethodDelete:
// All good.
case "":
// Set to default.
e.WriteMethod = http.MethodPost
default:
return errors.New("invalid write method")
}
} else {
e.WriteMethod = ""
}
// Check functions.
var defaultMimeType string
fnCnt := 0
@ -381,45 +272,14 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Wait for the owning module to be ready.
if !moduleIsReady(e.BelongsTo) {
http.Error(w, "The API endpoint is not ready yet or the its module is not enabled. Reload (F5) to try again.", http.StatusServiceUnavailable)
http.Error(w, "The API endpoint is not ready yet or the its module is not enabled. Please try again later.", http.StatusServiceUnavailable)
return
}
// Return OPTIONS request before starting to handle normal requests.
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
switch r.Method {
case http.MethodHead:
w.WriteHeader(http.StatusOK)
return
}
eMethod, readMethod, ok := getEffectiveMethod(r)
if !ok {
http.Error(w, "unsupported method for the actions API", http.StatusMethodNotAllowed)
return
}
if readMethod {
if eMethod != e.ReadMethod {
log.Tracer(r.Context()).Warningf(
"api: method %q does not match required read method %q%s",
r.Method,
e.ReadMethod,
" - this will be an error and abort the request in the future",
)
}
} else {
if eMethod != e.WriteMethod {
log.Tracer(r.Context()).Warningf(
"api: method %q does not match required write method %q%s",
r.Method,
e.WriteMethod,
" - this will be an error and abort the request in the future",
)
}
}
switch eMethod {
case http.MethodGet, http.MethodDelete:
// Nothing to do for these.
case http.MethodPost, http.MethodPut:
// Read body data.
inputData, ok := readBody(w, r)
@ -427,18 +287,16 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
apiRequest.InputData = inputData
// restore request body for any http.HandlerFunc below
r.Body = io.NopCloser(bytes.NewReader(inputData))
case http.MethodGet:
// Nothing special to do here.
case http.MethodOptions:
w.WriteHeader(http.StatusNoContent)
return
default:
// Defensive.
http.Error(w, "unsupported method for the actions API", http.StatusMethodNotAllowed)
return
}
// Add response headers to request struct so that the endpoint can work with them.
apiRequest.ResponseHeader = w.Header()
// Execute action function and get response data
var responseData []byte
var err error
@ -447,9 +305,6 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case e.ActionFunc != nil:
var msg string
msg, err = e.ActionFunc(apiRequest)
if !strings.HasSuffix(msg, "\n") {
msg += "\n"
}
if err == nil {
responseData = []byte(msg)
}
@ -461,18 +316,14 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var v interface{}
v, err = e.StructFunc(apiRequest)
if err == nil && v != nil {
var mimeType string
responseData, mimeType, _, err = dsd.MimeDump(v, r.Header.Get("Accept"))
if err == nil {
w.Header().Set("Content-Type", mimeType)
}
responseData, err = json.Marshal(v)
}
case e.RecordFunc != nil:
var rec record.Record
rec, err = e.RecordFunc(apiRequest)
if err == nil && r != nil {
responseData, err = MarshalRecord(rec, false)
responseData, err = marshalRecord(rec, false)
}
case e.HandlerFunc != nil:
@ -486,27 +337,12 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Check for handler error.
if err != nil {
var statusProvider HTTPStatusProvider
if errors.As(err, &statusProvider) {
http.Error(w, err.Error(), statusProvider.HTTPStatus())
} else {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Return no content if there is none, or if request is HEAD.
if len(responseData) == 0 || r.Method == http.MethodHead {
w.WriteHeader(http.StatusNoContent)
return
}
// Set content type if not yet set.
if w.Header().Get("Content-Type") == "" {
w.Header().Set("Content-Type", e.MimeType+"; charset=utf-8")
}
// Write response.
w.Header().Set("Content-Type", e.MimeType+"; charset=utf-8")
w.Header().Set("Content-Length", strconv.Itoa(len(responseData)))
w.WriteHeader(http.StatusOK)
_, err = w.Write(responseData)
@ -523,7 +359,7 @@ func readBody(w http.ResponseWriter, r *http.Request) (inputData []byte, ok bool
}
// Read and close body.
inputData, err := io.ReadAll(r.Body)
inputData, err := ioutil.ReadAll(r.Body)
if err != nil {
http.Error(w, "failed to read body"+err.Error(), http.StatusInternalServerError)
return nil, false

View file

@ -2,17 +2,11 @@ package api
import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"os"
"runtime/pprof"
"strings"
"time"
"github.com/safing/portbase/info"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/utils/debug"
)
@ -27,16 +21,6 @@ func registerDebugEndpoints() error {
return err
}
if err := RegisterEndpoint(Endpoint{
Path: "ready",
Read: PermitAnyone,
ActionFunc: ready,
Name: "Ready",
Description: "Check if Portmaster has completed starting and is ready.",
}); err != nil {
return err
}
if err := RegisterEndpoint(Endpoint{
Path: "debug/stack",
Read: PermitAnyone,
@ -57,58 +41,6 @@ func registerDebugEndpoints() error {
return err
}
if err := RegisterEndpoint(Endpoint{
Path: "debug/cpu",
MimeType: "application/octet-stream",
Read: PermitAnyone,
DataFunc: handleCPUProfile,
Name: "Get CPU Profile",
Description: strings.ReplaceAll(`Gather and return the CPU profile.
This data needs to gathered over a period of time, which is specified using the duration parameter.
You can easily view this data in your browser with this command (with Go installed):
"go tool pprof -http :8888 http://127.0.0.1:817/api/v1/debug/cpu"
`, `"`, "`"),
Parameters: []Parameter{{
Method: http.MethodGet,
Field: "duration",
Value: "10s",
Description: "Specify the formatting style. The default is simple markdown formatting.",
}},
}); err != nil {
return err
}
if err := RegisterEndpoint(Endpoint{
Path: "debug/heap",
MimeType: "application/octet-stream",
Read: PermitAnyone,
DataFunc: handleHeapProfile,
Name: "Get Heap Profile",
Description: strings.ReplaceAll(`Gather and return the heap memory profile.
You can easily view this data in your browser with this command (with Go installed):
"go tool pprof -http :8888 http://127.0.0.1:817/api/v1/debug/heap"
`, `"`, "`"),
}); err != nil {
return err
}
if err := RegisterEndpoint(Endpoint{
Path: "debug/allocs",
MimeType: "application/octet-stream",
Read: PermitAnyone,
DataFunc: handleAllocsProfile,
Name: "Get Allocs Profile",
Description: strings.ReplaceAll(`Gather and return the memory allocation profile.
You can easily view this data in your browser with this command (with Go installed):
"go tool pprof -http :8888 http://127.0.0.1:817/api/v1/debug/allocs"
`, `"`, "`"),
}); err != nil {
return err
}
if err := RegisterEndpoint(Endpoint{
Path: "debug/info",
Read: PermitAnyone,
@ -130,22 +62,9 @@ You can easily view this data in your browser with this command (with Go install
// ping responds with pong.
func ping(ar *Request) (msg string, err error) {
// TODO: Remove upgrade to "ready" when all UI components have transitioned.
if modules.IsStarting() || modules.IsShuttingDown() {
return "", ErrorWithStatus(errors.New("portmaster is not ready, reload (F5) to try again"), http.StatusTooEarly)
}
return "Pong.", nil
}
// ready checks if Portmaster has completed starting.
func ready(ar *Request) (msg string, err error) {
if modules.IsStarting() || modules.IsShuttingDown() {
return "", ErrorWithStatus(errors.New("portmaster is not ready, reload (F5) to try again"), http.StatusTooEarly)
}
return "Portmaster is ready.", nil
}
// getStack returns the current goroutine stack.
func getStack(_ *Request) (data []byte, err error) {
buf := &bytes.Buffer{}
@ -171,73 +90,6 @@ func printStack(_ *Request) (msg string, err error) {
return "stack printed to stdout", nil
}
// handleCPUProfile returns the CPU profile.
func handleCPUProfile(ar *Request) (data []byte, err error) {
// Parse duration.
duration := 10 * time.Second
if durationOption := ar.Request.URL.Query().Get("duration"); durationOption != "" {
parsedDuration, err := time.ParseDuration(durationOption)
if err != nil {
return nil, fmt.Errorf("failed to parse duration: %w", err)
}
duration = parsedDuration
}
// Indicate download and filename.
ar.ResponseHeader.Set(
"Content-Disposition",
fmt.Sprintf(`attachment; filename="portmaster-cpu-profile_v%s.pprof"`, info.Version()),
)
// Start CPU profiling.
buf := new(bytes.Buffer)
if err := pprof.StartCPUProfile(buf); err != nil {
return nil, fmt.Errorf("failed to start cpu profile: %w", err)
}
// Wait for the specified duration.
select {
case <-time.After(duration):
case <-ar.Context().Done():
pprof.StopCPUProfile()
return nil, context.Canceled
}
// Stop CPU profiling and return data.
pprof.StopCPUProfile()
return buf.Bytes(), nil
}
// handleHeapProfile returns the Heap profile.
func handleHeapProfile(ar *Request) (data []byte, err error) {
// Indicate download and filename.
ar.ResponseHeader.Set(
"Content-Disposition",
fmt.Sprintf(`attachment; filename="portmaster-memory-heap-profile_v%s.pprof"`, info.Version()),
)
buf := new(bytes.Buffer)
if err := pprof.Lookup("heap").WriteTo(buf, 0); err != nil {
return nil, fmt.Errorf("failed to write heap profile: %w", err)
}
return buf.Bytes(), nil
}
// handleAllocsProfile returns the Allocs profile.
func handleAllocsProfile(ar *Request) (data []byte, err error) {
// Indicate download and filename.
ar.ResponseHeader.Set(
"Content-Disposition",
fmt.Sprintf(`attachment; filename="portmaster-memory-allocs-profile_v%s.pprof"`, info.Version()),
)
buf := new(bytes.Buffer)
if err := pprof.Lookup("allocs").WriteTo(buf, 0); err != nil {
return nil, fmt.Errorf("failed to write allocs profile: %w", err)
}
return buf.Bytes(), nil
}
// debugInfo returns the debugging information for support requests.
func debugInfo(ar *Request) (data []byte, err error) {
// Create debug information helper.

View file

@ -51,6 +51,7 @@ func registerMetaEndpoints() error {
if err := RegisterEndpoint(Endpoint{
Path: "auth/reset",
Read: PermitAnyone,
Write: PermitAnyone,
HandlerFunc: authReset,
Name: "Reset Authenticated Session",
Description: "Resets authentication status internally and in the browser.",
@ -93,10 +94,7 @@ func authBearer(w http.ResponseWriter, r *http.Request) {
}
// Respond with desired authentication header.
w.Header().Set(
"WWW-Authenticate",
`Bearer realm="Portmaster API" domain="/"`,
)
w.Header().Set("WWW-Authenticate", "Bearer realm=Portmaster API")
http.Error(w, "Authorization required.", http.StatusUnauthorized)
}
@ -109,10 +107,7 @@ func authBasic(w http.ResponseWriter, r *http.Request) {
}
// Respond with desired authentication header.
w.Header().Set(
"WWW-Authenticate",
`Basic realm="Portmaster API" domain="/"`,
)
w.Header().Set("WWW-Authenticate", "Basic realm=Portmaster API")
http.Error(w, "Authorization required.", http.StatusUnauthorized)
}
@ -133,7 +128,7 @@ func authReset(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Clear-Site-Data", "*")
// Set HTTP Auth Realm without requesting authorization.
w.Header().Set("WWW-Authenticate", `None realm="Portmaster API"`)
w.Header().Set("WWW-Authenticate", "None realm=Portmaster API")
// Reply with 401 Unauthorized in order to clear HTTP Basic Auth data.
http.Error(w, "Session deleted.", http.StatusUnauthorized)

View file

@ -3,27 +3,15 @@ package api
import (
"errors"
"fmt"
"github.com/safing/portbase/modules"
)
func registerModulesEndpoints() error {
if err := RegisterEndpoint(Endpoint{
Path: "modules/status",
Read: PermitUser,
StructFunc: getStatusfunc,
Name: "Get Module Status",
Description: "Returns status information of all modules.",
}); err != nil {
return err
}
if err := RegisterEndpoint(Endpoint{
Path: "modules/{moduleName:.+}/trigger/{eventName:.+}",
Write: PermitSelf,
ActionFunc: triggerEvent,
Name: "Trigger Event",
Description: "Triggers an event of an internal module.",
Name: "Export Configuration Options",
Description: "Returns a list of all registered configuration options and their metadata. This does not include the current active or default settings.",
}); err != nil {
return err
}
@ -31,14 +19,6 @@ func registerModulesEndpoints() error {
return nil
}
func getStatusfunc(ar *Request) (i interface{}, err error) {
status := modules.GetStatus()
if status == nil {
return nil, errors.New("modules not yet initialized")
}
return status, nil
}
func triggerEvent(ar *Request) (msg string, err error) {
// Get parameters.
moduleName := ar.URLVars["moduleName"]

View file

@ -5,9 +5,8 @@ import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/safing/portbase/database/record"
"github.com/stretchr/testify/assert"
)
const (
@ -22,8 +21,6 @@ type actionTestRecord struct {
}
func TestEndpoints(t *testing.T) {
t.Parallel()
testHandler := &mainHandler{
mux: mainMux,
}
@ -116,8 +113,6 @@ func TestEndpoints(t *testing.T) {
}
func TestActionRegistration(t *testing.T) {
t.Parallel()
assert.Error(t, RegisterEndpoint(Endpoint{}))
assert.Error(t, RegisterEndpoint(Endpoint{

View file

@ -1,6 +1,7 @@
package api
import (
"context"
"encoding/json"
"errors"
"flag"
@ -57,7 +58,7 @@ func prep() error {
}
func start() error {
startServer()
go Serve()
_ = updateAPIKeys(module.Ctx, nil)
err := module.RegisterEventHook("config", "config change", "update API keys", updateAPIKeys)
@ -74,7 +75,10 @@ func start() error {
}
func stop() error {
return stopServer()
if server != nil {
return server.Shutdown(context.Background())
}
return nil
}
func exportEndpointsCmd() error {

View file

@ -2,13 +2,15 @@ package api
import (
"fmt"
"io/ioutil"
"os"
"testing"
// API depends on the database for the database api.
_ "github.com/safing/portbase/database/dbmodule"
"github.com/safing/portbase/dataroot"
"github.com/safing/portbase/modules"
// API depends on the database for the database api.
_ "github.com/safing/portbase/database/dbmodule"
)
func init() {
@ -20,13 +22,13 @@ func TestMain(m *testing.M) {
module.Enable()
// tmp dir for data root (db & config)
tmpDir, err := os.MkdirTemp("", "portbase-testing-")
tmpDir, err := ioutil.TempDir("", "portbase-testing-")
if err != nil {
fmt.Fprintf(os.Stderr, "failed to create tmp dir: %s\n", err)
os.Exit(1)
}
// initialize data dir
err = dataroot.Initialize(tmpDir, 0o0755)
err = dataroot.Initialize(tmpDir, 0755)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to initialize data root: %s\n", err)
os.Exit(1)
@ -51,6 +53,6 @@ func TestMain(m *testing.M) {
fmt.Fprintf(os.Stderr, "failed to cleanly shutdown test: %s\n", err)
}
// clean up and exit
_ = os.RemoveAll(tmpDir)
os.RemoveAll(tmpDir)
os.Exit(exitCode)
}

View file

@ -6,7 +6,6 @@ import (
"github.com/safing/portbase/modules"
)
// ModuleHandler specifies the interface for API endpoints that are bound to a module.
type ModuleHandler interface {
BelongsTo() *modules.Module
}

View file

@ -5,7 +5,6 @@ import (
"net/http"
"github.com/gorilla/mux"
"github.com/safing/portbase/log"
)
@ -26,9 +25,6 @@ type Request struct {
// AuthToken is the request-side authentication token assigned.
AuthToken *AuthToken
// ResponseHeader holds the response header.
ResponseHeader http.Header
// HandlerCache can be used by handlers to cache data between handlers within a request.
HandlerCache interface{}
}
@ -36,12 +32,13 @@ type Request struct {
// apiRequestContextKey is a key used for the context key/value storage.
type apiRequestContextKey struct{}
// RequestContextKey is the key used to add the API request to the context.
var RequestContextKey = apiRequestContextKey{}
var (
requestContextKey = apiRequestContextKey{}
)
// GetAPIRequest returns the API Request of the given http request.
func GetAPIRequest(r *http.Request) *Request {
ar, ok := r.Context().Value(RequestContextKey).(*Request)
ar, ok := r.Context().Value(requestContextKey).(*Request)
if ok {
return ar
}

View file

@ -3,11 +3,8 @@ package api
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"path"
"runtime/debug"
"strings"
"sync"
"time"
@ -15,74 +12,40 @@ import (
"github.com/gorilla/mux"
"github.com/safing/portbase/log"
"github.com/safing/portbase/utils"
)
// EnableServer defines if the HTTP server should be started.
var EnableServer = true
var (
// mainMux is the main mux router.
// gorilla mux
mainMux = mux.NewRouter()
// server is the main server.
server = &http.Server{
ReadHeaderTimeout: 10 * time.Second,
}
// main server and lock
server = &http.Server{}
handlerLock sync.RWMutex
allowedDevCORSOrigins = []string{
"127.0.0.1",
"localhost",
}
)
// RegisterHandler registers a handler with the API endpoint.
// RegisterHandler registers a handler with the API endoint.
func RegisterHandler(path string, handler http.Handler) *mux.Route {
handlerLock.Lock()
defer handlerLock.Unlock()
return mainMux.Handle(path, handler)
}
// RegisterHandleFunc registers a handle function with the API endpoint.
// RegisterHandleFunc registers a handle function with the API endoint.
func RegisterHandleFunc(path string, handleFunc func(http.ResponseWriter, *http.Request)) *mux.Route {
handlerLock.Lock()
defer handlerLock.Unlock()
return mainMux.HandleFunc(path, handleFunc)
}
func startServer() {
// Check if server is enabled.
if !EnableServer {
return
}
// Configure server.
// Serve starts serving the API endpoint.
func Serve() {
// configure server
server.Addr = listenAddressConfig()
server.Handler = &mainHandler{
// TODO: mainMux should not be modified anymore.
mux: mainMux,
}
// Start server manager.
module.StartServiceWorker("http server manager", 0, serverManager)
}
func stopServer() error {
// Check if server is enabled.
if !EnableServer {
return nil
}
if server.Addr != "" {
return server.Shutdown(context.Background())
}
return nil
}
// Serve starts serving the API endpoint.
func serverManager(_ context.Context) error {
// start serving
log.Infof("api: starting to listen on %s", server.Addr)
backoffDuration := 10 * time.Second
@ -93,7 +56,7 @@ func serverManager(_ context.Context) error {
})
// return on shutdown error
if errors.Is(err, http.ErrServerClosed) {
return nil
return
}
// log error and restart
log.Errorf("api: http endpoint failed: %s - restarting in %s", err, backoffDuration)
@ -118,7 +81,7 @@ func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error {
apiRequest := &Request{
Request: r,
}
ctx = context.WithValue(ctx, RequestContextKey, apiRequest)
ctx = context.WithValue(ctx, requestContextKey, apiRequest)
// Add context back to request.
r = r.WithContext(ctx)
lrw := NewLoggingResponseWriter(w, r)
@ -133,80 +96,6 @@ func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error {
tracer.Submit()
}()
// Add security headers.
w.Header().Set("Referrer-Policy", "same-origin")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "deny")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("X-DNS-Prefetch-Control", "off")
// Add CSP Header in production mode.
if !devMode() {
w.Header().Set(
"Content-Security-Policy",
"default-src 'self'; "+
"connect-src https://*.safing.io 'self'; "+
"style-src 'self' 'unsafe-inline'; "+
"img-src 'self' data: blob:",
)
}
// Check Cross-Origin Requests.
origin := r.Header.Get("Origin")
isPreflighCheck := false
if origin != "" {
// Parse origin URL.
originURL, err := url.Parse(origin)
if err != nil {
tracer.Warningf("api: denied request from %s: failed to parse origin header: %s", r.RemoteAddr, err)
http.Error(lrw, "Invalid Origin.", http.StatusForbidden)
return nil
}
// Check if the Origin matches the Host.
switch {
case originURL.Host == r.Host:
// Origin (with port) matches Host.
case originURL.Hostname() == r.Host:
// Origin (without port) matches Host.
case originURL.Scheme == "chrome-extension":
// Allow access for the browser extension
// TODO(ppacher):
// This currently allows access from any browser extension.
// Can we reduce that to only our browser extension?
// Also, what do we need to support Firefox?
case devMode() &&
utils.StringInSlice(allowedDevCORSOrigins, originURL.Hostname()):
// We are in dev mode and the request is coming from the allowed
// development origins.
default:
// Origin and Host do NOT match!
tracer.Warningf("api: denied request from %s: Origin (`%s`) and Host (`%s`) do not match", r.RemoteAddr, origin, r.Host)
http.Error(lrw, "Cross-Origin Request Denied.", http.StatusForbidden)
return nil
// If the Host header has a port, and the Origin does not, requests will
// also end up here, as we cannot properly check for equality.
}
// Add Cross-Site Headers now as we need them in any case now.
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods", "*")
w.Header().Set("Access-Control-Allow-Headers", "*")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Expose-Headers", "*")
w.Header().Set("Access-Control-Max-Age", "60")
w.Header().Add("Vary", "Origin")
// if there's a Access-Control-Request-Method header this is a Preflight check.
// In that case, we will just check if the preflighMethod is allowed and then return
// success here
if preflighMethod := r.Header.Get("Access-Control-Request-Method"); r.Method == http.MethodOptions && preflighMethod != "" {
isPreflighCheck = true
}
}
// Clean URL.
cleanedRequestPath := cleanRequestPath(r.URL.Path)
@ -228,41 +117,14 @@ func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error {
apiRequest.Route = match.Route
apiRequest.URLVars = match.Vars
}
switch {
case match.MatchErr == nil:
// All good.
case errors.Is(match.MatchErr, mux.ErrMethodMismatch):
http.Error(lrw, "Method not allowed.", http.StatusMethodNotAllowed)
return nil
default:
tracer.Debug("api: no handler registered for this path")
http.Error(lrw, "Not found.", http.StatusNotFound)
return nil
}
// Be sure that URLVars always is a map.
if apiRequest.URLVars == nil {
apiRequest.URLVars = make(map[string]string)
}
// Check method.
_, readMethod, ok := getEffectiveMethod(r)
if !ok {
http.Error(lrw, "Method not allowed.", http.StatusMethodNotAllowed)
return nil
}
// At this point we know the method is allowed and there's a handler for the request.
// If this is just a CORS-Preflight, we'll accept the request with StatusOK now.
// There's no point in trying to authenticate the request because the Browser will
// not send authentication along a preflight check.
if isPreflighCheck && handler != nil {
lrw.WriteHeader(http.StatusOK)
return nil
}
// Check authentication.
apiRequest.AuthToken = authenticateRequest(lrw, r, handler, readMethod)
apiRequest.AuthToken = authenticateRequest(lrw, r, handler)
if apiRequest.AuthToken == nil {
// Authenticator already replied.
return nil
@ -271,42 +133,38 @@ func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error {
// Wait for the owning module to be ready.
if moduleHandler, ok := handler.(ModuleHandler); ok {
if !moduleIsReady(moduleHandler.BelongsTo()) {
http.Error(lrw, "The API endpoint is not ready yet. Reload (F5) to try again.", http.StatusServiceUnavailable)
http.Error(lrw, "The API endpoint is not ready yet. Please try again later.", http.StatusServiceUnavailable)
return nil
}
}
// Check if we have a handler.
if handler == nil {
http.Error(lrw, "Not found.", http.StatusNotFound)
return nil
// Add security headers.
if !devMode() {
w.Header().Set(
"Content-Security-Policy",
"default-src 'self'; "+
"connect-src https://*.safing.io 'self'; "+
"style-src 'self' 'unsafe-inline'; "+
"img-src 'self' data:",
)
w.Header().Set("Referrer-Policy", "no-referrer")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "deny")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("X-DNS-Prefetch-Control", "off")
} else {
w.Header().Set("Access-Control-Allow-Origin", "*")
}
// Format panics in handler.
defer func() {
if panicValue := recover(); panicValue != nil {
// Report failure via module system.
me := module.NewPanicError("api request", "custom", panicValue)
me.Report()
// Respond with a server error.
if devMode() {
http.Error(
lrw,
fmt.Sprintf(
"Internal Server Error: %s\n\n%s",
panicValue,
debug.Stack(),
),
http.StatusInternalServerError,
)
} else {
http.Error(lrw, "Internal Server Error.", http.StatusInternalServerError)
}
}
}()
// Handle with registered handler.
handler.ServeHTTP(lrw, r)
// Handle request.
switch {
case handler != nil:
handler.ServeHTTP(lrw, r)
case errors.Is(match.MatchErr, mux.ErrMethodMismatch):
http.Error(lrw, "Method not allowed.", http.StatusMethodNotAllowed)
default: // handler == nil or other error
http.Error(lrw, "Not found.", http.StatusNotFound)
}
return nil
}

View file

@ -1,167 +0,0 @@
package apprise
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"sync"
"github.com/safing/portbase/utils"
)
// Notifier sends messsages to an Apprise API.
type Notifier struct {
// URL defines the Apprise API endpoint.
URL string
// DefaultType defines the default message type.
DefaultType MsgType
// DefaultTag defines the default message tag.
DefaultTag string
// DefaultFormat defines the default message format.
DefaultFormat MsgFormat
// AllowUntagged defines if untagged messages are allowed,
// which are sent to all configured apprise endpoints.
AllowUntagged bool
client *http.Client
clientLock sync.Mutex
}
// Message represents the message to be sent to the Apprise API.
type Message struct {
// Title is an optional title to go along with the body.
Title string `json:"title,omitempty"`
// Body is the main message content. This is the only required field.
Body string `json:"body"`
// Type defines the message type you want to send as.
// The valid options are info, success, warning, and failure.
// If no type is specified then info is the default value used.
Type MsgType `json:"type,omitempty"`
// Tag is used to notify only those tagged accordingly.
// Use a comma (,) to OR your tags and a space ( ) to AND them.
Tag string `json:"tag,omitempty"`
// Format optionally identifies the text format of the data you're feeding Apprise.
// The valid options are text, markdown, html.
// The default value if nothing is specified is text.
Format MsgFormat `json:"format,omitempty"`
}
// MsgType defines the message type.
type MsgType string
// Message Types.
const (
TypeInfo MsgType = "info"
TypeSuccess MsgType = "success"
TypeWarning MsgType = "warning"
TypeFailure MsgType = "failure"
)
// MsgFormat defines the message format.
type MsgFormat string
// Message Formats.
const (
FormatText MsgFormat = "text"
FormatMarkdown MsgFormat = "markdown"
FormatHTML MsgFormat = "html"
)
type errorResponse struct {
Error string `json:"error"`
}
// Send sends a message to the Apprise API.
func (n *Notifier) Send(ctx context.Context, m *Message) error {
// Check if the message has a body.
if m.Body == "" {
return errors.New("the message must have a body")
}
// Apply notifier defaults.
n.applyDefaults(m)
// Check if the message is tagged.
if m.Tag == "" && !n.AllowUntagged {
return errors.New("the message must have a tag")
}
// Marshal the message to JSON.
payload, err := json.Marshal(m)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
// Create request.
request, err := http.NewRequestWithContext(ctx, http.MethodPost, n.URL, bytes.NewReader(payload))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
request.Header.Set("Content-Type", "application/json")
// Send message to API.
resp, err := n.getClient().Do(request)
if err != nil {
return fmt.Errorf("failed to send message: %w", err)
}
defer resp.Body.Close() //nolint:errcheck,gosec
switch resp.StatusCode {
case http.StatusOK, http.StatusCreated, http.StatusNoContent, http.StatusAccepted:
return nil
default:
// Try to tease body contents.
if body, err := io.ReadAll(resp.Body); err == nil && len(body) > 0 {
// Try to parse json response.
errorResponse := &errorResponse{}
if err := json.Unmarshal(body, errorResponse); err == nil && errorResponse.Error != "" {
return fmt.Errorf("failed to send message: apprise returned %q with an error message: %s", resp.Status, errorResponse.Error)
}
return fmt.Errorf("failed to send message: %s (body teaser: %s)", resp.Status, utils.SafeFirst16Bytes(body))
}
return fmt.Errorf("failed to send message: %s", resp.Status)
}
}
func (n *Notifier) applyDefaults(m *Message) {
if m.Type == "" {
m.Type = n.DefaultType
}
if m.Tag == "" {
m.Tag = n.DefaultTag
}
if m.Format == "" {
m.Format = n.DefaultFormat
}
}
// SetClient sets a custom http client for accessing the Apprise API.
func (n *Notifier) SetClient(client *http.Client) {
n.clientLock.Lock()
defer n.clientLock.Unlock()
n.client = client
}
func (n *Notifier) getClient() *http.Client {
n.clientLock.Lock()
defer n.clientLock.Unlock()
// Create client if needed.
if n.client == nil {
n.client = &http.Client{}
}
return n.client
}

View file

@ -80,7 +80,7 @@ func registerBasicOptions() error {
// Register to hook to update the log level.
if err := module.RegisterEventHook(
"config",
ChangeEvent,
configChangeEvent,
"update log level",
setLogLevel,
); err != nil {
@ -102,8 +102,8 @@ func registerBasicOptions() error {
})
}
func loadLogLevel() error {
return setDefaultConfigOption(CfgLogLevel, log.GetLogLevel().Name(), false)
func loadLogLevel() {
setDefaultConfigOption(CfgLogLevel, log.GetLogLevel().Name(), false)
}
func setLogLevel(ctx context.Context, data interface{}) error {

View file

@ -13,7 +13,9 @@ import (
"github.com/safing/portbase/log"
)
var dbController *database.Controller
var (
dbController *database.Controller
)
// StorageInterface provices a storage.Interface to the configuration manager.
type StorageInterface struct {
@ -65,8 +67,6 @@ func (s *StorageInterface) Put(r record.Record) (record.Record, error) {
value, ok = acc.GetInt("Value")
case OptTypeBool:
value, ok = acc.GetBool("Value")
case optTypeAny:
ok = false
}
if !ok {
return nil, errors.New("received invalid value in \"Value\"")

View file

@ -1,3 +1,5 @@
// Package config ... (linter fix)
//nolint:dupl
package config
import (
@ -51,17 +53,17 @@ func registerExpertiseLevelOption() {
},
PossibleValues: []PossibleValue{
{
Name: "Simple Interface",
Name: "Simple",
Value: ExpertiseLevelNameUser,
Description: "Hide complex settings and information.",
},
{
Name: "Advanced Interface",
Name: "Advanced",
Value: ExpertiseLevelNameExpert,
Description: "Show technical details.",
},
{
Name: "Developer Interface",
Name: "Developer",
Value: ExpertiseLevelNameDeveloper,
Description: "Developer mode. Please be careful!",
},

View file

@ -4,8 +4,10 @@ import "sync"
type safe struct{}
// Concurrent makes concurrency safe get methods available.
var Concurrent = &safe{}
var (
// Concurrent makes concurrency safe get methods available.
Concurrent = &safe{}
)
// GetAsString returns a function that returns the wanted string with high performance.
func (cs *safe) GetAsString(name string, fallback string) StringOption {

View file

@ -2,7 +2,6 @@ package config
import (
"encoding/json"
"fmt"
"testing"
"github.com/safing/portbase/log"
@ -14,11 +13,7 @@ func parseAndReplaceConfig(jsonData string) error {
return err
}
validationErrors, _ := ReplaceConfig(m)
if len(validationErrors) > 0 {
return fmt.Errorf("%d errors, first: %w", len(validationErrors), validationErrors[0])
}
return nil
return replaceConfig(m)
}
func parseAndReplaceDefaultConfig(jsonData string) error {
@ -27,16 +22,10 @@ func parseAndReplaceDefaultConfig(jsonData string) error {
return err
}
validationErrors, _ := ReplaceDefaultConfig(m)
if len(validationErrors) > 0 {
return fmt.Errorf("%d errors, first: %w", len(validationErrors), validationErrors[0])
}
return nil
return replaceDefaultConfig(m)
}
func quickRegister(t *testing.T, key string, optType OptionType, defaultValue interface{}) {
t.Helper()
err := Register(&Option{
Name: key,
Key: key,
@ -51,7 +40,7 @@ func quickRegister(t *testing.T, key string, optType OptionType, defaultValue in
}
}
func TestGet(t *testing.T) { //nolint:paralleltest
func TestGet(t *testing.T) { //nolint:gocognit
// reset
options = make(map[string]*Option)
@ -192,7 +181,7 @@ func TestGet(t *testing.T) { //nolint:paralleltest
}
}
func TestReleaseLevel(t *testing.T) { //nolint:paralleltest
func TestReleaseLevel(t *testing.T) {
// reset
options = make(map[string]*Option)
registerReleaseLevelOption()

View file

@ -4,20 +4,17 @@ import (
"encoding/json"
"errors"
"flag"
"fmt"
"io/fs"
"os"
"path/filepath"
"sort"
"github.com/safing/portbase/dataroot"
"github.com/safing/portbase/modules"
"github.com/safing/portbase/utils"
"github.com/safing/portbase/utils/debug"
)
// ChangeEvent is the name of the config change event.
const ChangeEvent = "config change"
const (
configChangeEvent = "config change"
)
var (
module *modules.Module
@ -35,7 +32,7 @@ func SetDataRoot(root *utils.DirStructure) {
func init() {
module = modules.Register("config", prep, start, nil, "database")
module.RegisterEvent(ChangeEvent, true)
module.RegisterEvent(configChangeEvent, true)
flag.BoolVar(&exportConfig, "export-config-options", false, "export configuration registry and exit")
}
@ -57,32 +54,21 @@ func start() error {
configFilePath = filepath.Join(dataRoot.Path, "config.json")
// Load log level from log package after it started.
err := loadLogLevel()
if err != nil {
loadLogLevel()
err := registerAsDatabase()
if err != nil && !os.IsNotExist(err) {
return err
}
err = registerAsDatabase()
if err != nil && !errors.Is(err, fs.ErrNotExist) {
err = loadConfig()
if err != nil && !os.IsNotExist(err) {
return err
}
err = loadConfig(false)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("failed to load config file: %w", err)
}
return nil
}
func exportConfigCmd() error {
// Reset the metrics instance name option, as the default
// is set to the current hostname.
// Config key copied from metrics.CfgOptionInstanceKey.
option, err := GetOption("core/metrics/instance")
if err == nil {
option.DefaultValue = ""
}
data, err := json.MarshalIndent(ExportOptions(), "", " ")
if err != nil {
return err
@ -91,51 +77,3 @@ func exportConfigCmd() error {
_, err = os.Stdout.Write(data)
return err
}
// AddToDebugInfo adds all changed global config options to the given debug.Info.
func AddToDebugInfo(di *debug.Info) {
var lines []string
// Collect all changed settings.
_ = ForEachOption(func(opt *Option) error {
opt.Lock()
defer opt.Unlock()
if opt.ReleaseLevel <= getReleaseLevel() && opt.activeValue != nil {
if opt.Sensitive {
lines = append(lines, fmt.Sprintf("%s: [redacted]", opt.Key))
} else {
lines = append(lines, fmt.Sprintf("%s: %v", opt.Key, opt.activeValue.getData(opt)))
}
}
return nil
})
sort.Strings(lines)
// Add data as section.
di.AddSection(
fmt.Sprintf("Config: %d", len(lines)),
debug.UseCodeSection|debug.AddContentLineBreaks,
lines...,
)
}
// GetActiveConfigValues returns a map with the active config values.
func GetActiveConfigValues() map[string]interface{} {
values := make(map[string]interface{})
// Collect active values from options.
_ = ForEachOption(func(opt *Option) error {
opt.Lock()
defer opt.Unlock()
if opt.ReleaseLevel <= getReleaseLevel() && opt.activeValue != nil {
values[opt.Key] = opt.activeValue.getData(opt)
}
return nil
})
return values
}

View file

@ -3,15 +3,12 @@ package config
import (
"encoding/json"
"fmt"
"reflect"
"regexp"
"sync"
"github.com/mitchellh/copystructure"
"github.com/tidwall/sjson"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/formats/dsd"
)
// OptionType defines the value type of an option.
@ -66,9 +63,6 @@ type PossibleValue struct {
// Format: <vendor/package>:<scope>:<identifier> //.
type Annotations map[string]interface{}
// MigrationFunc is a function that migrates a config option value.
type MigrationFunc func(option *Option, value any) any
// Well known annotations defined by this package.
const (
// DisplayHintAnnotation provides a hint for the user
@ -99,10 +93,6 @@ const (
// may be extended to hold references to other options in the
// future.
StackableAnnotation = "safing/portbase:options:stackable"
// RestartPendingAnnotation is automatically set on a configuration option
// that requires a restart and has been changed.
// The value must always be a boolean with value "true".
RestartPendingAnnotation = "safing/portbase:options:restart-pending"
// QuickSettingAnnotation can be used to add quick settings to
// a configuration option. A quick setting can support the user
// by switching between pre-configured values.
@ -112,19 +102,6 @@ const (
// requirement. The type of RequiresAnnotation is []ValueRequirement
// or ValueRequirement.
RequiresAnnotation = "safing/portbase:config:requires"
// RequiresFeatureIDAnnotation can be used to mark a setting as only available
// when the user has a certain feature ID in the subscription plan.
// The type is []string or string.
RequiresFeatureIDAnnotation = "safing/portmaster:ui:config:requires-feature"
// SettablePerAppAnnotation can be used to mark a setting as settable per-app and
// is a boolean.
SettablePerAppAnnotation = "safing/portmaster:settable-per-app"
// RequiresUIReloadAnnotation can be used to inform the UI that changing the value
// of the annotated setting requires a full reload of the user interface.
// The value of this annotation does not matter as the sole presence of
// the annotation key is enough. Though, users are advised to set the value
// of this annotation to true.
RequiresUIReloadAnnotation = "safing/portmaster:ui:requires-reload"
)
// QuickSettingsAction defines the action of a quick setting.
@ -175,14 +152,11 @@ const (
// only sense together with the PossibleValues property
// of Option.
DisplayHintOneOf = "one-of"
// DisplayHintOrdered is used to mark a list option as ordered.
// DisplayHintOrdered Used to mark a list option as ordered.
// That is, the order of items is important and a user interface
// is encouraged to provide the user with re-ordering support
// (like drag'n'drop).
DisplayHintOrdered = "ordered"
// DisplayHintFilePicker is used to mark the option as being a file, which
// should give the option to use a file picker to select a local file from disk.
DisplayHintFilePicker = "file-picker"
)
// Option describes a configuration option.
@ -210,9 +184,6 @@ type Option struct {
// Help is considered immutable after the option has
// been created.
Help string
// Sensitive signifies that the configuration values may contain sensitive
// content, such as authentication keys.
Sensitive bool
// OptType defines the type of the option.
// OptType is considered immutable after the option has
// been created.
@ -247,10 +218,6 @@ type Option struct {
// ValidationRegex is considered immutable after the option has
// been created.
ValidationRegex string
// ValidationFunc may contain a function to validate more complex values.
// The error is returned beyond the scope of this package and may be
// displayed to a user.
ValidationFunc func(value interface{}) error `json:"-"`
// PossibleValues may be set to a slice of values that are allowed
// for this configuration setting. Note that PossibleValues makes most
// sense when ExternalOptType is set to HintOneOf
@ -262,9 +229,6 @@ type Option struct {
// Annotations is considered mutable and setting/reading annotation keys
// must be performed while the option is locked.
Annotations Annotations
// Migrations holds migration functions that are given the raw option value
// before any validation is run. The returned value is then used.
Migrations []MigrationFunc `json:"-"`
activeValue *valueCache // runtime value (loaded from config file or set by user)
activeDefaultValue *valueCache // runtime default value (may be set internally)
@ -293,12 +257,6 @@ func (option *Option) SetAnnotation(key string, value interface{}) {
option.Lock()
defer option.Unlock()
option.setAnnotation(key, value)
}
// setAnnotation sets the value of the annotation key overwritting an
// existing value if required. Does not lock the Option.
func (option *Option) setAnnotation(key string, value interface{}) {
if option.Annotations == nil {
option.Annotations = make(Annotations)
}
@ -317,63 +275,6 @@ func (option *Option) GetAnnotation(key string) (interface{}, bool) {
return val, ok
}
// AnnotationEquals returns whether the annotation of the given key matches the
// given value.
func (option *Option) AnnotationEquals(key string, value any) bool {
option.Lock()
defer option.Unlock()
if option.Annotations == nil {
return false
}
setValue, ok := option.Annotations[key]
if !ok {
return false
}
return reflect.DeepEqual(value, setValue)
}
// copyOrNil returns a copy of the option, or nil if copying failed.
func (option *Option) copyOrNil() *Option {
copied, err := copystructure.Copy(option)
if err != nil {
return nil
}
return copied.(*Option) //nolint:forcetypeassert
}
// IsSetByUser returns whether the option has been set by the user.
func (option *Option) IsSetByUser() bool {
option.Lock()
defer option.Unlock()
return option.activeValue != nil
}
// UserValue returns the value set by the user or nil if the value has not
// been changed from the default.
func (option *Option) UserValue() any {
option.Lock()
defer option.Unlock()
if option.activeValue == nil {
return nil
}
return option.activeValue.getData(option)
}
// ValidateValue checks if the given value is valid for the option.
func (option *Option) ValidateValue(value any) error {
option.Lock()
defer option.Unlock()
value = migrateValue(option, value)
if _, err := validateValue(option, value); err != nil {
return err
}
return nil
}
// Export expors an option to a Record.
func (option *Option) Export() (record.Record, error) {
option.Lock()
@ -402,7 +303,7 @@ func (option *Option) export() (record.Record, error) {
}
}
r, err := record.NewWrapper(fmt.Sprintf("config:%s", option.Key), nil, dsd.JSON, data)
r, err := record.NewWrapper(fmt.Sprintf("config:%s", option.Key), nil, record.JSON, data)
if err != nil {
return nil, err
}

View file

@ -2,39 +2,25 @@ package config
import (
"encoding/json"
"fmt"
"os"
"io/ioutil"
"path"
"strings"
"sync"
"github.com/safing/portbase/log"
)
var (
configFilePath string
loadedConfigValidationErrors []*ValidationError
loadedConfigValidationErrorsLock sync.Mutex
)
// GetLoadedConfigValidationErrors returns the encountered validation errors
// from the last time loading config from disk.
func GetLoadedConfigValidationErrors() []*ValidationError {
loadedConfigValidationErrorsLock.Lock()
defer loadedConfigValidationErrorsLock.Unlock()
return loadedConfigValidationErrors
}
func loadConfig(requireValidConfig bool) error {
func loadConfig() error {
// check if persistence is configured
if configFilePath == "" {
return nil
}
// read config file
data, err := os.ReadFile(configFilePath)
data, err := ioutil.ReadFile(configFilePath)
if err != nil {
return err
}
@ -45,23 +31,13 @@ func loadConfig(requireValidConfig bool) error {
return err
}
validationErrors, _ := ReplaceConfig(newValues)
if requireValidConfig && len(validationErrors) > 0 {
return fmt.Errorf("encountered %d validation errors during config loading", len(validationErrors))
}
// Save validation errors.
loadedConfigValidationErrorsLock.Lock()
defer loadedConfigValidationErrorsLock.Unlock()
loadedConfigValidationErrors = validationErrors
return nil
return replaceConfig(newValues)
}
// SaveConfig saves the current configuration to file.
// saveConfig saves the current configuration to file.
// It will acquire a read-lock on the global options registry
// lock and must lock each option!
func SaveConfig() error {
func saveConfig() error {
optionsLock.RLock()
defer optionsLock.RUnlock()
@ -93,7 +69,7 @@ func SaveConfig() error {
}
// write file
return os.WriteFile(configFilePath, data, 0o0600)
return ioutil.WriteFile(configFilePath, data, 0600)
}
// JSONToMap parses and flattens a hierarchical json object.

View file

@ -36,7 +36,6 @@ var (
)
func TestJSONMapConversion(t *testing.T) {
t.Parallel()
// convert to json
j, err := MapToJSON(mapData)
@ -68,8 +67,6 @@ func TestJSONMapConversion(t *testing.T) {
}
func TestConfigCleaning(t *testing.T) {
t.Parallel()
// load
configFlat, err := JSONToMap(jsonBytes)
if err != nil {

View file

@ -35,8 +35,6 @@ optionsLoop:
if !ok {
continue
}
// migrate value
configValue = migrateValue(option, configValue)
// validate value
valueCache, err := validateValue(option, configValue)
if err != nil {
@ -57,7 +55,7 @@ optionsLoop:
if firstErr != nil {
if errCnt > 0 {
return perspective, fmt.Errorf("encountered %d errors, first was: %w", errCnt, firstErr)
return perspective, fmt.Errorf("encountered %d errors, first was: %s", errCnt, firstErr)
}
return perspective, firstErr
}

View file

@ -88,14 +88,13 @@ func Register(option *Option) error {
if option.ValidationRegex != "" {
option.compiledRegex, err = regexp.Compile(option.ValidationRegex)
if err != nil {
return fmt.Errorf("config: could not compile option.ValidationRegex: %w", err)
return fmt.Errorf("config: could not compile option.ValidationRegex: %s", err)
}
}
var vErr *ValidationError
option.activeFallbackValue, vErr = validateValue(option, option.DefaultValue)
if vErr != nil {
return fmt.Errorf("config: invalid default value: %w", vErr)
option.activeFallbackValue, err = validateValue(option, option.DefaultValue)
if err != nil {
return fmt.Errorf("config: invalid default value: %s", err)
}
optionsLock.Lock()

View file

@ -4,7 +4,7 @@ import (
"testing"
)
func TestRegistry(t *testing.T) { //nolint:paralleltest
func TestRegistry(t *testing.T) {
// reset
options = make(map[string]*Option)
@ -46,4 +46,5 @@ func TestRegistry(t *testing.T) { //nolint:paralleltest
}); err == nil {
t.Error("should fail")
}
}

View file

@ -1,3 +1,5 @@
// Package config ... (linter fix)
//nolint:dupl
package config
import (
@ -10,7 +12,7 @@ import (
// configuration setting.
type ReleaseLevel uint8
// Release Level constants.
// Release Level constants
const (
ReleaseLevelStable ReleaseLevel = 0
ReleaseLevelBeta ReleaseLevel = 1

View file

@ -2,6 +2,7 @@ package config
import (
"errors"
"fmt"
"sync"
"github.com/tevino/abool"
@ -34,126 +35,102 @@ func signalChanges() {
validityFlag = abool.NewBool(true)
validityFlagLock.Unlock()
module.TriggerEvent(ChangeEvent, nil)
module.TriggerEvent(configChangeEvent, nil)
}
// ValidateConfig validates the given configuration and returns all validation
// errors as well as whether the given configuration contains unknown keys.
func ValidateConfig(newValues map[string]interface{}) (validationErrors []*ValidationError, requiresRestart bool, containsUnknown bool) {
// replaceConfig sets the (prioritized) user defined config.
func replaceConfig(newValues map[string]interface{}) error {
var firstErr error
var errCnt int
// RLock the options because we are not adding or removing
// options from the registration but rather only checking the
// options value which is guarded by the option's lock itself.
// options from the registration but rather only update the
// options value which is guarded by the option's lock itself
optionsLock.RLock()
defer optionsLock.RUnlock()
var checked int
for key, option := range options {
newValue, ok := newValues[key]
option.Lock()
option.activeValue = nil
if ok {
checked++
func() {
option.Lock()
defer option.Unlock()
newValue = migrateValue(option, newValue)
_, err := validateValue(option, newValue)
if err != nil {
validationErrors = append(validationErrors, err)
valueCache, err := validateValue(option, newValue)
if err == nil {
option.activeValue = valueCache
} else {
errCnt++
if firstErr == nil {
firstErr = err
}
if option.RequiresRestart {
requiresRestart = true
}
}()
}
}
handleOptionUpdate(option, true)
option.Unlock()
}
return validationErrors, requiresRestart, checked < len(newValues)
signalChanges()
if firstErr != nil {
if errCnt > 0 {
return fmt.Errorf("encountered %d errors, first was: %s", errCnt, firstErr)
}
return firstErr
}
return nil
}
// ReplaceConfig sets the (prioritized) user defined config.
func ReplaceConfig(newValues map[string]interface{}) (validationErrors []*ValidationError, requiresRestart bool) {
// replaceDefaultConfig sets the (fallback) default config.
func replaceDefaultConfig(newValues map[string]interface{}) error {
var firstErr error
var errCnt int
// RLock the options because we are not adding or removing
// options from the registration but rather only update the
// options value which is guarded by the option's lock itself.
// options value which is guarded by the option's lock itself
optionsLock.RLock()
defer optionsLock.RUnlock()
for key, option := range options {
newValue, ok := newValues[key]
func() {
option.Lock()
defer option.Unlock()
option.activeValue = nil
if ok {
newValue = migrateValue(option, newValue)
valueCache, err := validateValue(option, newValue)
if err == nil {
option.activeValue = valueCache
} else {
validationErrors = append(validationErrors, err)
option.Lock()
option.activeDefaultValue = nil
if ok {
valueCache, err := validateValue(option, newValue)
if err == nil {
option.activeDefaultValue = valueCache
} else {
errCnt++
if firstErr == nil {
firstErr = err
}
}
handleOptionUpdate(option, true)
if option.RequiresRestart {
requiresRestart = true
}
}()
}
handleOptionUpdate(option, true)
option.Unlock()
}
signalChanges()
return validationErrors, requiresRestart
}
// ReplaceDefaultConfig sets the (fallback) default config.
func ReplaceDefaultConfig(newValues map[string]interface{}) (validationErrors []*ValidationError, requiresRestart bool) {
// RLock the options because we are not adding or removing
// options from the registration but rather only update the
// options value which is guarded by the option's lock itself.
optionsLock.RLock()
defer optionsLock.RUnlock()
for key, option := range options {
newValue, ok := newValues[key]
func() {
option.Lock()
defer option.Unlock()
option.activeDefaultValue = nil
if ok {
newValue = migrateValue(option, newValue)
valueCache, err := validateValue(option, newValue)
if err == nil {
option.activeDefaultValue = valueCache
} else {
validationErrors = append(validationErrors, err)
}
}
handleOptionUpdate(option, true)
if option.RequiresRestart {
requiresRestart = true
}
}()
if firstErr != nil {
if errCnt > 0 {
return fmt.Errorf("encountered %d errors, first was: %s", errCnt, firstErr)
}
return firstErr
}
signalChanges()
return validationErrors, requiresRestart
return nil
}
// SetConfigOption sets a single value in the (prioritized) user defined config.
func SetConfigOption(key string, value any) error {
func SetConfigOption(key string, value interface{}) error {
return setConfigOption(key, value, true)
}
func setConfigOption(key string, value any, push bool) (err error) {
func setConfigOption(key string, value interface{}, push bool) (err error) {
option, err := GetOption(key)
if err != nil {
return err
@ -163,20 +140,13 @@ func setConfigOption(key string, value any, push bool) (err error) {
if value == nil {
option.activeValue = nil
} else {
value = migrateValue(option, value)
valueCache, vErr := validateValue(option, value)
if vErr == nil {
var valueCache *valueCache
valueCache, err = validateValue(option, value)
if err == nil {
option.activeValue = valueCache
} else {
err = vErr
}
}
// Add the "restart pending" annotation if the settings requires a restart.
if option.RequiresRestart {
option.setAnnotation(RestartPendingAnnotation, true)
}
handleOptionUpdate(option, push)
option.Unlock()
@ -187,7 +157,7 @@ func setConfigOption(key string, value any, push bool) (err error) {
// finalize change, activate triggers
signalChanges()
return SaveConfig()
return saveConfig()
}
// SetDefaultConfigOption sets a single value in the (fallback) default config.
@ -205,20 +175,13 @@ func setDefaultConfigOption(key string, value interface{}, push bool) (err error
if value == nil {
option.activeDefaultValue = nil
} else {
value = migrateValue(option, value)
valueCache, vErr := validateValue(option, value)
if vErr == nil {
var valueCache *valueCache
valueCache, err = validateValue(option, value)
if err == nil {
option.activeDefaultValue = valueCache
} else {
err = vErr
}
}
// Add the "restart pending" annotation if the settings requires a restart.
if option.RequiresRestart {
option.setAnnotation(RestartPendingAnnotation, true)
}
handleOptionUpdate(option, push)
option.Unlock()

View file

@ -1,9 +1,9 @@
//nolint:goconst
//nolint:goconst,errcheck
package config
import "testing"
func TestLayersGetters(t *testing.T) { //nolint:paralleltest
func TestLayersGetters(t *testing.T) {
// reset
options = make(map[string]*Option)
@ -24,9 +24,9 @@ func TestLayersGetters(t *testing.T) { //nolint:paralleltest
t.Fatal(err)
}
validationErrors, _ := ReplaceConfig(mapData)
if len(validationErrors) > 0 {
t.Fatalf("%d errors, first: %s", len(validationErrors), validationErrors[0].Error())
err = replaceConfig(mapData)
if err != nil {
t.Fatal(err)
}
// Test missing values
@ -77,13 +77,14 @@ func TestLayersGetters(t *testing.T) { //nolint:paralleltest
if notBool() {
t.Error("expected fallback value: false")
}
}
func TestLayersSetters(t *testing.T) { //nolint:paralleltest
func TestLayersSetters(t *testing.T) {
// reset
options = make(map[string]*Option)
_ = Register(&Option{
Register(&Option{
Name: "name",
Key: "monkey",
Description: "description",
@ -93,7 +94,7 @@ func TestLayersSetters(t *testing.T) { //nolint:paralleltest
DefaultValue: "banana",
ValidationRegex: "^(banana|water)$",
})
_ = Register(&Option{
Register(&Option{
Name: "name",
Key: "zebras/zebra",
Description: "description",
@ -103,7 +104,7 @@ func TestLayersSetters(t *testing.T) { //nolint:paralleltest
DefaultValue: []string{"black", "white"},
ValidationRegex: "^[a-z]+$",
})
_ = Register(&Option{
Register(&Option{
Name: "name",
Key: "elephant",
Description: "description",
@ -113,7 +114,7 @@ func TestLayersSetters(t *testing.T) { //nolint:paralleltest
DefaultValue: 2,
ValidationRegex: "",
})
_ = Register(&Option{
Register(&Option{
Name: "name",
Key: "hot",
Description: "description",
@ -190,4 +191,5 @@ func TestLayersSetters(t *testing.T) { //nolint:paralleltest
if err := SetDefaultConfigOption("invalid_delete", nil); err == nil {
t.Error("should fail")
}
}

View file

@ -5,8 +5,6 @@ import (
"fmt"
"math"
"reflect"
"github.com/safing/portbase/log"
)
type valueCache struct {
@ -26,8 +24,6 @@ func (vc *valueCache) getData(opt *Option) interface{} {
return vc.stringVal
case OptTypeStringArray:
return vc.stringArrayVal
case optTypeAny:
return nil
default:
return nil
}
@ -63,177 +59,110 @@ func isAllowedPossibleValue(opt *Option, value interface{}) error {
}
}
return errors.New("value is not allowed")
}
// migrateValue runs all value migrations.
func migrateValue(option *Option, value any) any {
for _, migration := range option.Migrations {
newValue := migration(option, value)
if newValue != value {
log.Debugf("config: migrated %s value from %v to %v", option.Key, value, newValue)
}
value = newValue
}
return value
return fmt.Errorf("value is not allowed")
}
// validateValue ensures that value matches the expected type of option.
// It does not create a copy of the value!
func validateValue(option *Option, value interface{}) (*valueCache, *ValidationError) { //nolint:gocyclo
func validateValue(option *Option, value interface{}) (*valueCache, error) { //nolint:gocyclo
if option.OptType != OptTypeStringArray {
if err := isAllowedPossibleValue(option, value); err != nil {
return nil, &ValidationError{
Option: option.copyOrNil(),
Err: err,
}
return nil, fmt.Errorf("validation of option %s failed for %v: %w", option.Key, value, err)
}
}
var validated *valueCache
reflect.TypeOf(value).ConvertibleTo(reflect.TypeOf(""))
switch v := value.(type) {
case string:
if option.OptType != OptTypeString {
return nil, invalid(option, "expected type %s, got type %T", getTypeName(option.OptType), v)
return nil, fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
}
if option.compiledRegex != nil {
if !option.compiledRegex.MatchString(v) {
return nil, invalid(option, "did not match validation regex")
return nil, fmt.Errorf("validation of option %s failed: string \"%s\" did not match validation regex for option", option.Key, v)
}
}
validated = &valueCache{stringVal: v}
return &valueCache{stringVal: v}, nil
case []interface{}:
vConverted := make([]string, len(v))
for pos, entry := range v {
s, ok := entry.(string)
if !ok {
return nil, invalid(option, "entry #%d is not a string", pos+1)
return nil, fmt.Errorf("validation of option %s failed: element %+v at index %d is not a string", option.Key, entry, pos)
}
vConverted[pos] = s
}
// Call validation function again with converted value.
var vErr *ValidationError
validated, vErr = validateValue(option, vConverted)
if vErr != nil {
return nil, vErr
}
// continue to next case
return validateValue(option, vConverted)
case []string:
if option.OptType != OptTypeStringArray {
return nil, invalid(option, "expected type %s, got type %T", getTypeName(option.OptType), v)
return nil, fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
}
if option.compiledRegex != nil {
for pos, entry := range v {
if !option.compiledRegex.MatchString(entry) {
return nil, invalid(option, "entry #%d did not match validation regex", pos+1)
return nil, fmt.Errorf("validation of option %s failed: string \"%s\" at index %d did not match validation regex", option.Key, entry, pos)
}
if err := isAllowedPossibleValue(option, entry); err != nil {
return nil, invalid(option, "entry #%d is not allowed", pos+1)
return nil, fmt.Errorf("validation of option %s failed: string %q at index %d is not allowed", option.Key, entry, pos)
}
}
}
validated = &valueCache{stringArrayVal: v}
return &valueCache{stringArrayVal: v}, nil
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, float32, float64:
// uint64 is omitted, as it does not fit in a int64
if option.OptType != OptTypeInt {
return nil, invalid(option, "expected type %s, got type %T", getTypeName(option.OptType), v)
return nil, fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
}
if option.compiledRegex != nil {
// we need to use %v here so we handle float and int correctly.
if !option.compiledRegex.MatchString(fmt.Sprintf("%v", v)) {
return nil, invalid(option, "did not match validation regex")
return nil, fmt.Errorf("validation of option %s failed: number \"%d\" did not match validation regex", option.Key, v)
}
}
switch v := value.(type) {
case int:
validated = &valueCache{intVal: int64(v)}
return &valueCache{intVal: int64(v)}, nil
case int8:
validated = &valueCache{intVal: int64(v)}
return &valueCache{intVal: int64(v)}, nil
case int16:
validated = &valueCache{intVal: int64(v)}
return &valueCache{intVal: int64(v)}, nil
case int32:
validated = &valueCache{intVal: int64(v)}
return &valueCache{intVal: int64(v)}, nil
case int64:
validated = &valueCache{intVal: v}
return &valueCache{intVal: v}, nil
case uint:
validated = &valueCache{intVal: int64(v)}
return &valueCache{intVal: int64(v)}, nil
case uint8:
validated = &valueCache{intVal: int64(v)}
return &valueCache{intVal: int64(v)}, nil
case uint16:
validated = &valueCache{intVal: int64(v)}
return &valueCache{intVal: int64(v)}, nil
case uint32:
validated = &valueCache{intVal: int64(v)}
return &valueCache{intVal: int64(v)}, nil
case float32:
// convert if float has no decimals
if math.Remainder(float64(v), 1) == 0 {
validated = &valueCache{intVal: int64(v)}
} else {
return nil, invalid(option, "failed to convert float32 to int64")
return &valueCache{intVal: int64(v)}, nil
}
return nil, fmt.Errorf("failed to convert float32 to int64 for option %s, got value %+v", option.Key, v)
case float64:
// convert if float has no decimals
if math.Remainder(v, 1) == 0 {
validated = &valueCache{intVal: int64(v)}
} else {
return nil, invalid(option, "failed to convert float64 to int64")
return &valueCache{intVal: int64(v)}, nil
}
return nil, fmt.Errorf("failed to convert float64 to int64 for option %s, got value %+v", option.Key, v)
default:
return nil, invalid(option, "internal error")
return nil, errors.New("internal error")
}
case bool:
if option.OptType != OptTypeBool {
return nil, invalid(option, "expected type %s, got type %T", getTypeName(option.OptType), v)
return nil, fmt.Errorf("expected type %s for option %s, got type %T", getTypeName(option.OptType), option.Key, v)
}
validated = &valueCache{boolVal: v}
return &valueCache{boolVal: v}, nil
default:
return nil, invalid(option, "invalid option value type: %T", value)
}
// Check if there is an additional function to validate the value.
if option.ValidationFunc != nil {
var err error
switch option.OptType {
case optTypeAny:
err = errors.New("internal error")
case OptTypeString:
err = option.ValidationFunc(validated.stringVal)
case OptTypeStringArray:
err = option.ValidationFunc(validated.stringArrayVal)
case OptTypeInt:
err = option.ValidationFunc(validated.intVal)
case OptTypeBool:
err = option.ValidationFunc(validated.boolVal)
}
if err != nil {
return nil, &ValidationError{
Option: option.copyOrNil(),
Err: err,
}
}
}
return validated, nil
}
// ValidationError error holds details about a config option value validation error.
type ValidationError struct {
Option *Option
Err error
}
// Error returns the formatted error.
func (ve *ValidationError) Error() string {
return fmt.Sprintf("validation of %s failed: %s", ve.Option.Key, ve.Err)
}
// Unwrap returns the wrapped error.
func (ve *ValidationError) Unwrap() error {
return ve.Err
}
func invalid(option *Option, format string, a ...interface{}) *ValidationError {
return &ValidationError{
Option: option.copyOrNil(),
Err: fmt.Errorf(format, a...),
return nil, fmt.Errorf("invalid option value type for option %s: %T", option.Key, value)
}
}

View file

@ -10,11 +10,9 @@ type ValidityFlag struct {
}
// NewValidityFlag returns a flag that signifies if the configuration has been changed.
// It always starts out as invalid. Refresh to start with the current value.
func NewValidityFlag() *ValidityFlag {
vf := &ValidityFlag{
flag: abool.New(),
}
vf := &ValidityFlag{}
vf.Refresh()
return vf
}

View file

@ -7,7 +7,7 @@ import (
"github.com/safing/portbase/formats/varint"
)
// Container is []byte sclie on steroids, allowing for quick data appending, prepending and fetching.
// Container is []byte sclie on steroids, allowing for quick data appending, prepending and fetching as well as transparent error transportation. (Error transportation requires use of varints for data)
type Container struct {
compartments [][]byte
offset int
@ -127,7 +127,7 @@ func (c *Container) CompileData() []byte {
// Get returns the given amount of bytes. Data MAY be copied and IS consumed.
func (c *Container) Get(n int) ([]byte, error) {
buf := c.Peek(n)
buf := c.gather(n)
if len(buf) < n {
return nil, errors.New("container: not enough data to return")
}
@ -138,24 +138,24 @@ func (c *Container) Get(n int) ([]byte, error) {
// GetAll returns all data. Data MAY be copied and IS consumed.
func (c *Container) GetAll() []byte {
// TODO: Improve.
buf := c.Peek(c.Length())
buf := c.gather(c.Length())
c.skip(len(buf))
return buf
}
// GetAsContainer returns the given amount of bytes in a new container. Data will NOT be copied and IS consumed.
func (c *Container) GetAsContainer(n int) (*Container, error) {
newC := c.PeekContainer(n)
if newC == nil {
new := c.gatherAsContainer(n)
if new == nil {
return nil, errors.New("container: not enough data to return")
}
c.skip(n)
return newC, nil
return new, nil
}
// GetMax returns as much as possible, but the given amount of bytes at maximum. Data MAY be copied and IS consumed.
func (c *Container) GetMax(n int) []byte {
buf := c.Peek(n)
buf := c.gather(n)
c.skip(len(buf))
return buf
}
@ -211,13 +211,17 @@ func (c *Container) renewCompartments() {
}
func (c *Container) carbonCopy() *Container {
newC := &Container{
new := &Container{
compartments: make([][]byte, len(c.compartments)),
offset: c.offset,
err: c.err,
}
copy(newC.compartments, c.compartments)
return newC
for i := 0; i < len(c.compartments); i++ {
new.compartments[i] = c.compartments[i]
}
// TODO: investigate why copy fails to correctly duplicate [][]byte
// copy(new.compartments, c.compartments)
return new
}
func (c *Container) checkOffset() {
@ -226,6 +230,42 @@ func (c *Container) checkOffset() {
}
}
// Error Handling
/*
DEPRECATING... like.... NOW.
// SetError sets an error.
func (c *Container) SetError(err error) {
c.err = err
c.Replace(append([]byte{0x00}, []byte(err.Error())...))
}
// CheckError checks if there is an error in the data. If so, it will parse the error and delete the data.
func (c *Container) CheckError() {
if len(c.compartments[c.offset]) > 0 && c.compartments[c.offset][0] == 0x00 {
c.compartments[c.offset] = c.compartments[c.offset][1:]
c.err = errors.New(string(c.CompileData()))
c.compartments = nil
}
}
// HasError returns wether or not the container is holding an error.
func (c *Container) HasError() bool {
return c.err != nil
}
// Error returns the error.
func (c *Container) Error() error {
return c.err
}
// ErrString returns the error as a string.
func (c *Container) ErrString() string {
return c.err.Error()
}
*/
// Block Handling
// PrependLength prepends the current full length of all bytes in the container.
@ -233,8 +273,7 @@ func (c *Container) PrependLength() {
c.Prepend(varint.Pack64(uint64(c.Length())))
}
// Peek returns the given amount of bytes. Data MAY be copied and IS NOT consumed.
func (c *Container) Peek(n int) []byte {
func (c *Container) gather(n int) []byte {
// Check requested length.
if n <= 0 {
return nil
@ -261,8 +300,7 @@ func (c *Container) Peek(n int) []byte {
return slice[:n]
}
// PeekContainer returns the given amount of bytes in a new container. Data will NOT be copied and IS NOT consumed.
func (c *Container) PeekContainer(n int) (newC *Container) {
func (c *Container) gatherAsContainer(n int) (new *Container) {
// Check requested length.
if n < 0 {
return nil
@ -270,20 +308,20 @@ func (c *Container) PeekContainer(n int) (newC *Container) {
return &Container{}
}
newC = &Container{}
new = &Container{}
for i := c.offset; i < len(c.compartments); i++ {
if n >= len(c.compartments[i]) {
newC.compartments = append(newC.compartments, c.compartments[i])
new.compartments = append(new.compartments, c.compartments[i])
n -= len(c.compartments[i])
} else {
newC.compartments = append(newC.compartments, c.compartments[i][:n])
new.compartments = append(new.compartments, c.compartments[i][:n])
n = 0
}
}
if n > 0 {
return nil
}
return newC
return new
}
func (c *Container) skip(n int) {
@ -325,7 +363,7 @@ func (c *Container) GetNextBlockAsContainer() (*Container, error) {
// GetNextN8 parses and returns a varint of type uint8.
func (c *Container) GetNextN8() (uint8, error) {
buf := c.Peek(2)
buf := c.gather(2)
num, n, err := varint.Unpack8(buf)
if err != nil {
return 0, err
@ -336,7 +374,7 @@ func (c *Container) GetNextN8() (uint8, error) {
// GetNextN16 parses and returns a varint of type uint16.
func (c *Container) GetNextN16() (uint16, error) {
buf := c.Peek(3)
buf := c.gather(3)
num, n, err := varint.Unpack16(buf)
if err != nil {
return 0, err
@ -347,7 +385,7 @@ func (c *Container) GetNextN16() (uint16, error) {
// GetNextN32 parses and returns a varint of type uint32.
func (c *Container) GetNextN32() (uint32, error) {
buf := c.Peek(5)
buf := c.gather(5)
num, n, err := varint.Unpack32(buf)
if err != nil {
return 0, err
@ -358,7 +396,7 @@ func (c *Container) GetNextN32() (uint32, error) {
// GetNextN64 parses and returns a varint of type uint64.
func (c *Container) GetNextN64() (uint64, error) {
buf := c.Peek(10)
buf := c.gather(10)
num, n, err := varint.Unpack64(buf)
if err != nil {
return 0, err

View file

@ -23,7 +23,6 @@ var (
)
func TestContainerDataHandling(t *testing.T) {
t.Parallel()
c1 := New(utils.DuplicateBytes(testData))
c1c := c1.carbonCopy()
@ -66,17 +65,15 @@ func TestContainerDataHandling(t *testing.T) {
}
c8.clean()
c9 := c8.PeekContainer(len(testData))
c9 := c8.gatherAsContainer(len(testData))
c10 := c9.PeekContainer(len(testData) - 1)
c10 := c9.gatherAsContainer(len(testData) - 1)
c10.Append(testData[len(testData)-1:])
compareMany(t, testData, c1.CompileData(), c2.CompileData(), c3.CompileData(), d4, d5, c6.CompileData(), c7.CompileData(), c8.CompileData(), c9.CompileData(), c10.CompileData())
}
func compareMany(t *testing.T, reference []byte, other ...[]byte) {
t.Helper()
for i, cmp := range other {
if !bytes.Equal(reference, cmp) {
t.Errorf("sample %d does not match reference: sample is '%s'", i+1, string(cmp))
@ -85,8 +82,6 @@ func compareMany(t *testing.T, reference []byte, other ...[]byte) {
}
func TestDataFetching(t *testing.T) {
t.Parallel()
c1 := New(utils.DuplicateBytes(testData))
data := c1.GetMax(1)
if string(data[0]) != "T" {
@ -105,8 +100,6 @@ func TestDataFetching(t *testing.T) {
}
func TestBlocks(t *testing.T) {
t.Parallel()
c1 := New(utils.DuplicateBytes(testData))
c1.PrependLength()
@ -144,10 +137,10 @@ func TestBlocks(t *testing.T) {
if n4 != 43 {
t.Errorf("n should be 43, was %d", n4)
}
}
func TestContainerBlockHandling(t *testing.T) {
t.Parallel()
c1 := New(utils.DuplicateBytes(testData))
c1.PrependLength()
@ -192,8 +185,6 @@ func TestContainerBlockHandling(t *testing.T) {
}
func TestContainerMisc(t *testing.T) {
t.Parallel()
c1 := New()
d1 := c1.CompileData()
if len(d1) > 0 {
@ -202,7 +193,5 @@ func TestContainerMisc(t *testing.T) {
}
func TestDeprecated(t *testing.T) {
t.Parallel()
NewContainer(utils.DuplicateBytes(testData))
}

View file

@ -5,22 +5,23 @@
// Byte slices added to the Container are not changed or appended, to not corrupt any other data that may be before and after the given slice.
// If interested, consider the following example to understand why this is important:
//
// package main
// package main
//
// import (
// "fmt"
// )
// import (
// "fmt"
// )
//
// func main() {
// a := []byte{0, 1,2,3,4,5,6,7,8,9}
// fmt.Printf("a: %+v\n", a)
// fmt.Printf("\nmaking changes...\n(we are not changing a directly)\n\n")
// b := a[2:6]
// c := append(b, 10, 11)
// fmt.Printf("b: %+v\n", b)
// fmt.Printf("c: %+v\n", c)
// fmt.Printf("a: %+v\n", a)
// }
// func main() {
// a := []byte{0, 1,2,3,4,5,6,7,8,9}
// fmt.Printf("a: %+v\n", a)
// fmt.Printf("\nmaking changes...\n(we are not changing a directly)\n\n")
// b := a[2:6]
// c := append(b, 10, 11)
// fmt.Printf("b: %+v\n", b)
// fmt.Printf("c: %+v\n", c)
// fmt.Printf("a: %+v\n", a)
// }
//
// run it here: https://play.golang.org/p/xu1BXT3QYeE
//
package container

View file

@ -27,11 +27,11 @@ func (ja *JSONBytesAccessor) Set(key string, value interface{}) error {
}
}
newJSON, err := sjson.SetBytes(*ja.json, key, value)
new, err := sjson.SetBytes(*ja.json, key, value)
if err != nil {
return err
}
*ja.json = newJSON
*ja.json = new
return nil
}
@ -60,15 +60,15 @@ func (ja *JSONBytesAccessor) GetStringArray(key string) (value []string, ok bool
return nil, false
}
slice := result.Array()
sliceCopy := make([]string, len(slice))
new := make([]string, len(slice))
for i, res := range slice {
if res.Type == gjson.String {
sliceCopy[i] = res.String()
new[i] = res.String()
} else {
return nil, false
}
}
return sliceCopy, true
return new, true
}
// GetInt returns the int found by the given json key and whether it could be successfully extracted.

View file

@ -29,11 +29,11 @@ func (ja *JSONAccessor) Set(key string, value interface{}) error {
}
}
newJSON, err := sjson.Set(*ja.json, key, value)
new, err := sjson.Set(*ja.json, key, value)
if err != nil {
return err
}
*ja.json = newJSON
*ja.json = new
return nil
}
@ -84,15 +84,15 @@ func (ja *JSONAccessor) GetStringArray(key string) (value []string, ok bool) {
return nil, false
}
slice := result.Array()
sliceCopy := make([]string, len(slice))
new := make([]string, len(slice))
for i, res := range slice {
if res.Type == gjson.String {
sliceCopy[i] = res.String()
new[i] = res.String()
} else {
return nil, false
}
}
return sliceCopy, true
return new, true
}
// GetInt returns the int found by the given json key and whether it could be successfully extracted.

View file

@ -37,12 +37,12 @@ func (sa *StructAccessor) Set(key string, value interface{}) error {
}
// handle special cases
switch field.Kind() { // nolint:exhaustive
switch field.Kind() {
// ints
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
var newInt int64
switch newVal.Kind() { // nolint:exhaustive
switch newVal.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
newInt = newVal.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
@ -58,7 +58,7 @@ func (sa *StructAccessor) Set(key string, value interface{}) error {
// uints
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
var newUint uint64
switch newVal.Kind() { // nolint:exhaustive
switch newVal.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
newUint = uint64(newVal.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
@ -73,7 +73,7 @@ func (sa *StructAccessor) Set(key string, value interface{}) error {
// floats
case reflect.Float32, reflect.Float64:
switch newVal.Kind() { // nolint:exhaustive
switch newVal.Kind() {
case reflect.Float32, reflect.Float64:
field.SetFloat(newVal.Float())
default:
@ -124,7 +124,7 @@ func (sa *StructAccessor) GetInt(key string) (value int64, ok bool) {
if !field.IsValid() {
return 0, false
}
switch field.Kind() { // nolint:exhaustive
switch field.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return field.Int(), true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
@ -140,7 +140,7 @@ func (sa *StructAccessor) GetFloat(key string) (value float64, ok bool) {
if !field.IsValid() {
return 0, false
}
switch field.Kind() { // nolint:exhaustive
switch field.Kind() {
case reflect.Float32, reflect.Float64:
return field.Float(), true
default:

View file

@ -13,6 +13,8 @@ type Accessor interface {
GetFloat(key string) (value float64, ok bool)
GetBool(key string) (value bool, ok bool)
Exists(key string) bool
Set(key string, value interface{}) error
Type() string
}

View file

@ -44,13 +44,11 @@ var (
F64: 42.42,
B: true,
}
testJSONBytes, _ = json.Marshal(testStruct) //nolint:errchkjson
testJSONBytes, _ = json.Marshal(testStruct)
testJSON = string(testJSONBytes)
)
func testGetString(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue string) {
t.Helper()
v, ok := acc.GetString(key)
switch {
case !ok && shouldSucceed:
@ -64,8 +62,6 @@ func testGetString(t *testing.T, acc Accessor, key string, shouldSucceed bool, e
}
func testGetStringArray(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue []string) {
t.Helper()
v, ok := acc.GetStringArray(key)
switch {
case !ok && shouldSucceed:
@ -79,8 +75,6 @@ func testGetStringArray(t *testing.T, acc Accessor, key string, shouldSucceed bo
}
func testGetInt(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue int64) {
t.Helper()
v, ok := acc.GetInt(key)
switch {
case !ok && shouldSucceed:
@ -94,8 +88,6 @@ func testGetInt(t *testing.T, acc Accessor, key string, shouldSucceed bool, expe
}
func testGetFloat(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue float64) {
t.Helper()
v, ok := acc.GetFloat(key)
switch {
case !ok && shouldSucceed:
@ -109,8 +101,6 @@ func testGetFloat(t *testing.T, acc Accessor, key string, shouldSucceed bool, ex
}
func testGetBool(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue bool) {
t.Helper()
v, ok := acc.GetBool(key)
switch {
case !ok && shouldSucceed:
@ -124,8 +114,6 @@ func testGetBool(t *testing.T, acc Accessor, key string, shouldSucceed bool, exp
}
func testExists(t *testing.T, acc Accessor, key string, shouldSucceed bool) {
t.Helper()
ok := acc.Exists(key)
switch {
case !ok && shouldSucceed:
@ -136,8 +124,6 @@ func testExists(t *testing.T, acc Accessor, key string, shouldSucceed bool) {
}
func testSet(t *testing.T, acc Accessor, key string, shouldSucceed bool, valueToSet interface{}) {
t.Helper()
err := acc.Set(key, valueToSet)
switch {
case err != nil && shouldSucceed:
@ -148,9 +134,8 @@ func testSet(t *testing.T, acc Accessor, key string, shouldSucceed bool, valueTo
}
func TestAccessor(t *testing.T) {
t.Parallel()
// Test interface compliance.
// Test interface compliance
accs := []Accessor{
NewJSONAccessor(&testJSON),
NewJSONBytesAccessor(&testJSONBytes),
@ -288,4 +273,5 @@ func TestAccessor(t *testing.T) {
for _, acc := range accs {
testExists(t, acc, "X", false)
}
}

View file

@ -15,10 +15,12 @@ type Example struct {
Score int
}
var exampleDB = NewInterface(&Options{
Internal: true,
Local: true,
})
var (
exampleDB = NewInterface(&Options{
Internal: true,
Local: true,
})
)
// GetExample gets an Example from the database.
func GetExample(key string) (*Example, error) {
@ -30,20 +32,20 @@ func GetExample(key string) (*Example, error) {
// unwrap
if r.IsWrapped() {
// only allocate a new struct, if we need it
newExample := &Example{}
err = record.Unwrap(r, newExample)
new := &Example{}
err = record.Unwrap(r, new)
if err != nil {
return nil, err
}
return newExample, nil
return new, nil
}
// or adjust type
newExample, ok := r.(*Example)
new, ok := r.(*Example)
if !ok {
return nil, fmt.Errorf("record not of type *Example, but %T", r)
}
return newExample, nil
return new, nil
}
func (e *Example) Save() error {
@ -56,10 +58,10 @@ func (e *Example) SaveAs(key string) error {
}
func NewExample(key, name string, score int) *Example {
newExample := &Example{
new := &Example{
Name: name,
Score: score,
}
newExample.SetKey(key)
return newExample
new.SetKey(key)
return new
}

View file

@ -14,7 +14,6 @@ import (
// A Controller takes care of all the extra database logic.
type Controller struct {
database *Database
storage storage.Interface
shadowDelete bool
@ -26,9 +25,8 @@ type Controller struct {
}
// newController creates a new controller for a storage.
func newController(database *Database, storageInt storage.Interface, shadowDelete bool) *Controller {
func newController(storageInt storage.Interface, shadowDelete bool) *Controller {
return &Controller{
database: database,
storage: storageInt,
shadowDelete: shadowDelete,
}
@ -78,7 +76,7 @@ func (c *Controller) Get(key string) (record.Record, error) {
return r, nil
}
// GetMeta returns the metadata of the record with the given key.
// Get returns the metadata of the record with the given key.
func (c *Controller) GetMeta(key string) (*record.Meta, error) {
if shuttingDown.IsSet() {
return nil, ErrShuttingDown

View file

@ -8,9 +8,6 @@ import (
"github.com/safing/portbase/database/storage"
)
// StorageTypeInjected is the type of injected databases.
const StorageTypeInjected = "injected"
var (
controllers = make(map[string]*Controller)
controllersLock sync.RWMutex
@ -39,27 +36,22 @@ func getController(name string) (*Controller, error) {
// get db registration
registeredDB, err := getDatabase(name)
if err != nil {
return nil, fmt.Errorf("could not start database %s: %w", name, err)
}
// Check if database is injected.
if registeredDB.StorageType == StorageTypeInjected {
return nil, fmt.Errorf("database storage is not injected")
return nil, fmt.Errorf(`could not start database %s: %s`, name, err)
}
// get location
dbLocation, err := getLocation(name, registeredDB.StorageType)
if err != nil {
return nil, fmt.Errorf("could not start database %s (type %s): %w", name, registeredDB.StorageType, err)
return nil, fmt.Errorf(`could not start database %s (type %s): %s`, name, registeredDB.StorageType, err)
}
// start database
storageInt, err := storage.StartDatabase(name, registeredDB.StorageType, dbLocation)
if err != nil {
return nil, fmt.Errorf("could not start database %s (type %s): %w", name, registeredDB.StorageType, err)
return nil, fmt.Errorf(`could not start database %s (type %s): %s`, name, registeredDB.StorageType, err)
}
controller = newController(registeredDB, storageInt, registeredDB.ShadowDelete)
controller = newController(storageInt, registeredDB.ShadowDelete)
controllers[name] = controller
return controller, nil
}
@ -84,23 +76,13 @@ func InjectDatabase(name string, storageInt storage.Interface) (*Controller, err
// check if database is registered
registeredDB, ok := registry[name]
if !ok {
return nil, fmt.Errorf("database %q not registered", name)
return nil, fmt.Errorf(`database "%s" not registered`, name)
}
if registeredDB.StorageType != StorageTypeInjected {
return nil, fmt.Errorf("database not of type %q", StorageTypeInjected)
if registeredDB.StorageType != "injected" {
return nil, fmt.Errorf(`database not of type "injected"`)
}
controller := newController(registeredDB, storageInt, false)
controller := newController(storageInt, false)
controllers[name] = controller
return controller, nil
}
// Withdraw withdraws an injected database, but leaves the database registered.
func (c *Controller) Withdraw() {
if c != nil && c.Injected() {
controllersLock.Lock()
defer controllersLock.Unlock()
delete(controllers, c.database.Name)
}
}

View file

@ -4,7 +4,7 @@ import (
"time"
)
// Database holds information about a registered database.
// Database holds information about registered databases
type Database struct {
Name string
Description string

View file

@ -2,8 +2,8 @@ package database
import (
"context"
"errors"
"fmt"
"io/ioutil"
"log"
"os"
"reflect"
@ -11,9 +11,11 @@ import (
"testing"
"time"
"github.com/safing/portbase/database/storage"
q "github.com/safing/portbase/database/query"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/database/storage"
_ "github.com/safing/portbase/database/storage/badger"
_ "github.com/safing/portbase/database/storage/bbolt"
_ "github.com/safing/portbase/database/storage/fstree"
@ -21,7 +23,7 @@ import (
)
func TestMain(m *testing.M) {
testDir, err := os.MkdirTemp("", "portbase-database-testing-")
testDir, err := ioutil.TempDir("", "portbase-database-testing-")
if err != nil {
panic(err)
}
@ -35,7 +37,7 @@ func TestMain(m *testing.M) {
// Clean up the test directory.
// Do not defer, as we end this function with a os.Exit call.
_ = os.RemoveAll(testDir)
os.RemoveAll(testDir)
os.Exit(exitCode)
}
@ -44,7 +46,7 @@ func makeKey(dbName, key string) string {
return fmt.Sprintf("%s:%s", dbName, key)
}
func testDatabase(t *testing.T, storageType string, shadowDelete bool) { //nolint:maintidx,thelper
func testDatabase(t *testing.T, storageType string, shadowDelete bool) { //nolint:gocognit,gocyclo
t.Run(fmt.Sprintf("TestStorage_%s_%v", storageType, shadowDelete), func(t *testing.T) {
dbName := fmt.Sprintf("testing-%s-%v", storageType, shadowDelete)
fmt.Println(dbName)
@ -178,7 +180,7 @@ func testDatabase(t *testing.T, storageType string, shadowDelete bool) { //nolin
// check status individually
_, err = dbController.storage.Get("A")
if !errors.Is(err, storage.ErrNotFound) {
if err != storage.ErrNotFound {
t.Errorf("A should be deleted and purged, err=%s", err)
}
B1, err := dbController.storage.Get("B")
@ -206,13 +208,13 @@ func testDatabase(t *testing.T, storageType string, shadowDelete bool) { //nolin
B2, err := dbController.storage.Get("B")
if err == nil {
t.Errorf("B should be deleted and purged, meta: %+v", B2.Meta())
} else if !errors.Is(err, storage.ErrNotFound) {
} else if err != storage.ErrNotFound {
t.Errorf("B should be deleted and purged, err=%s", err)
}
C2, err := dbController.storage.Get("C")
if err == nil {
t.Errorf("C should be deleted and purged, meta: %+v", C2.Meta())
} else if !errors.Is(err, storage.ErrNotFound) {
} else if err != storage.ErrNotFound {
t.Errorf("C should be deleted and purged, err=%s", err)
}
@ -231,11 +233,11 @@ func testDatabase(t *testing.T, storageType string, shadowDelete bool) { //nolin
if err != nil {
t.Fatal(err)
}
})
}
func TestDatabaseSystem(t *testing.T) { //nolint:tparallel
t.Parallel()
func TestDatabaseSystem(t *testing.T) {
// panic after 10 seconds, to check for locks
finished := make(chan struct{})
@ -280,8 +282,6 @@ func TestDatabaseSystem(t *testing.T) { //nolint:tparallel
}
func countRecords(t *testing.T, db *Interface, query *q.Query) int {
t.Helper()
_, err := query.Check()
if err != nil {
t.Fatal(err)

View file

@ -1,62 +1,63 @@
/*
Package database provides a universal interface for interacting with the database.
# A Lazy Database
A Lazy Database
The database system can handle Go structs as well as serialized data by the dsd package.
While data is in transit within the system, it does not know which form it currently has. Only when it reaches its destination, it must ensure that it is either of a certain type or dump it.
# Record Interface
Record Interface
The database system uses the Record interface to transparently handle all types of structs that get saved in the database. Structs include the Base struct to fulfill most parts of the Record interface.
Boilerplate Code:
type Example struct {
record.Base
sync.Mutex
type Example struct {
record.Base
sync.Mutex
Name string
Score int
}
Name string
Score int
}
var (
db = database.NewInterface(nil)
)
var (
db = database.NewInterface(nil)
)
// GetExample gets an Example from the database.
func GetExample(key string) (*Example, error) {
r, err := db.Get(key)
if err != nil {
return nil, err
}
// GetExample gets an Example from the database.
func GetExample(key string) (*Example, error) {
r, err := db.Get(key)
if err != nil {
return nil, err
}
// unwrap
if r.IsWrapped() {
// only allocate a new struct, if we need it
new := &Example{}
err = record.Unwrap(r, new)
if err != nil {
return nil, err
}
return new, nil
}
// unwrap
if r.IsWrapped() {
// only allocate a new struct, if we need it
new := &Example{}
err = record.Unwrap(r, new)
if err != nil {
return nil, err
}
return new, nil
}
// or adjust type
new, ok := r.(*Example)
if !ok {
return nil, fmt.Errorf("record not of type *Example, but %T", r)
}
return new, nil
}
// or adjust type
new, ok := r.(*Example)
if !ok {
return nil, fmt.Errorf("record not of type *Example, but %T", r)
}
return new, nil
}
func (e *Example) Save() error {
return db.Put(e)
}
func (e *Example) Save() error {
return db.Put(e)
}
func (e *Example) SaveAs(key string) error {
e.SetKey(key)
return db.PutNew(e)
}
func (e *Example) SaveAs(key string) error {
e.SetKey(key)
return db.PutNew(e)
}
*/
package database

View file

@ -4,7 +4,7 @@ import (
"errors"
)
// Errors.
// Errors
var (
ErrNotFound = errors.New("database entry not found")
ErrPermissionDenied = errors.New("access to database record denied")

View file

@ -5,7 +5,8 @@ import (
)
// HookBase implements the Hook interface and provides dummy functions to reduce boilerplate.
type HookBase struct{}
type HookBase struct {
}
// UsesPreGet implements the Hook interface and returns false.
func (b *HookBase) UsesPreGet() bool {

View file

@ -120,19 +120,19 @@ func NewInterface(opts *Options) *Interface {
opts = &Options{}
}
newIface := &Interface{
new := &Interface{
options: opts,
}
if opts.CacheSize > 0 {
cacheBuilder := gcache.New(opts.CacheSize).ARC()
if opts.DelayCachedWrites != "" {
cacheBuilder.EvictedFunc(newIface.cacheEvictHandler)
newIface.writeCache = make(map[string]record.Record, opts.CacheSize/2)
newIface.triggerCacheWrite = make(chan struct{})
cacheBuilder.EvictedFunc(new.cacheEvictHandler)
new.writeCache = make(map[string]record.Record, opts.CacheSize/2)
new.triggerCacheWrite = make(chan struct{})
}
newIface.cache = cacheBuilder.Build()
new.cache = cacheBuilder.Build()
}
return newIface
return new
}
// Exists return whether a record with the given key exists.
@ -157,7 +157,7 @@ func (i *Interface) Get(key string) (record.Record, error) {
return r, err
}
func (i *Interface) getRecord(dbName string, dbKey string, mustBeWriteable bool) (r record.Record, db *Controller, err error) { //nolint:unparam
func (i *Interface) getRecord(dbName string, dbKey string, mustBeWriteable bool) (r record.Record, db *Controller, err error) {
if dbName == "" {
dbName, dbKey = record.ParseKey(dbKey)
}
@ -201,7 +201,7 @@ func (i *Interface) getRecord(dbName string, dbKey string, mustBeWriteable bool)
return r, db, nil
}
func (i *Interface) getMeta(dbName string, dbKey string, mustBeWriteable bool) (m *record.Meta, db *Controller, err error) { //nolint:unparam
func (i *Interface) getMeta(dbName string, dbKey string, mustBeWriteable bool) (m *record.Meta, db *Controller, err error) {
if dbName == "" {
dbName, dbKey = record.ParseKey(dbKey)
}
@ -258,7 +258,7 @@ func (i *Interface) InsertValue(key string, attribute string, value interface{})
err = acc.Set(attribute, value)
if err != nil {
return fmt.Errorf("failed to set value with %s: %w", acc.Type(), err)
return fmt.Errorf("failed to set value with %s: %s", acc.Type(), err)
}
i.options.Apply(r)
@ -271,7 +271,7 @@ func (i *Interface) Put(r record.Record) (err error) {
var db *Controller
if !i.options.HasAllPermissions() {
_, db, err = i.getMeta(r.DatabaseName(), r.DatabaseKey(), true)
if err != nil && !errors.Is(err, ErrNotFound) {
if err != nil && err != ErrNotFound {
return err
}
} else {
@ -309,7 +309,7 @@ func (i *Interface) PutNew(r record.Record) (err error) {
var db *Controller
if !i.options.HasAllPermissions() {
_, db, err = i.getMeta(r.DatabaseName(), r.DatabaseKey(), true)
if err != nil && !errors.Is(err, ErrNotFound) {
if err != nil && err != ErrNotFound {
return err
}
} else {
@ -344,13 +344,11 @@ func (i *Interface) PutNew(r record.Record) (err error) {
return db.Put(r)
}
// PutMany stores many records in the database.
// Warning: This is nearly a direct database access and omits many things:
// PutMany stores many records in the database. Warning: This is nearly a direct database access and omits many things:
// - Record locking
// - Hooks
// - Subscriptions
// - Caching
// Use with care.
func (i *Interface) PutMany(dbName string) (put func(record.Record) error) {
interfaceBatch := make(chan record.Record, 100)
@ -521,8 +519,6 @@ func (i *Interface) Delete(key string) error {
}
// Query executes the given query on the database.
// Will not see data that is in the write cache, waiting to be written.
// Use with care with caching.
func (i *Interface) Query(q *query.Query) (*iterator.Iterator, error) {
_, err := q.Check()
if err != nil {
@ -534,7 +530,7 @@ func (i *Interface) Query(q *query.Query) (*iterator.Iterator, error) {
return nil, err
}
// TODO: Finish caching system integration.
// FIXME:
// Flush the cache before we query the database.
// i.FlushCache()

View file

@ -45,7 +45,7 @@ func (i *Interface) DelayedCacheWriter(ctx context.Context) error {
i.flushWriteCache(0)
case <-thresholdWriteTicker.C:
// Often check if the write cache has filled up to a certain degree and
// Often check if the the write cache has filled up to a certain degree and
// flush it to storage before we start evicting to-be-written entries and
// slow down the hot path again.
i.flushWriteCache(percentThreshold)
@ -57,6 +57,7 @@ func (i *Interface) DelayedCacheWriter(ctx context.Context) error {
// of a total crash.
i.flushWriteCache(0)
}
}
}

View file

@ -8,7 +8,7 @@ import (
"testing"
)
func benchmarkCacheWriting(b *testing.B, storageType string, cacheSize int, sampleSize int, delayWrites bool) { //nolint:gocognit,gocyclo,thelper
func benchmarkCacheWriting(b *testing.B, storageType string, cacheSize int, sampleSize int, delayWrites bool) { //nolint:gocognit,gocyclo
b.Run(fmt.Sprintf("CacheWriting_%s_%d_%d_%v", storageType, cacheSize, sampleSize, delayWrites), func(b *testing.B) {
// Setup Benchmark.
@ -66,10 +66,11 @@ func benchmarkCacheWriting(b *testing.B, storageType string, cacheSize int, samp
// End cache writer and wait
cancelCtx()
wg.Wait()
})
}
func benchmarkCacheReadWrite(b *testing.B, storageType string, cacheSize int, sampleSize int, delayWrites bool) { //nolint:gocognit,gocyclo,thelper
func benchmarkCacheReadWrite(b *testing.B, storageType string, cacheSize int, sampleSize int, delayWrites bool) { //nolint:gocognit,gocyclo
b.Run(fmt.Sprintf("CacheReadWrite_%s_%d_%d_%v", storageType, cacheSize, sampleSize, delayWrites), func(b *testing.B) {
// Setup Benchmark.
@ -134,6 +135,7 @@ func benchmarkCacheReadWrite(b *testing.B, storageType string, cacheSize int, sa
// End cache writer and wait
cancelCtx()
wg.Wait()
})
}

View file

@ -5,9 +5,8 @@ import (
"fmt"
"path/filepath"
"github.com/tevino/abool"
"github.com/safing/portbase/utils"
"github.com/tevino/abool"
)
const (
@ -26,7 +25,7 @@ var (
// InitializeWithPath initializes the database at the specified location using a path.
func InitializeWithPath(dirPath string) error {
return Initialize(utils.NewDirStructure(dirPath, 0o0755))
return Initialize(utils.NewDirStructure(dirPath, 0755))
}
// Initialize initializes the database at the specified location using a dir structure.
@ -35,16 +34,16 @@ func Initialize(dirStructureRoot *utils.DirStructure) error {
rootStructure = dirStructureRoot
// ensure root and databases dirs
databasesStructure = rootStructure.ChildDir(databasesSubDir, 0o0700)
databasesStructure = rootStructure.ChildDir(databasesSubDir, 0700)
err := databasesStructure.Ensure()
if err != nil {
return fmt.Errorf("could not create/open database directory (%s): %w", rootStructure.Path, err)
return fmt.Errorf("could not create/open database directory (%s): %s", rootStructure.Path, err)
}
if registryPersistence.IsSet() {
err = loadRegistry()
if err != nil {
return fmt.Errorf("could not load database registry (%s): %w", filepath.Join(rootStructure.Path, registryFileName), err)
return fmt.Errorf("could not load database registry (%s): %s", filepath.Join(rootStructure.Path, registryFileName), err)
}
}
@ -75,11 +74,11 @@ func Shutdown() (err error) {
// getLocation returns the storage location for the given name and type.
func getLocation(name, storageType string) (string, error) {
location := databasesStructure.ChildDir(name, 0o0700).ChildDir(storageType, 0o0700)
location := databasesStructure.ChildDir(name, 0700).ChildDir(storageType, 0700)
// check location
err := location.Ensure()
if err != nil {
return "", fmt.Errorf(`failed to create/check database dir "%s": %w`, location.Path, err)
return "", fmt.Errorf(`failed to create/check database dir "%s": %s`, location.Path, err)
}
return location.Path, nil
}

View file

@ -1,58 +0,0 @@
package migration
import "errors"
// DiagnosticStep describes one migration step in the Diagnostics.
type DiagnosticStep struct {
Version string
Description string
}
// Diagnostics holds a detailed error report about a failed migration.
type Diagnostics struct { //nolint:errname
// Message holds a human readable message of the encountered
// error.
Message string
// Wrapped must be set to the underlying error that was encountered
// while preparing or executing migrations.
Wrapped error
// StartOfMigration is set to the version of the database before
// any migrations are applied.
StartOfMigration string
// LastSuccessfulMigration is set to the version of the database
// which has been applied successfully before the error happened.
LastSuccessfulMigration string
// TargetVersion is set to the version of the database that the
// migration run aimed for. That is, it's the last available version
// added to the registry.
TargetVersion string
// ExecutionPlan is a list of migration steps that were planned to
// be executed.
ExecutionPlan []DiagnosticStep
// FailedMigration is the description of the migration that has
// failed.
FailedMigration string
}
// Error returns a string representation of the migration error.
func (err *Diagnostics) Error() string {
msg := ""
if err.FailedMigration != "" {
msg = err.FailedMigration + ": "
}
if err.Message != "" {
msg += err.Message + ": "
}
msg += err.Wrapped.Error()
return msg
}
// Unwrap returns the actual error that happened when executing
// a migration. It implements the interface required by the stdlib
// errors package to support errors.Is() and errors.As().
func (err *Diagnostics) Unwrap() error {
if u := errors.Unwrap(err.Wrapped); u != nil {
return u
}
return err.Wrapped
}

View file

@ -1,220 +0,0 @@
package migration
import (
"context"
"errors"
"fmt"
"sort"
"sync"
"time"
"github.com/hashicorp/go-version"
"github.com/safing/portbase/database"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/log"
)
// MigrateFunc is called when a migration should be applied to the
// database. It receives the current version (from) and the target
// version (to) of the database and a dedicated interface for
// interacting with data stored in the DB.
// A dedicated log.ContextTracer is added to ctx for each migration
// run.
type MigrateFunc func(ctx context.Context, from, to *version.Version, dbInterface *database.Interface) error
// Migration represents a registered data-migration that should be applied to
// some database. Migrations are stacked on top and executed in order of increasing
// version number (see Version field).
type Migration struct {
// Description provides a short human-readable description of the
// migration.
Description string
// Version should hold the version of the database/subsystem after
// the migration has been applied.
Version string
// MigrateFuc is executed when the migration should be performed.
MigrateFunc MigrateFunc
}
// Registry holds a migration stack.
type Registry struct {
key string
lock sync.Mutex
migrations []Migration
}
// New creates a new migration registry.
// The key should be the name of the database key that is used to store
// the version of the last successfully applied migration.
func New(key string) *Registry {
return &Registry{
key: key,
}
}
// Add adds one or more migrations to reg.
func (reg *Registry) Add(migrations ...Migration) error {
reg.lock.Lock()
defer reg.lock.Unlock()
for _, m := range migrations {
if _, err := version.NewSemver(m.Version); err != nil {
return fmt.Errorf("migration %q: invalid version %s: %w", m.Description, m.Version, err)
}
reg.migrations = append(reg.migrations, m)
}
return nil
}
// Migrate migrates the database by executing all registered
// migration in order of increasing version numbers. The error
// returned, if not nil, is always of type *Diagnostics.
func (reg *Registry) Migrate(ctx context.Context) (err error) {
reg.lock.Lock()
defer reg.lock.Unlock()
start := time.Now()
log.Infof("migration: migration of %s started", reg.key)
defer func() {
if err != nil {
log.Errorf("migration: migration of %s failed after %s: %s", reg.key, time.Since(start), err)
} else {
log.Infof("migration: migration of %s finished after %s", reg.key, time.Since(start))
}
}()
db := database.NewInterface(&database.Options{
Local: true,
Internal: true,
})
startOfMigration, err := reg.getLatestSuccessfulMigration(db)
if err != nil {
return err
}
execPlan, diag, err := reg.getExecutionPlan(startOfMigration)
if err != nil {
return err
}
if len(execPlan) == 0 {
return nil
}
diag.TargetVersion = execPlan[len(execPlan)-1].Version
// finally, apply our migrations
lastAppliedMigration := startOfMigration
for _, m := range execPlan {
target, _ := version.NewSemver(m.Version) // we can safely ignore the error here
migrationCtx, tracer := log.AddTracer(ctx)
if err := m.MigrateFunc(migrationCtx, lastAppliedMigration, target, db); err != nil {
diag.Wrapped = err
diag.FailedMigration = m.Description
tracer.Errorf("migration: migration for %s failed: %s - %s", reg.key, target.String(), m.Description)
tracer.Submit()
return diag
}
lastAppliedMigration = target
diag.LastSuccessfulMigration = lastAppliedMigration.String()
if err := reg.saveLastSuccessfulMigration(db, target); err != nil {
diag.Message = "failed to persist migration status"
diag.Wrapped = err
diag.FailedMigration = m.Description
}
tracer.Infof("migration: applied migration for %s: %s - %s", reg.key, target.String(), m.Description)
tracer.Submit()
}
// all migrations have been applied successfully, we're done here
return nil
}
func (reg *Registry) getLatestSuccessfulMigration(db *database.Interface) (*version.Version, error) {
// find the latest version stored in the database
rec, err := db.Get(reg.key)
if errors.Is(err, database.ErrNotFound) {
return nil, nil
}
if err != nil {
return nil, &Diagnostics{
Message: "failed to query database for migration status",
Wrapped: err,
}
}
// Unwrap the record to get the actual database
r, ok := rec.(*record.Wrapper)
if !ok {
return nil, &Diagnostics{
Wrapped: errors.New("expected wrapped database record"),
}
}
sv, err := version.NewSemver(string(r.Data))
if err != nil {
return nil, &Diagnostics{
Message: "failed to parse version stored in migration status record",
Wrapped: err,
}
}
return sv, nil
}
func (reg *Registry) saveLastSuccessfulMigration(db *database.Interface, ver *version.Version) error {
r := &record.Wrapper{
Data: []byte(ver.String()),
Format: dsd.RAW,
}
r.SetKey(reg.key)
return db.Put(r)
}
func (reg *Registry) getExecutionPlan(startOfMigration *version.Version) ([]Migration, *Diagnostics, error) {
// create a look-up map for migrations indexed by their semver created a
// list of version (sorted by increasing number) that we use as our execution
// plan.
lm := make(map[string]Migration)
versions := make(version.Collection, 0, len(reg.migrations))
for _, m := range reg.migrations {
ver, err := version.NewSemver(m.Version)
if err != nil {
return nil, nil, &Diagnostics{
Message: "failed to parse version of migration",
Wrapped: err,
FailedMigration: m.Description,
}
}
lm[ver.String()] = m // use .String() for a normalized string representation
versions = append(versions, ver)
}
sort.Sort(versions)
diag := new(Diagnostics)
if startOfMigration != nil {
diag.StartOfMigration = startOfMigration.String()
}
// prepare our diagnostics and the execution plan
execPlan := make([]Migration, 0, len(versions))
for _, ver := range versions {
// skip an migration that has already been applied.
if startOfMigration != nil && startOfMigration.GreaterThanOrEqual(ver) {
continue
}
m := lm[ver.String()]
diag.ExecutionPlan = append(diag.ExecutionPlan, DiagnosticStep{
Description: m.Description,
Version: ver.String(),
})
execPlan = append(execPlan, m)
}
return execPlan, diag, nil
}

View file

@ -15,6 +15,7 @@ type boolCondition struct {
}
func newBoolCondition(key string, operator uint8, value interface{}) *boolCondition {
var parsedValue bool
switch v := value.(type) {

View file

@ -15,6 +15,7 @@ type floatCondition struct {
}
func newFloatCondition(key string, operator uint8, value interface{}) *floatCondition {
var parsedValue float64
switch v := value.(type) {

View file

@ -15,6 +15,7 @@ type intCondition struct {
}
func newIntCondition(key string, operator uint8, value interface{}) *intCondition {
var parsedValue int64
switch v := value.(type) {

View file

@ -15,6 +15,7 @@ type stringSliceCondition struct {
}
func newStringSliceCondition(key string, operator uint8, value interface{}) *stringSliceCondition {
switch v := value.(type) {
case string:
parsedValue := strings.Split(v, ",")
@ -41,6 +42,7 @@ func newStringSliceCondition(key string, operator uint8, value interface{}) *str
operator: errorPresent,
}
}
}
func (c *stringSliceCondition) complies(acc accessor.Accessor) bool {

View file

@ -13,7 +13,7 @@ type Condition interface {
string() string
}
// Operators.
// Operators
const (
Equals uint8 = iota // int
GreaterThan // int

View file

@ -3,8 +3,6 @@ package query
import "testing"
func testSuccess(t *testing.T, c Condition) {
t.Helper()
err := c.check()
if err != nil {
t.Errorf("failed: %s", err)
@ -12,8 +10,6 @@ func testSuccess(t *testing.T, c Condition) {
}
func TestInterfaces(t *testing.T) {
t.Parallel()
testSuccess(t, newIntCondition("banana", Equals, uint(1)))
testSuccess(t, newIntCondition("banana", Equals, uint8(1)))
testSuccess(t, newIntCondition("banana", Equals, uint16(1)))
@ -45,8 +41,6 @@ func TestInterfaces(t *testing.T) {
}
func testCondError(t *testing.T, c Condition) {
t.Helper()
err := c.check()
if err == nil {
t.Error("should fail")
@ -54,8 +48,6 @@ func testCondError(t *testing.T, c Condition) {
}
func TestConditionErrors(t *testing.T) {
t.Parallel()
// test invalid value types
testCondError(t, newBoolCondition("banana", Is, 1))
testCondError(t, newFloatCondition("banana", FloatEquals, true))
@ -76,8 +68,6 @@ func TestConditionErrors(t *testing.T) {
}
func TestWhere(t *testing.T) {
t.Parallel()
c := Where("", 254, nil)
err := c.check()
if err == nil {

View file

@ -3,8 +3,6 @@ package query
import "testing"
func TestGetOpName(t *testing.T) {
t.Parallel()
if getOpName(254) != "[unknown]" {
t.Error("unexpected output")
}

View file

@ -14,7 +14,6 @@ type snippet struct {
}
// ParseQuery parses a plaintext query. Special characters (that must be escaped with a '\') are: `\()` and any whitespaces.
//
//nolint:gocognit
func ParseQuery(query string) (*Query, error) {
snippets, err := extractSnippets(query)
@ -122,6 +121,7 @@ func ParseQuery(query string) (*Query, error) {
}
func extractSnippets(text string) (snippets []*snippet, err error) {
skip := false
start := -1
inParenthesis := false
@ -193,22 +193,21 @@ func extractSnippets(text string) (snippets []*snippet, err error) {
}
return snippets, nil
}
//nolint:gocognit
func parseAndOr(getSnippet func() (*snippet, error), remainingSnippets func() int, rootCondition bool) (Condition, error) {
var (
isOr = false
typeSet = false
wrapInNot = false
expectingMore = true
conditions []Condition
)
var isOr = false
var typeSet = false
var wrapInNot = false
var expectingMore = true
var conditions []Condition
for {
if !expectingMore && rootCondition && remainingSnippets() == 0 {
// advance snippetsPos by one, as it will be set back by 1
_, _ = getSnippet()
getSnippet() //nolint:errcheck
if len(conditions) == 1 {
return conditions[0], nil
}
@ -332,19 +331,21 @@ func parseCondition(firstSnippet *snippet, getSnippet func() (*snippet, error))
return Where(firstSnippet.text, operator, value.text), nil
}
var escapeReplacer = regexp.MustCompile(`\\([^\\])`)
var (
escapeReplacer = regexp.MustCompile(`\\([^\\])`)
)
// prepToken removes surrounding parenthesis and escape characters.
func prepToken(text string) string {
return escapeReplacer.ReplaceAllString(strings.Trim(text, "\""), "$1")
}
// escapeString correctly escapes a snippet for printing.
// escapeString correctly escapes a snippet for printing
func escapeString(token string) string {
// check if token contains characters that need to be escaped
if strings.ContainsAny(token, "()\"\\\t\r\n ") {
// put the token in parenthesis and only escape \ and "
return fmt.Sprintf("\"%s\"", strings.ReplaceAll(token, "\"", "\\\""))
return fmt.Sprintf("\"%s\"", strings.Replace(token, "\"", "\\\"", -1))
}
return token
}

View file

@ -8,8 +8,6 @@ import (
)
func TestExtractSnippets(t *testing.T) {
t.Parallel()
text1 := `query test: where ( "bananas" > 100 and monkeys.# <= "12")or(coconuts < 10 "and" area > 50) or name sameas Julian or name matches ^King\ `
result1 := []*snippet{
{text: "query", globalPosition: 1},
@ -60,8 +58,6 @@ func TestExtractSnippets(t *testing.T) {
}
func testParsing(t *testing.T, queryText string, expectedResult *Query) {
t.Helper()
_, err := expectedResult.Check()
if err != nil {
t.Errorf("failed to create query: %s", err)
@ -88,8 +84,6 @@ func testParsing(t *testing.T, queryText string, expectedResult *Query) {
}
func TestParseQuery(t *testing.T) {
t.Parallel()
text1 := `query test: where (bananas > 100 and monkeys.# <= 12) or not (coconuts < 10 and area not > 50) or name sameas Julian or name matches "^King " orderby name limit 10 offset 20`
result1 := New("test:").Where(Or(
And(
@ -137,8 +131,6 @@ func TestParseQuery(t *testing.T) {
}
func testParseError(t *testing.T, queryText string, expectedErrorString string) {
t.Helper()
_, err := ParseQuery(queryText)
if err == nil {
t.Errorf("should fail to parse: %s", queryText)
@ -150,8 +142,6 @@ func testParseError(t *testing.T, queryText string, expectedErrorString string)
}
func TestParseErrors(t *testing.T) {
t.Parallel()
// syntax
testParseError(t, `query`, `unexpected end at position 5`)
testParseError(t, `query test: where`, `unexpected end at position 17`)

View file

@ -8,8 +8,9 @@ import (
"github.com/safing/portbase/formats/dsd"
)
// copied from https://github.com/tidwall/gjson/blob/master/gjson_test.go
var testJSON = `{"age":100, "name":{"here":"B\\\"R"},
var (
// copied from https://github.com/tidwall/gjson/blob/master/gjson_test.go
testJSON = `{"age":100, "name":{"here":"B\\\"R"},
"noop":{"what is a wren?":"a bird"},
"happy":true,"immortal":false,
"items":[1,2,3,{"tags":[1,2,3],"points":[[1,2],[3,4]]},4,5,6,7],
@ -45,11 +46,11 @@ var testJSON = `{"age":100, "name":{"here":"B\\\"R"},
"lastly":{"yay":"final"},
"temperature": 120.413
}`
)
func testQuery(t *testing.T, r record.Record, shouldMatch bool, condition Condition) {
t.Helper()
q := New("test:").Where(condition).MustBeValid()
// fmt.Printf("%s\n", q.Print())
matched := q.Matches(r)
@ -62,7 +63,6 @@ func testQuery(t *testing.T, r record.Record, shouldMatch bool, condition Condit
}
func TestQuery(t *testing.T) {
t.Parallel()
// if !gjson.Valid(testJSON) {
// t.Fatal("test json is invalid")
@ -110,4 +110,5 @@ func TestQuery(t *testing.T) {
testQuery(t, r, true, Where("happy", Exists, nil))
testQuery(t, r, true, Where("created", Matches, "^2014-[0-9]{2}-[0-9]{2}T"))
}

View file

@ -44,13 +44,6 @@ func (b *Base) SetKey(key string) {
}
}
// ResetKey resets the database name and key.
// Use with caution!
func (b *Base) ResetKey() {
b.dbName = ""
b.dbKey = ""
}
// Key returns the key of the database record.
// As the key must be set before any usage and can only be set once, this
// function may be used without locking the record.
@ -129,14 +122,14 @@ func (b *Base) MarshalRecord(self Record) ([]byte, error) {
c := container.New([]byte{1})
// meta encoding
metaSection, err := dsd.Dump(b.meta, dsd.GenCode)
metaSection, err := dsd.Dump(b.meta, GenCode)
if err != nil {
return nil, err
}
c.AppendAsBlock(metaSection)
// data
dataSection, err := b.Marshal(self, dsd.JSON)
dataSection, err := b.Marshal(self, JSON)
if err != nil {
return nil, err
}

View file

@ -3,11 +3,11 @@ package record
import "testing"
func TestBaseRecord(t *testing.T) {
t.Parallel()
// check model interface compliance
var m Record
b := &TestRecord{}
m = b
_ = m
}

View file

@ -0,0 +1,15 @@
package record
import (
"github.com/safing/portbase/formats/dsd"
)
// Reimport DSD storage types
const (
AUTO = dsd.AUTO
STRING = dsd.STRING // S
BYTES = dsd.BYTES // X
JSON = dsd.JSON // J
BSON = dsd.BSON // B
GenCode = dsd.GenCode // G
)

View file

@ -24,16 +24,22 @@ import (
"github.com/safing/portbase/container"
"github.com/safing/portbase/formats/dsd"
"github.com/safing/portbase/formats/varint"
// Colfer
// "github.com/safing/portbase/database/model/model"
// XDR
// xdr2 "github.com/davecgh/go-xdr/xdr2"
)
var testMeta = &Meta{
Created: time.Now().Unix(),
Modified: time.Now().Unix(),
Expires: time.Now().Unix(),
Deleted: time.Now().Unix(),
secret: true,
cronjewel: true,
}
var (
testMeta = &Meta{
Created: time.Now().Unix(),
Modified: time.Now().Unix(),
Expires: time.Now().Unix(),
Deleted: time.Now().Unix(),
secret: true,
cronjewel: true,
}
)
func BenchmarkAllocateBytes(b *testing.B) {
for i := 0; i < b.N; i++ {
@ -43,8 +49,8 @@ func BenchmarkAllocateBytes(b *testing.B) {
func BenchmarkAllocateStruct1(b *testing.B) {
for i := 0; i < b.N; i++ {
var newMeta Meta
_ = newMeta
var new Meta
_ = new
}
}
@ -55,6 +61,7 @@ func BenchmarkAllocateStruct2(b *testing.B) {
}
func BenchmarkMetaSerializeContainer(b *testing.B) {
// Start benchmark
for i := 0; i < b.N; i++ {
c := container.New()
@ -73,9 +80,11 @@ func BenchmarkMetaSerializeContainer(b *testing.B) {
c.AppendNumber(0)
}
}
}
func BenchmarkMetaUnserializeContainer(b *testing.B) {
// Setup
c := container.New()
c.AppendNumber(uint64(testMeta.Created))
@ -148,9 +157,11 @@ func BenchmarkMetaUnserializeContainer(b *testing.B) {
return
}
}
}
func BenchmarkMetaSerializeVarInt(b *testing.B) {
// Start benchmark
for i := 0; i < b.N; i++ {
encoded := make([]byte, 33)
@ -186,10 +197,13 @@ func BenchmarkMetaSerializeVarInt(b *testing.B) {
default:
encoded[offset] = 0
}
offset++
}
}
func BenchmarkMetaUnserializeVarInt(b *testing.B) {
// Setup
encoded := make([]byte, 33)
offset := 0
@ -281,9 +295,106 @@ func BenchmarkMetaUnserializeVarInt(b *testing.B) {
return
}
}
}
// func BenchmarkMetaSerializeWithXDR2(b *testing.B) {
//
// // Setup
// var w bytes.Buffer
//
// // Reset timer for precise results
// b.ResetTimer()
//
// // Start benchmark
// for i := 0; i < b.N; i++ {
// w.Reset()
// _, err := xdr2.Marshal(&w, testMeta)
// if err != nil {
// b.Errorf("failed to serialize with xdr2: %s", err)
// return
// }
// }
//
// }
// func BenchmarkMetaUnserializeWithXDR2(b *testing.B) {
//
// // Setup
// var w bytes.Buffer
// _, err := xdr2.Marshal(&w, testMeta)
// if err != nil {
// b.Errorf("failed to serialize with xdr2: %s", err)
// }
// encodedData := w.Bytes()
//
// // Reset timer for precise results
// b.ResetTimer()
//
// // Start benchmark
// for i := 0; i < b.N; i++ {
// var newMeta Meta
// _, err := xdr2.Unmarshal(bytes.NewReader(encodedData), &newMeta)
// if err != nil {
// b.Errorf("failed to unserialize with xdr2: %s", err)
// return
// }
// }
//
// }
// func BenchmarkMetaSerializeWithColfer(b *testing.B) {
//
// testColf := &model.Course{
// Created: time.Now().Unix(),
// Modified: time.Now().Unix(),
// Expires: time.Now().Unix(),
// Deleted: time.Now().Unix(),
// Secret: true,
// Cronjewel: true,
// }
//
// // Setup
// for i := 0; i < b.N; i++ {
// _, err := testColf.MarshalBinary()
// if err != nil {
// b.Errorf("failed to serialize with colfer: %s", err)
// return
// }
// }
//
// }
// func BenchmarkMetaUnserializeWithColfer(b *testing.B) {
//
// testColf := &model.Course{
// Created: time.Now().Unix(),
// Modified: time.Now().Unix(),
// Expires: time.Now().Unix(),
// Deleted: time.Now().Unix(),
// Secret: true,
// Cronjewel: true,
// }
// encodedData, err := testColf.MarshalBinary()
// if err != nil {
// b.Errorf("failed to serialize with colfer: %s", err)
// return
// }
//
// // Setup
// for i := 0; i < b.N; i++ {
// var testUnColf model.Course
// err := testUnColf.UnmarshalBinary(encodedData)
// if err != nil {
// b.Errorf("failed to unserialize with colfer: %s", err)
// return
// }
// }
//
// }
func BenchmarkMetaSerializeWithCodegen(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := testMeta.GenCodeMarshal(nil)
if err != nil {
@ -291,9 +402,11 @@ func BenchmarkMetaSerializeWithCodegen(b *testing.B) {
return
}
}
}
func BenchmarkMetaUnserializeWithCodegen(b *testing.B) {
// Setup
encodedData, err := testMeta.GenCodeMarshal(nil)
if err != nil {
@ -313,21 +426,25 @@ func BenchmarkMetaUnserializeWithCodegen(b *testing.B) {
return
}
}
}
func BenchmarkMetaSerializeWithDSDJSON(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := dsd.Dump(testMeta, dsd.JSON)
_, err := dsd.Dump(testMeta, JSON)
if err != nil {
b.Errorf("failed to serialize with DSD/JSON: %s", err)
return
}
}
}
func BenchmarkMetaUnserializeWithDSDJSON(b *testing.B) {
// Setup
encodedData, err := dsd.Dump(testMeta, dsd.JSON)
encodedData, err := dsd.Dump(testMeta, JSON)
if err != nil {
b.Errorf("failed to serialize with DSD/JSON: %s", err)
return
@ -345,4 +462,5 @@ func BenchmarkMetaUnserializeWithDSDJSON(b *testing.B) {
return
}
}
}

View file

@ -2,9 +2,18 @@ package record
import (
"fmt"
"io"
"time"
"unsafe"
)
// GenCodeSize returns the size of the gencode marshalled byte slice.
var (
_ = unsafe.Sizeof(0)
_ = io.ReadFull
_ = time.Now()
)
// GenCodeSize returns the size of the gencode marshalled byte slice
func (m *Meta) GenCodeSize() (s int) {
s += 34
return
@ -124,16 +133,24 @@ func (m *Meta) GenCodeUnmarshal(buf []byte) (uint64, error) {
i := uint64(0)
{
m.Created = 0 | (int64(buf[0+0]) << 0) | (int64(buf[1+0]) << 8) | (int64(buf[2+0]) << 16) | (int64(buf[3+0]) << 24) | (int64(buf[4+0]) << 32) | (int64(buf[5+0]) << 40) | (int64(buf[6+0]) << 48) | (int64(buf[7+0]) << 56)
}
{
m.Modified = 0 | (int64(buf[0+8]) << 0) | (int64(buf[1+8]) << 8) | (int64(buf[2+8]) << 16) | (int64(buf[3+8]) << 24) | (int64(buf[4+8]) << 32) | (int64(buf[5+8]) << 40) | (int64(buf[6+8]) << 48) | (int64(buf[7+8]) << 56)
}
{
m.Expires = 0 | (int64(buf[0+16]) << 0) | (int64(buf[1+16]) << 8) | (int64(buf[2+16]) << 16) | (int64(buf[3+16]) << 24) | (int64(buf[4+16]) << 32) | (int64(buf[5+16]) << 40) | (int64(buf[6+16]) << 48) | (int64(buf[7+16]) << 56)
}
{
m.Deleted = 0 | (int64(buf[0+24]) << 0) | (int64(buf[1+24]) << 8) | (int64(buf[2+24]) << 16) | (int64(buf[3+24]) << 24) | (int64(buf[4+24]) << 32) | (int64(buf[5+24]) << 40) | (int64(buf[6+24]) << 48) | (int64(buf[7+24]) << 56)
}
{
m.secret = buf[32] == 1

View file

@ -6,30 +6,30 @@ import (
"time"
)
var genCodeTestMeta = &Meta{
Created: time.Now().Unix(),
Modified: time.Now().Unix(),
Expires: time.Now().Unix(),
Deleted: time.Now().Unix(),
secret: true,
cronjewel: true,
}
var (
genCodeTestMeta = &Meta{
Created: time.Now().Unix(),
Modified: time.Now().Unix(),
Expires: time.Now().Unix(),
Deleted: time.Now().Unix(),
secret: true,
cronjewel: true,
}
)
func TestGenCode(t *testing.T) {
t.Parallel()
encoded, err := genCodeTestMeta.GenCodeMarshal(nil)
if err != nil {
t.Fatal(err)
}
newMeta := &Meta{}
_, err = newMeta.GenCodeUnmarshal(encoded)
new := &Meta{}
_, err = new.GenCodeUnmarshal(encoded)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(genCodeTestMeta, newMeta) {
t.Errorf("objects are not equal, got: %v", newMeta)
if !reflect.DeepEqual(genCodeTestMeta, new) {
t.Errorf("objects are not equal, got: %v", new)
}
}

View file

@ -2,7 +2,7 @@ package record
import "time"
// Meta holds metadata about the record.
// Meta holds
type Meta struct {
Created int64
Modified int64

View file

@ -12,21 +12,17 @@ type Record interface {
DatabaseName() string // test
DatabaseKey() string // config
// Metadata.
Meta() *Meta
SetMeta(meta *Meta)
CreateMeta()
UpdateMeta()
// Serialization.
Marshal(self Record, format uint8) ([]byte, error)
MarshalRecord(self Record) ([]byte, error)
GetAccessor(self Record) accessor.Accessor
// Locking.
Lock()
Unlock()
// Wrapping.
IsWrapped() bool
}

View file

@ -32,21 +32,21 @@ func NewRawWrapper(database, key string, data []byte) (*Wrapper, error) {
metaSection, n, err := varint.GetNextBlock(data[offset:])
if err != nil {
return nil, fmt.Errorf("could not get meta section: %w", err)
return nil, fmt.Errorf("could not get meta section: %s", err)
}
offset += n
newMeta := &Meta{}
_, err = dsd.Load(metaSection, newMeta)
if err != nil {
return nil, fmt.Errorf("could not unmarshal meta section: %w", err)
return nil, fmt.Errorf("could not unmarshal meta section: %s", err)
}
var format uint8 = dsd.RAW
var format uint8 = dsd.NONE
if !newMeta.IsDeleted() {
format, n, err = varint.Unpack8(data[offset:])
if err != nil {
return nil, fmt.Errorf("could not get dsd format: %w", err)
return nil, fmt.Errorf("could not get dsd format: %s", err)
}
offset += n
}
@ -79,7 +79,7 @@ func NewWrapper(key string, meta *Meta, format uint8, data []byte) (*Wrapper, er
}, nil
}
// Marshal marshals the object, without the database key or metadata.
// Marshal marshals the object, without the database key or metadata
func (w *Wrapper) Marshal(r Record, format uint8) ([]byte, error) {
if w.Meta() == nil {
return nil, errors.New("missing meta")
@ -89,7 +89,7 @@ func (w *Wrapper) Marshal(r Record, format uint8) ([]byte, error) {
return nil, nil
}
if format != dsd.AUTO && format != w.Format {
if format != AUTO && format != w.Format {
return nil, errors.New("could not dump model, wrapped object format mismatch")
}
@ -112,14 +112,14 @@ func (w *Wrapper) MarshalRecord(r Record) ([]byte, error) {
c := container.New([]byte{1})
// meta
metaSection, err := dsd.Dump(w.meta, dsd.GenCode)
metaSection, err := dsd.Dump(w.meta, GenCode)
if err != nil {
return nil, err
}
c.AppendAsBlock(metaSection)
// data
dataSection, err := w.Marshal(r, dsd.AUTO)
dataSection, err := w.Marshal(r, JSON)
if err != nil {
return nil, err
}
@ -134,26 +134,26 @@ func (w *Wrapper) IsWrapped() bool {
}
// Unwrap unwraps data into a record.
func Unwrap(wrapped, r Record) error {
func Unwrap(wrapped, new Record) error {
wrapper, ok := wrapped.(*Wrapper)
if !ok {
return fmt.Errorf("cannot unwrap %T", wrapped)
}
err := dsd.LoadAsFormat(wrapper.Data, wrapper.Format, r)
_, err := dsd.LoadAsFormat(wrapper.Data, wrapper.Format, new)
if err != nil {
return fmt.Errorf("failed to unwrap %T: %w", r, err)
return fmt.Errorf("failed to unwrap %T: %s", new, err)
}
r.SetKey(wrapped.Key())
r.SetMeta(wrapped.Meta())
new.SetKey(wrapped.Key())
new.SetMeta(wrapped.Meta())
return nil
}
// GetAccessor returns an accessor for this record, if available.
func (w *Wrapper) GetAccessor(self Record) accessor.Accessor {
if w.Format == dsd.JSON && len(w.Data) > 0 {
if w.Format == JSON && len(w.Data) > 0 {
return accessor.NewJSONBytesAccessor(&w.Data)
}
return nil

View file

@ -3,12 +3,9 @@ package record
import (
"bytes"
"testing"
"github.com/safing/portbase/formats/dsd"
)
func TestWrapper(t *testing.T) {
t.Parallel()
// check model interface compliance
var m Record
@ -21,18 +18,18 @@ func TestWrapper(t *testing.T) {
encodedTestData := []byte(`J{"a": "b"}`)
// test wrapper
wrapper, err := NewWrapper("test:a", &Meta{}, dsd.JSON, testData)
wrapper, err := NewWrapper("test:a", &Meta{}, JSON, testData)
if err != nil {
t.Fatal(err)
}
if wrapper.Format != dsd.JSON {
if wrapper.Format != JSON {
t.Error("format mismatch")
}
if !bytes.Equal(testData, wrapper.Data) {
t.Error("data mismatch")
}
encoded, err := wrapper.Marshal(wrapper, dsd.JSON)
encoded, err := wrapper.Marshal(wrapper, JSON)
if err != nil {
t.Fatal(err)
}

View file

@ -4,7 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io/fs"
"io/ioutil"
"os"
"path"
"regexp"
@ -32,7 +32,7 @@ var (
// If the database is already registered, only
// the description and the primary API will be
// updated and the effective object will be returned.
func Register(db *Database) (*Database, error) {
func Register(new *Database) (*Database, error) {
if !initialized.IsSet() {
return nil, errors.New("database not initialized")
}
@ -40,31 +40,31 @@ func Register(db *Database) (*Database, error) {
registryLock.Lock()
defer registryLock.Unlock()
registeredDB, ok := registry[db.Name]
registeredDB, ok := registry[new.Name]
save := false
if ok {
// update database
if registeredDB.Description != db.Description {
registeredDB.Description = db.Description
if registeredDB.Description != new.Description {
registeredDB.Description = new.Description
save = true
}
if registeredDB.ShadowDelete != db.ShadowDelete {
registeredDB.ShadowDelete = db.ShadowDelete
if registeredDB.ShadowDelete != new.ShadowDelete {
registeredDB.ShadowDelete = new.ShadowDelete
save = true
}
} else {
// register new database
if !nameConstraint.MatchString(db.Name) {
if !nameConstraint.MatchString(new.Name) {
return nil, errors.New("database name must only contain alphanumeric and `_-` characters and must be at least 3 characters long")
}
now := time.Now().Round(time.Second)
db.Registered = now
db.LastUpdated = now
db.LastLoaded = time.Time{}
new.Registered = now
new.LastUpdated = now
new.LastLoaded = time.Time{}
registry[db.Name] = db
registry[new.Name] = new
save = true
}
@ -115,23 +115,23 @@ func loadRegistry() error {
// read file
filePath := path.Join(rootStructure.Path, registryFileName)
data, err := os.ReadFile(filePath)
data, err := ioutil.ReadFile(filePath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
if os.IsNotExist(err) {
return nil
}
return err
}
// parse
databases := make(map[string]*Database)
err = json.Unmarshal(data, &databases)
new := make(map[string]*Database)
err = json.Unmarshal(data, &new)
if err != nil {
return err
}
// set
registry = databases
registry = new
return nil
}
@ -150,7 +150,7 @@ func saveRegistry(lock bool) error {
// write file
// TODO: write atomically (best effort)
filePath := path.Join(rootStructure.Path, registryFileName)
return os.WriteFile(filePath, data, 0o0600)
return ioutil.WriteFile(filePath, data, 0600)
}
func registryWriter() {

View file

@ -30,7 +30,7 @@ func NewBadger(name, location string) (storage.Interface, error) {
opts := badger.DefaultOptions(location)
db, err := badger.Open(opts)
if errors.Is(err, badger.ErrTruncateNeeded) {
if err == badger.ErrTruncateNeeded {
// clean up after crash
log.Warningf("database/storage: truncating corrupted value log of badger database %s: this may cause data loss", name)
opts.Truncate = true
@ -54,7 +54,7 @@ func (b *Badger) Get(key string) (record.Record, error) {
var err error
item, err = txn.Get([]byte(key))
if err != nil {
if errors.Is(err, badger.ErrKeyNotFound) {
if err == badger.ErrKeyNotFound {
return storage.ErrNotFound
}
return err
@ -114,7 +114,7 @@ func (b *Badger) Put(r record.Record) (record.Record, error) {
func (b *Badger) Delete(key string) error {
return b.db.Update(func(txn *badger.Txn) error {
err := txn.Delete([]byte(key))
if err != nil && !errors.Is(err, badger.ErrKeyNotFound) {
if err != nil && err != badger.ErrKeyNotFound {
return err
}
return nil
@ -125,7 +125,7 @@ func (b *Badger) Delete(key string) error {
func (b *Badger) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
_, err := q.Check()
if err != nil {
return nil, fmt.Errorf("invalid query: %w", err)
return nil, fmt.Errorf("invalid query: %s", err)
}
queryIter := iterator.New()
@ -169,17 +169,17 @@ func (b *Badger) queryExecutor(queryIter *iterator.Iterator, q *query.Query, loc
if err != nil {
return err
}
newWrapper, err := record.NewRawWrapper(b.name, r.DatabaseKey(), copiedData)
new, err := record.NewRawWrapper(b.name, r.DatabaseKey(), copiedData)
if err != nil {
return err
}
select {
case <-queryIter.Done:
return nil
case queryIter.Next <- newWrapper:
case queryIter.Next <- new:
default:
select {
case queryIter.Next <- newWrapper:
case queryIter.Next <- new:
case <-queryIter.Done:
return nil
case <-time.After(1 * time.Minute):

View file

@ -1,7 +1,9 @@
//nolint:unparam,maligned
package badger
import (
"context"
"io/ioutil"
"os"
"reflect"
"sync"
@ -18,7 +20,7 @@ var (
_ storage.Maintainer = &Badger{}
)
type TestRecord struct { //nolint:maligned
type TestRecord struct {
record.Base
sync.Mutex
S string
@ -38,15 +40,11 @@ type TestRecord struct { //nolint:maligned
}
func TestBadger(t *testing.T) {
t.Parallel()
testDir, err := os.MkdirTemp("", "testing-")
testDir, err := ioutil.TempDir("", "testing-")
if err != nil {
t.Fatal(err)
}
defer func() {
_ = os.RemoveAll(testDir) // clean up
}()
defer os.RemoveAll(testDir) // clean up
// start
db, err := NewBadger("test", testDir)

View file

@ -16,7 +16,9 @@ import (
"github.com/safing/portbase/database/storage"
)
var bucketName = []byte{0}
var (
bucketName = []byte{0}
)
// BBolt database made pluggable for portbase.
type BBolt struct {
@ -37,10 +39,10 @@ func NewBBolt(name, location string) (storage.Interface, error) {
}
// Open/Create database, retry if there is a timeout.
db, err := bbolt.Open(dbFile, 0o0600, dbOptions)
db, err := bbolt.Open(dbFile, 0600, dbOptions)
for i := 0; i < 5 && err != nil; i++ {
// Try again if there is an error.
db, err = bbolt.Open(dbFile, 0o0600, dbOptions)
db, err = bbolt.Open(dbFile, 0600, dbOptions)
}
if err != nil {
return nil, err
@ -87,6 +89,7 @@ func (b *BBolt) Get(key string) (record.Record, error) {
}
return nil
})
if err != nil {
return nil, err
}
@ -185,7 +188,7 @@ func (b *BBolt) Delete(key string) error {
func (b *BBolt) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
_, err := q.Check()
if err != nil {
return nil, fmt.Errorf("invalid query: %w", err)
return nil, fmt.Errorf("invalid query: %s", err)
}
queryIter := iterator.New()
@ -232,19 +235,19 @@ func (b *BBolt) queryExecutor(queryIter *iterator.Iterator, q *query.Query, loca
duplicate := make([]byte, len(value))
copy(duplicate, value)
newWrapper, err := record.NewRawWrapper(b.name, iterWrapper.DatabaseKey(), duplicate)
new, err := record.NewRawWrapper(b.name, iterWrapper.DatabaseKey(), duplicate)
if err != nil {
return err
}
select {
case <-queryIter.Done:
return nil
case queryIter.Next <- newWrapper:
case queryIter.Next <- new:
default:
select {
case <-queryIter.Done:
return nil
case queryIter.Next <- newWrapper:
case queryIter.Next <- new:
case <-time.After(1 * time.Second):
return errors.New("query timeout")
}

View file

@ -1,7 +1,9 @@
//nolint:unparam,maligned
package bbolt
import (
"context"
"io/ioutil"
"os"
"reflect"
"sync"
@ -20,7 +22,7 @@ var (
_ storage.Purger = &BBolt{}
)
type TestRecord struct { //nolint:maligned
type TestRecord struct {
record.Base
sync.Mutex
S string
@ -40,15 +42,11 @@ type TestRecord struct { //nolint:maligned
}
func TestBBolt(t *testing.T) {
t.Parallel()
testDir, err := os.MkdirTemp("", "testing-")
testDir, err := ioutil.TempDir("", "testing-")
if err != nil {
t.Fatal(err)
}
defer func() {
_ = os.RemoveAll(testDir) // clean up
}()
defer os.RemoveAll(testDir) // clean up
// start
db, err := NewBBolt("test", testDir)

View file

@ -2,7 +2,7 @@ package storage
import "errors"
// Errors for storages.
// Errors for storages
var (
ErrNotFound = errors.New("storage entry not found")
)

View file

@ -8,7 +8,7 @@ import (
"context"
"errors"
"fmt"
"io/fs"
"io/ioutil"
"os"
"path/filepath"
"runtime"
@ -23,8 +23,8 @@ import (
)
const (
defaultFileMode = os.FileMode(0o0644)
defaultDirMode = os.FileMode(0o0755)
defaultFileMode = os.FileMode(int(0644))
defaultDirMode = os.FileMode(int(0755))
onWindows = runtime.GOOS == "windows"
)
@ -42,18 +42,18 @@ func init() {
func NewFSTree(name, location string) (storage.Interface, error) {
basePath, err := filepath.Abs(location)
if err != nil {
return nil, fmt.Errorf("fstree: failed to validate path %s: %w", location, err)
return nil, fmt.Errorf("fstree: failed to validate path %s: %s", location, err)
}
file, err := os.Stat(basePath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
if os.IsNotExist(err) {
err = os.MkdirAll(basePath, defaultDirMode)
if err != nil {
return nil, fmt.Errorf("fstree: failed to create directory %s: %w", basePath, err)
return nil, fmt.Errorf("fstree: failed to create directory %s: %s", basePath, err)
}
} else {
return nil, fmt.Errorf("fstree: failed to stat path %s: %w", basePath, err)
return nil, fmt.Errorf("fstree: failed to stat path %s: %s", basePath, err)
}
} else {
if !file.IsDir() {
@ -88,12 +88,12 @@ func (fst *FSTree) Get(key string) (record.Record, error) {
return nil, err
}
data, err := os.ReadFile(dstPath)
data, err := ioutil.ReadFile(dstPath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
if os.IsNotExist(err) {
return nil, storage.ErrNotFound
}
return nil, fmt.Errorf("fstree: failed to read file %s: %w", dstPath, err)
return nil, fmt.Errorf("fstree: failed to read file %s: %s", dstPath, err)
}
r, err := record.NewRawWrapper(fst.name, key, data)
@ -132,11 +132,11 @@ func (fst *FSTree) Put(r record.Record) (record.Record, error) {
// create dir and try again
err = os.MkdirAll(filepath.Dir(dstPath), defaultDirMode)
if err != nil {
return nil, fmt.Errorf("fstree: failed to create directory %s: %w", filepath.Dir(dstPath), err)
return nil, fmt.Errorf("fstree: failed to create directory %s: %s", filepath.Dir(dstPath), err)
}
err = writeFile(dstPath, data, defaultFileMode)
if err != nil {
return nil, fmt.Errorf("fstree: could not write file %s: %w", dstPath, err)
return nil, fmt.Errorf("fstree: could not write file %s: %s", dstPath, err)
}
}
@ -153,7 +153,7 @@ func (fst *FSTree) Delete(key string) error {
// remove entry
err = os.Remove(dstPath)
if err != nil {
return fmt.Errorf("fstree: could not delete %s: %w", dstPath, err)
return fmt.Errorf("fstree: could not delete %s: %s", dstPath, err)
}
return nil
@ -163,7 +163,7 @@ func (fst *FSTree) Delete(key string) error {
func (fst *FSTree) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
_, err := q.Check()
if err != nil {
return nil, fmt.Errorf("invalid query: %w", err)
return nil, fmt.Errorf("invalid query: %s", err)
}
walkPrefix, err := fst.buildFilePath(q.DatabaseKeyPrefix(), false)
@ -177,10 +177,10 @@ func (fst *FSTree) Query(q *query.Query, local, internal bool) (*iterator.Iterat
walkRoot = walkPrefix
case err == nil:
walkRoot = filepath.Dir(walkPrefix)
case errors.Is(err, fs.ErrNotExist):
case os.IsNotExist(err):
walkRoot = filepath.Dir(walkPrefix)
default: // err != nil
return nil, fmt.Errorf("fstree: could not stat query root %s: %w", walkPrefix, err)
return nil, fmt.Errorf("fstree: could not stat query root %s: %s", walkPrefix, err)
}
queryIter := iterator.New()
@ -191,8 +191,10 @@ func (fst *FSTree) Query(q *query.Query, local, internal bool) (*iterator.Iterat
func (fst *FSTree) queryExecutor(walkRoot string, queryIter *iterator.Iterator, q *query.Query, local, internal bool) {
err := filepath.Walk(walkRoot, func(path string, info os.FileInfo, err error) error {
// check for error
if err != nil {
return fmt.Errorf("fstree: error in walking fs: %w", err)
return fmt.Errorf("fstree: error in walking fs: %s", err)
}
if info.IsDir() {
@ -210,22 +212,22 @@ func (fst *FSTree) queryExecutor(walkRoot string, queryIter *iterator.Iterator,
}
// read file
data, err := os.ReadFile(path)
data, err := ioutil.ReadFile(path)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("fstree: failed to read file %s: %w", path, err)
return fmt.Errorf("fstree: failed to read file %s: %s", path, err)
}
// parse
key, err := filepath.Rel(fst.basePath, path)
if err != nil {
return fmt.Errorf("fstree: failed to extract key from filepath %s: %w", path, err)
return fmt.Errorf("fstree: failed to extract key from filepath %s: %s", path, err)
}
r, err := record.NewRawWrapper(fst.name, key, data)
if err != nil {
return fmt.Errorf("fstree: failed to load file %s: %w", path, err)
return fmt.Errorf("fstree: failed to load file %s: %s", path, err)
}
if !r.Meta().CheckValidity() {
@ -275,7 +277,7 @@ func (fst *FSTree) Shutdown() error {
return nil
}
// writeFile mirrors os.WriteFile, replacing an existing file with the same
// writeFile mirrors ioutil.WriteFile, replacing an existing file with the same
// name atomically. This is not atomic on Windows, but still an improvement.
// TODO: Replace with github.com/google/renamio.WriteFile as soon as it is fixed on Windows.
// TODO: This has become a wont-fix. Explore other options.

View file

@ -2,5 +2,7 @@ package fstree
import "github.com/safing/portbase/database/storage"
// Compile time interface checks.
var _ storage.Interface = &FSTree{}
var (
// Compile time interface checks.
_ storage.Interface = &FSTree{}
)

View file

@ -113,7 +113,7 @@ func (hm *HashMap) Delete(key string) error {
func (hm *HashMap) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) {
_, err := q.Check()
if err != nil {
return nil, fmt.Errorf("invalid query: %w", err)
return nil, fmt.Errorf("invalid query: %s", err)
}
queryIter := iterator.New()

View file

@ -1,3 +1,4 @@
//nolint:unparam,maligned
package hashmap
import (
@ -5,9 +6,10 @@ import (
"sync"
"testing"
"github.com/safing/portbase/database/storage"
"github.com/safing/portbase/database/query"
"github.com/safing/portbase/database/record"
"github.com/safing/portbase/database/storage"
)
var (
@ -16,7 +18,7 @@ var (
_ storage.Batcher = &HashMap{}
)
type TestRecord struct { //nolint:maligned
type TestRecord struct {
record.Base
sync.Mutex
S string
@ -36,8 +38,6 @@ type TestRecord struct { //nolint:maligned
}
func TestHashMap(t *testing.T) {
t.Parallel()
// start
db, err := NewHashMap("test", "")
if err != nil {

View file

@ -10,13 +10,15 @@ import (
"github.com/safing/portbase/database/record"
)
// ErrNotImplemented is returned when a function is not implemented by a storage.
var ErrNotImplemented = errors.New("not implemented")
var (
// ErrNotImplemented is returned when a function is not implemented by a storage.
ErrNotImplemented = errors.New("not implemented")
)
// InjectBase is a dummy base structure to reduce boilerplate code for injected storage interfaces.
type InjectBase struct{}
// Compile time interface check.
// Compile time interface check
var _ Interface = &InjectBase{}
// Get returns a database record.

View file

@ -26,7 +26,7 @@ type Interface interface {
MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time, shadowDelete bool) error
}
// MetaHandler defines the database storage API for backends that support optimized fetching of only the metadata.
// Maintainer defines the database storage API for backends that support optimized fetching of only the metadata.
type MetaHandler interface {
GetMeta(key string) (*record.Meta, error)
}

View file

@ -17,7 +17,7 @@ type Sinkhole struct {
}
var (
// Compile time interface checks.
// Compile time interface check
_ storage.Interface = &Sinkhole{}
_ storage.Maintainer = &Sinkhole{}
_ storage.Batcher = &Sinkhole{}
@ -62,7 +62,7 @@ func (s *Sinkhole) PutMany(shadowDelete bool) (chan<- record.Record, <-chan erro
// start handler
go func() {
for range batch {
// discard everything
// nom, nom, nom
}
errs <- nil
}()

Some files were not shown because too many files have changed in this diff Show more