From c04213219bc647021415adce5447557cf863973f Mon Sep 17 00:00:00 2001
From: Daniel <dhaavi@users.noreply.github.com>
Date: Fri, 28 Feb 2025 10:21:59 +0100
Subject: [PATCH] Make format and value nullable and improve maintenance and
 purge queries

---
 .../storage/sqlite/migrations/1_initial.sql   |   4 +-
 .../storage/sqlite/models/records.bob.go      |  50 +++----
 base/database/storage/sqlite/sqlite.go        | 133 +++++++++++++++---
 base/database/storage/sqlite/sqlite_test.go   |  24 ++--
 go.mod                                        |   1 +
 go.sum                                        |   2 +
 6 files changed, 155 insertions(+), 59 deletions(-)

diff --git a/base/database/storage/sqlite/migrations/1_initial.sql b/base/database/storage/sqlite/migrations/1_initial.sql
index e0c9ded7..172dbded 100644
--- a/base/database/storage/sqlite/migrations/1_initial.sql
+++ b/base/database/storage/sqlite/migrations/1_initial.sql
@@ -3,8 +3,8 @@
 CREATE TABLE records (
     key TEXT PRIMARY KEY,
 
-    format SMALLINT NOT NULL,
-    value  BLOB NOT NULL,
+    format SMALLINT,
+    value  BLOB,
 
     created    BIGINT NOT NULL,
     modified   BIGINT NOT NULL,
diff --git a/base/database/storage/sqlite/models/records.bob.go b/base/database/storage/sqlite/models/records.bob.go
index 5f90d27a..02561c19 100644
--- a/base/database/storage/sqlite/models/records.bob.go
+++ b/base/database/storage/sqlite/models/records.bob.go
@@ -7,7 +7,9 @@ import (
 	"context"
 	"io"
 
+	"github.com/aarondl/opt/null"
 	"github.com/aarondl/opt/omit"
+	"github.com/aarondl/opt/omitnull"
 	"github.com/stephenafamo/bob"
 	"github.com/stephenafamo/bob/dialect/sqlite"
 	"github.com/stephenafamo/bob/dialect/sqlite/dialect"
@@ -19,15 +21,15 @@ import (
 
 // Record is an object representing the database table.
 type Record struct {
-	Key        string `db:"key,pk" `
-	Format     int16  `db:"format" `
-	Value      []byte `db:"value" `
-	Created    int64  `db:"created" `
-	Modified   int64  `db:"modified" `
-	Expires    int64  `db:"expires" `
-	Deleted    int64  `db:"deleted" `
-	Secret     bool   `db:"secret" `
-	Crownjewel bool   `db:"crownjewel" `
+	Key        string           `db:"key,pk" `
+	Format     null.Val[int16]  `db:"format" `
+	Value      null.Val[[]byte] `db:"value" `
+	Created    int64            `db:"created" `
+	Modified   int64            `db:"modified" `
+	Expires    int64            `db:"expires" `
+	Deleted    int64            `db:"deleted" `
+	Secret     bool             `db:"secret" `
+	Crownjewel bool             `db:"crownjewel" `
 }
 
 // RecordSlice is an alias for a slice of pointers to Record.
@@ -92,8 +94,8 @@ func buildRecordColumns(alias string) recordColumns {
 
 type recordWhere[Q sqlite.Filterable] struct {
 	Key        sqlite.WhereMod[Q, string]
-	Format     sqlite.WhereMod[Q, int16]
-	Value      sqlite.WhereMod[Q, []byte]
+	Format     sqlite.WhereNullMod[Q, int16]
+	Value      sqlite.WhereNullMod[Q, []byte]
 	Created    sqlite.WhereMod[Q, int64]
 	Modified   sqlite.WhereMod[Q, int64]
 	Expires    sqlite.WhereMod[Q, int64]
@@ -109,8 +111,8 @@ func (recordWhere[Q]) AliasedAs(alias string) recordWhere[Q] {
 func buildRecordWhere[Q sqlite.Filterable](cols recordColumns) recordWhere[Q] {
 	return recordWhere[Q]{
 		Key:        sqlite.Where[Q, string](cols.Key),
-		Format:     sqlite.Where[Q, int16](cols.Format),
-		Value:      sqlite.Where[Q, []byte](cols.Value),
+		Format:     sqlite.WhereNull[Q, int16](cols.Format),
+		Value:      sqlite.WhereNull[Q, []byte](cols.Value),
 		Created:    sqlite.Where[Q, int64](cols.Created),
 		Modified:   sqlite.Where[Q, int64](cols.Modified),
 		Expires:    sqlite.Where[Q, int64](cols.Expires),
@@ -124,15 +126,15 @@ func buildRecordWhere[Q sqlite.Filterable](cols recordColumns) recordWhere[Q] {
 // All values are optional, and do not have to be set
 // Generated columns are not included
 type RecordSetter struct {
-	Key        omit.Val[string] `db:"key,pk" `
-	Format     omit.Val[int16]  `db:"format" `
-	Value      omit.Val[[]byte] `db:"value" `
-	Created    omit.Val[int64]  `db:"created" `
-	Modified   omit.Val[int64]  `db:"modified" `
-	Expires    omit.Val[int64]  `db:"expires" `
-	Deleted    omit.Val[int64]  `db:"deleted" `
-	Secret     omit.Val[bool]   `db:"secret" `
-	Crownjewel omit.Val[bool]   `db:"crownjewel" `
+	Key        omit.Val[string]     `db:"key,pk" `
+	Format     omitnull.Val[int16]  `db:"format" `
+	Value      omitnull.Val[[]byte] `db:"value" `
+	Created    omit.Val[int64]      `db:"created" `
+	Modified   omit.Val[int64]      `db:"modified" `
+	Expires    omit.Val[int64]      `db:"expires" `
+	Deleted    omit.Val[int64]      `db:"deleted" `
+	Secret     omit.Val[bool]       `db:"secret" `
+	Crownjewel omit.Val[bool]       `db:"crownjewel" `
 }
 
 func (s RecordSetter) SetColumns() []string {
@@ -181,10 +183,10 @@ func (s RecordSetter) Overwrite(t *Record) {
 		t.Key, _ = s.Key.Get()
 	}
 	if !s.Format.IsUnset() {
-		t.Format, _ = s.Format.Get()
+		t.Format, _ = s.Format.GetNull()
 	}
 	if !s.Value.IsUnset() {
-		t.Value, _ = s.Value.Get()
+		t.Value, _ = s.Value.GetNull()
 	}
 	if !s.Created.IsUnset() {
 		t.Created, _ = s.Created.Get()
diff --git a/base/database/storage/sqlite/sqlite.go b/base/database/storage/sqlite/sqlite.go
index 53823ecd..38f17ed4 100644
--- a/base/database/storage/sqlite/sqlite.go
+++ b/base/database/storage/sqlite/sqlite.go
@@ -10,7 +10,15 @@ import (
 	"time"
 
 	"github.com/aarondl/opt/omit"
+	"github.com/aarondl/opt/omitnull"
 	migrate "github.com/rubenv/sql-migrate"
+	sqldblogger "github.com/simukti/sqldb-logger"
+	"github.com/stephenafamo/bob"
+	"github.com/stephenafamo/bob/dialect/sqlite"
+	"github.com/stephenafamo/bob/dialect/sqlite/im"
+	"github.com/stephenafamo/bob/dialect/sqlite/um"
+	_ "modernc.org/sqlite"
+
 	"github.com/safing/portmaster/base/database/accessor"
 	"github.com/safing/portmaster/base/database/iterator"
 	"github.com/safing/portmaster/base/database/query"
@@ -19,12 +27,6 @@ import (
 	"github.com/safing/portmaster/base/database/storage/sqlite/models"
 	"github.com/safing/portmaster/base/log"
 	"github.com/safing/structures/dsd"
-	"github.com/stephenafamo/bob"
-	"github.com/stephenafamo/bob/dialect/sqlite"
-	"github.com/stephenafamo/bob/dialect/sqlite/im"
-	"github.com/stephenafamo/bob/dialect/sqlite/um"
-
-	_ "modernc.org/sqlite"
 )
 
 // SQLite storage.
@@ -47,14 +49,40 @@ func init() {
 
 // NewSQLite creates a sqlite database.
 func NewSQLite(name, location string) (*SQLite, error) {
+	return openSQLite(name, location, false)
+}
+
+// openSQLite creates a sqlite database.
+func openSQLite(name, location string, printStmts bool) (*SQLite, error) {
 	dbFile := filepath.Join(location, "db.sqlite")
 
 	// Open database file.
+	// Default settings:
+	// _time_format = YYYY-MM-DDTHH:MM:SS.SSS
+	// _txlock = deferred
 	db, err := sql.Open("sqlite", dbFile)
 	if err != nil {
 		return nil, fmt.Errorf("open sqlite: %w", err)
 	}
 
+	// Enable statement printing.
+	if printStmts {
+		db = sqldblogger.OpenDriver(dbFile, db.Driver(), &statementLogger{})
+	}
+
+	// Set other settings.
+	pragmas := []string{
+		"PRAGMA journal_mode=WAL;",   // Corruption safe write ahead log for txs.
+		"PRAGMA synchronous=NORMAL;", // Best for WAL.
+		"PRAGMA cache_size=-10000;",  // 10MB Cache.
+	}
+	for _, pragma := range pragmas {
+		_, err := db.Exec(pragma)
+		if err != nil {
+			return nil, fmt.Errorf("failed to init sqlite with %s: %w", pragma, err)
+		}
+	}
+
 	// Run migrations on database.
 	n, err := migrate.Exec(db, "sqlite3", getMigrations(), migrate.Up)
 	if err != nil {
@@ -84,7 +112,13 @@ func (db *SQLite) Get(key string) (record.Record, error) {
 	}
 
 	// Return data in wrapper.
-	return record.NewWrapperFromDatabase(db.name, key, getMeta(r), uint8(r.Format), r.Value)
+	return record.NewWrapperFromDatabase(
+		db.name,
+		key,
+		getMeta(r),
+		uint8(r.Format.GetOrZero()),
+		r.Value.GetOrZero(),
+	)
 }
 
 // GetMeta returns the metadata of a database record.
@@ -114,13 +148,20 @@ func (db *SQLite) putRecord(r record.Record, tx *bob.Tx) (record.Record, error)
 	if err != nil {
 		return nil, err
 	}
+	// Prepare for setter.
+	setFormat := omitnull.From(int16(dsd.JSON))
+	setData := omitnull.From(data)
+	if len(data) == 0 {
+		setFormat.Null()
+		setData.Null()
+	}
 
 	// Create structure for insert.
 	m := r.Meta()
 	setter := models.RecordSetter{
 		Key:        omit.From(r.DatabaseKey()),
-		Format:     omit.From(int16(dsd.JSON)),
-		Value:      omit.From(data),
+		Format:     setFormat,
+		Value:      setData,
 		Created:    omit.From(m.Created),
 		Modified:   omit.From(m.Modified),
 		Expires:    omit.From(m.Expires),
@@ -269,7 +310,11 @@ recordsLoop:
 
 		// Check Data.
 		if q.HasWhereCondition() {
-			jsonData := string(r.Value)
+			if r.Format.IsNull() || r.Value.IsNull() {
+				continue recordsLoop
+			}
+
+			jsonData := string(r.Value.GetOrZero())
 			jsonAccess := accessor.NewJSONAccessor(&jsonData)
 			if !q.MatchesAccessor(jsonAccess) {
 				continue recordsLoop
@@ -277,7 +322,13 @@ recordsLoop:
 		}
 
 		// Build database record.
-		matched, _ := record.NewWrapperFromDatabase(db.name, r.Key, m, uint8(r.Format), r.Value)
+		matched, _ := record.NewWrapperFromDatabase(
+			db.name,
+			r.Key,
+			m,
+			uint8(r.Format.GetOrZero()),
+			r.Value.GetOrZero(),
+		)
 
 		select {
 		case <-queryIter.Done:
@@ -301,7 +352,7 @@ recordsLoop:
 
 // Purge deletes all records that match the given query. It returns the number of successful deletes and an error.
 func (db *SQLite) Purge(ctx context.Context, q *query.Query, local, internal, shadowDelete bool) (int, error) {
-	// Optimize for local and internal queries without where clause.
+	// Optimize for local and internal queries without where clause and without shadow delete.
 	if local && internal && !shadowDelete && !q.HasWhereCondition() {
 		db.lock.Lock()
 		defer db.lock.Unlock()
@@ -321,21 +372,49 @@ func (db *SQLite) Purge(ctx context.Context, q *query.Query, local, internal, sh
 		return int(n), err
 	}
 
-	// Otherwise, iterate over all entries and delete matching ones.
+	// Optimize for local and internal queries without where clause, but with shadow delete.
+	if local && internal && shadowDelete && !q.HasWhereCondition() {
+		db.lock.Lock()
+		defer db.lock.Unlock()
 
-	// Create iterator to check all matching records.
-	queryIter := iterator.New()
-	defer queryIter.Cancel()
-	go db.queryExecutor(queryIter, q, local, internal)
+		// First count entries (SQLite does not support affected rows)
+		n, err := models.Records.Query(
+			models.SelectWhere.Records.Key.Like(q.DatabaseKeyPrefix()+"%"),
+		).Count(db.ctx, db.bob)
+		if err != nil || n == 0 {
+			return int(n), err
+		}
 
-	// Delete all matching records.
-	var deleted int
-	for r := range queryIter.Next {
-		db.Delete(r.DatabaseKey())
-		deleted++
+		// Mark purged records as deleted.
+		now := time.Now().Unix()
+		_, err = models.Records.Update(
+			um.SetCol("format").ToArg(nil),
+			um.SetCol("value").ToArg(nil),
+			um.SetCol("deleted").ToArg(now),
+			models.UpdateWhere.Records.Key.Like(q.DatabaseKeyPrefix()+"%"),
+		).Exec(db.ctx, db.bob)
+		return int(n), err
 	}
 
-	return deleted, nil
+	// Otherwise, iterate over all entries and delete matching ones.
+	return 0, storage.ErrNotImplemented
+
+	// Create iterator to check all matching records.
+
+	// TODO: This is untested and also needs handling of shadowDelete.
+	// For now: Use only without where condition and with a local and internal db interface.
+	// queryIter := iterator.New()
+	// defer queryIter.Cancel()
+	// go db.queryExecutor(queryIter, q, local, internal)
+
+	// // Delete all matching records.
+	// var deleted int
+	// for r := range queryIter.Next {
+	// 	db.Delete(r.DatabaseKey())
+	// 	deleted++
+	// }
+
+	// return deleted, nil
 }
 
 // ReadOnly returns whether the database is read only.
@@ -360,6 +439,8 @@ func (db *SQLite) MaintainRecordStates(ctx context.Context, purgeDeletedBefore t
 	if shadowDelete {
 		// Mark expired records as deleted.
 		models.Records.Update(
+			um.SetCol("format").ToArg(nil),
+			um.SetCol("value").ToArg(nil),
 			um.SetCol("deleted").ToArg(now),
 			models.UpdateWhere.Records.Deleted.EQ(0),
 			models.UpdateWhere.Records.Expires.GT(0),
@@ -416,3 +497,9 @@ func (db *SQLite) Shutdown() error {
 
 	return db.bob.Close()
 }
+
+type statementLogger struct{}
+
+func (sl statementLogger) Log(ctx context.Context, level sqldblogger.Level, msg string, data map[string]interface{}) {
+	fmt.Printf("SQL: %s --- %+v\n", msg, data)
+}
diff --git a/base/database/storage/sqlite/sqlite_test.go b/base/database/storage/sqlite/sqlite_test.go
index 188b068d..19bb36a0 100644
--- a/base/database/storage/sqlite/sqlite_test.go
+++ b/base/database/storage/sqlite/sqlite_test.go
@@ -7,10 +7,11 @@ import (
 	"testing"
 	"time"
 
+	"github.com/stretchr/testify/assert"
+
 	"github.com/safing/portmaster/base/database/query"
 	"github.com/safing/portmaster/base/database/record"
 	"github.com/safing/portmaster/base/database/storage"
-	"github.com/stretchr/testify/assert"
 )
 
 var (
@@ -51,7 +52,7 @@ func TestSQLite(t *testing.T) {
 	}()
 
 	// start
-	db, err := NewSQLite("test", testDir)
+	db, err := openSQLite("test", testDir, true)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -105,16 +106,18 @@ func TestSQLite(t *testing.T) {
 	// setup query test records
 	qA := &TestRecord{}
 	qA.SetKey("test:path/to/A")
-	qA.CreateMeta()
+	qA.UpdateMeta()
 	qB := &TestRecord{}
 	qB.SetKey("test:path/to/B")
-	qB.CreateMeta()
+	qB.UpdateMeta()
 	qC := &TestRecord{}
 	qC.SetKey("test:path/to/C")
-	qC.CreateMeta()
+	qC.UpdateMeta()
+	// Set expiry in the past.
+	qC.Meta().Expires = time.Now().Add(-time.Hour).Unix()
 	qZ := &TestRecord{}
 	qZ.SetKey("test:z")
-	qZ.CreateMeta()
+	qZ.UpdateMeta()
 	put, errs := db.PutMany(false)
 	put <- qA
 	put <- qB
@@ -139,7 +142,8 @@ func TestSQLite(t *testing.T) {
 	if it.Err() != nil {
 		t.Fatal(it.Err())
 	}
-	if cnt != 3 {
+	if cnt != 2 {
+		// Note: One is expired.
 		t.Fatalf("unexpected query result count: %d", cnt)
 	}
 
@@ -156,7 +160,7 @@ func TestSQLite(t *testing.T) {
 	}
 
 	// maintenance
-	err = db.MaintainRecordStates(context.TODO(), time.Now(), true)
+	err = db.MaintainRecordStates(context.TODO(), time.Now().Add(-time.Minute), true)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -168,11 +172,11 @@ func TestSQLite(t *testing.T) {
 	}
 
 	// purging
-	n, err := db.Purge(context.TODO(), query.New("test:path/to/").MustBeValid(), true, true, false)
+	n, err := db.Purge(context.TODO(), query.New("test:path/to/").MustBeValid(), true, true, true)
 	if err != nil {
 		t.Fatal(err)
 	}
-	if n != 3 {
+	if n != 2 {
 		t.Fatalf("unexpected purge delete count: %d", n)
 	}
 
diff --git a/go.mod b/go.mod
index bdd7b59a..2ce4d421 100644
--- a/go.mod
+++ b/go.mod
@@ -145,6 +145,7 @@ require (
 	github.com/satori/go.uuid v1.2.0 // indirect
 	github.com/seehuhn/sha256d v1.0.0 // indirect
 	github.com/shopspring/decimal v1.3.1 // indirect
+	github.com/simukti/sqldb-logger v0.0.0-20230108155151-646c1a075551 // indirect
 	github.com/spf13/cast v1.5.0 // indirect
 	github.com/spf13/pflag v1.0.5 // indirect
 	github.com/stephenafamo/scan v0.6.1 // indirect
diff --git a/go.sum b/go.sum
index 528dc251..7b4cce76 100644
--- a/go.sum
+++ b/go.sum
@@ -367,6 +367,8 @@ github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMT
 github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
 github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
 github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
+github.com/simukti/sqldb-logger v0.0.0-20230108155151-646c1a075551 h1:+EXKKt7RC4HyE/iE8zSeFL+7YBL8Z7vpBaEE3c7lCnk=
+github.com/simukti/sqldb-logger v0.0.0-20230108155151-646c1a075551/go.mod h1:ztTX0ctjRZ1wn9OXrzhonvNmv43yjFUXJYJR95JQAJE=
 github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
 github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
 github.com/spf13/afero v0.0.0-20170901052352-ee1bd8ee15a1/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=