safing-portmaster/base/database/storage/sqlite/sqlite_test.go
2025-03-04 15:25:44 +01:00

190 lines
3.1 KiB
Go

package sqlite
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/safing/portmaster/base/database/query"
"github.com/safing/portmaster/base/database/record"
"github.com/safing/portmaster/base/database/storage"
)
var (
// Compile time interface checks.
_ storage.Interface = &SQLite{}
_ storage.Batcher = &SQLite{}
_ storage.Purger = &SQLite{}
)
type TestRecord struct { //nolint:maligned
record.Base
sync.Mutex
S string
I int
I8 int8
I16 int16
I32 int32
I64 int64
UI uint
UI8 uint8
UI16 uint16
UI32 uint32
UI64 uint64
F32 float32
F64 float64
B bool
}
func TestSQLite(t *testing.T) {
t.Parallel()
// start
testDir := t.TempDir()
db, err := openSQLite("test", testDir, true)
if err != nil {
t.Fatal(err)
}
defer func() {
// shutdown
err = db.Shutdown()
if err != nil {
t.Fatal(err)
}
}()
a := &TestRecord{
S: "banana",
I: 42,
I8: 42,
I16: 42,
I32: 42,
I64: 42,
UI: 42,
UI8: 42,
UI16: 42,
UI32: 42,
UI64: 42,
F32: 42.42,
F64: 42.42,
B: true,
}
a.SetMeta(&record.Meta{})
a.Meta().Update()
a.SetKey("test:A")
// put record
_, err = db.Put(a)
if err != nil {
t.Fatal(err)
}
// get and compare
r1, err := db.Get("A")
if err != nil {
t.Fatal(err)
}
a1 := &TestRecord{}
err = record.Unwrap(r1, a1)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, a, a1, "struct must match")
// setup query test records
qA := &TestRecord{}
qA.SetKey("test:path/to/A")
qA.UpdateMeta()
qB := &TestRecord{}
qB.SetKey("test:path/to/B")
qB.UpdateMeta()
qC := &TestRecord{}
qC.SetKey("test:path/to/C")
qC.UpdateMeta()
// Set expiry in the past.
qC.Meta().Expires = time.Now().Add(-time.Hour).Unix()
qZ := &TestRecord{}
qZ.SetKey("test:z")
qZ.UpdateMeta()
put, errs := db.PutMany(false)
put <- qA
put <- qB
put <- qC
put <- qZ
close(put)
err = <-errs
if err != nil {
t.Fatal(err)
}
// test query
q := query.New("test:path/to/").MustBeValid()
it, err := db.Query(q, true, true)
if err != nil {
t.Fatal(err)
}
cnt := 0
for range it.Next {
cnt++
}
if it.Err() != nil {
t.Fatal(it.Err())
}
if cnt != 2 {
// Note: One is expired.
t.Fatalf("unexpected query result count: %d", cnt)
}
// delete
err = db.Delete("A")
if err != nil {
t.Fatal(err)
}
// check if its gone
_, err = db.Get("A")
if err == nil {
t.Fatal("should fail")
}
// maintenance
err = db.MaintainRecordStates(t.Context(), time.Now().Add(-time.Minute), true)
if err != nil {
t.Fatal(err)
}
// maintenance
err = db.MaintainRecordStates(t.Context(), time.Now(), false)
if err != nil {
t.Fatal(err)
}
// purging
n, err := db.Purge(t.Context(), query.New("test:path/to/").MustBeValid(), true, true, true)
if err != nil {
t.Fatal(err)
}
if n != 2 {
t.Fatalf("unexpected purge delete count: %d", n)
}
// test query
q = query.New("test").MustBeValid()
it, err = db.Query(q, true, true)
if err != nil {
t.Fatal(err)
}
cnt = 0
for range it.Next {
cnt++
}
if it.Err() != nil {
t.Fatal(it.Err())
}
if cnt != 1 {
t.Fatalf("unexpected query result count: %d", cnt)
}
}