From b68646c68919489a893b376e492a60f28527d4e2 Mon Sep 17 00:00:00 2001
From: Daniel <dhaavi@users.noreply.github.com>
Date: Wed, 26 Feb 2025 16:52:34 +0100
Subject: [PATCH] Use transaction for PutMany and cursor Query

---
 base/database/storage/sqlite/sqlite.go      | 74 ++++++++++++++++-----
 base/database/storage/sqlite/sqlite_test.go | 18 ++---
 service/core/base/databases.go              |  1 +
 3 files changed, 66 insertions(+), 27 deletions(-)

diff --git a/base/database/storage/sqlite/sqlite.go b/base/database/storage/sqlite/sqlite.go
index ba9b11c7..53823ecd 100644
--- a/base/database/storage/sqlite/sqlite.go
+++ b/base/database/storage/sqlite/sqlite.go
@@ -99,8 +99,15 @@ func (db *SQLite) GetMeta(key string) (*record.Meta, error) {
 
 // Put stores a record in the database.
 func (db *SQLite) Put(r record.Record) (record.Record, error) {
-	r.Lock()
-	defer r.Unlock()
+	return db.putRecord(r, nil)
+}
+
+func (db *SQLite) putRecord(r record.Record, tx *bob.Tx) (record.Record, error) {
+	// Lock record if in a transaction.
+	if tx != nil {
+		r.Lock()
+		defer r.Unlock()
+	}
 
 	// Serialize to JSON.
 	data, err := r.MarshalDataOnly(r, dsd.JSON)
@@ -127,12 +134,19 @@ func (db *SQLite) Put(r record.Record) (record.Record, error) {
 	defer db.lock.Unlock()
 
 	// Simulate upsert with custom selection on conflict.
-	_, err = models.Records.Insert(
+	dbQuery := models.Records.Insert(
 		&setter,
 		im.OnConflict("key").DoUpdate(
 			im.SetExcluded("format", "value", "created", "modified", "expires", "deleted", "secret", "crownjewel"),
 		),
-	).Exec(db.ctx, db.bob)
+	)
+
+	// Execute in transaction or directly.
+	if tx != nil {
+		_, err = dbQuery.Exec(db.ctx, tx)
+	} else {
+		_, err = dbQuery.Exec(db.ctx, db.bob)
+	}
 	if err != nil {
 		return nil, err
 	}
@@ -150,16 +164,39 @@ func (db *SQLite) PutMany(shadowDelete bool) (chan<- record.Record, <-chan error
 	batch := make(chan record.Record, 100)
 	errs := make(chan error, 1)
 
+	tx, err := db.bob.BeginTx(db.ctx, nil)
+	if err != nil {
+		errs <- err
+		return batch, errs
+	}
+
 	// start handler
 	go func() {
-		for r := range batch {
-			_, err := db.Put(r)
-			if err != nil {
-				errs <- err
-				return
+		// Read all put records.
+	writeBatch:
+		for {
+			select {
+			case r := <-batch:
+				if r != nil {
+					// Write record.
+					_, err := db.putRecord(r, &tx)
+					if err != nil {
+						errs <- err
+						break writeBatch
+					}
+				} else {
+					// Finalize transcation.
+					errs <- tx.Commit()
+					return
+				}
+
+			case <-db.ctx.Done():
+				break writeBatch
 			}
 		}
-		errs <- nil
+
+		// Rollback transaction.
+		errs <- tx.Rollback()
 	}()
 
 	return batch, errs
@@ -199,20 +236,25 @@ func (db *SQLite) queryExecutor(queryIter *iterator.Iterator, q *query.Query, lo
 		recordQuery = models.Records.View.Query()
 	}
 
-	// Get all records from query.
-	// TODO: This will load all records into memory. While this is efficient and
-	// will not block others from using the datbase, this might be quite a strain
-	// on the system memory. Monitor and see if this is an issue.
+	// Get cursor to go over all records in the query.
 	db.lock.RLock()
-	records, err := models.RecordsQuery.All(recordQuery, db.ctx, db.bob)
+	cursor, err := models.RecordsQuery.Cursor(recordQuery, db.ctx, db.bob)
 	db.lock.RUnlock()
 	if err != nil {
 		queryIter.Finish(err)
 		return
 	}
+	defer cursor.Close()
 
 recordsLoop:
-	for _, r := range records {
+	for cursor.Next() {
+		// Get next record
+		r, cErr := cursor.Get()
+		if cErr != nil {
+			err = fmt.Errorf("cursor error: %w", cErr)
+			break recordsLoop
+		}
+
 		// Check if key matches.
 		if !q.MatchesKey(r.Key) {
 			continue recordsLoop
diff --git a/base/database/storage/sqlite/sqlite_test.go b/base/database/storage/sqlite/sqlite_test.go
index 0fc1a4a1..188b068d 100644
--- a/base/database/storage/sqlite/sqlite_test.go
+++ b/base/database/storage/sqlite/sqlite_test.go
@@ -115,17 +115,13 @@ func TestSQLite(t *testing.T) {
 	qZ := &TestRecord{}
 	qZ.SetKey("test:z")
 	qZ.CreateMeta()
-	// put
-	_, err = db.Put(qA)
-	if err == nil {
-		_, err = db.Put(qB)
-	}
-	if err == nil {
-		_, err = db.Put(qC)
-	}
-	if err == nil {
-		_, err = db.Put(qZ)
-	}
+	put, errs := db.PutMany(false)
+	put <- qA
+	put <- qB
+	put <- qC
+	put <- qZ
+	close(put)
+	err = <-errs
 	if err != nil {
 		t.Fatal(err)
 	}
diff --git a/service/core/base/databases.go b/service/core/base/databases.go
index 68e7819b..c0f0f6b0 100644
--- a/service/core/base/databases.go
+++ b/service/core/base/databases.go
@@ -5,6 +5,7 @@ import (
 
 	"github.com/safing/portmaster/base/database"
 	_ "github.com/safing/portmaster/base/database/storage/bbolt"
+	_ "github.com/safing/portmaster/base/database/storage/sqlite"
 	"github.com/safing/portmaster/base/dataroot"
 	"github.com/safing/portmaster/base/utils"
 )