mirror of
https://github.com/openrecall/openrecall.git
synced 2026-04-28 11:29:50 +00:00
189 lines
No EOL
7.1 KiB
Python
189 lines
No EOL
7.1 KiB
Python
import unittest
|
|
import sqlite3
|
|
import os
|
|
import tempfile
|
|
import time
|
|
import numpy as np
|
|
from unittest.mock import patch
|
|
|
|
# Temporarily adjust path to import from openrecall
|
|
import sys
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
|
|
|
# Now import from openrecall.database, mocking db_path *before* the import
|
|
# Create a temporary file path that will be used by the mock
|
|
temp_db_file = tempfile.NamedTemporaryFile(delete=False)
|
|
mock_db_path = temp_db_file.name
|
|
temp_db_file.close() # Close the file handle, but the file persists because delete=False
|
|
|
|
with patch('openrecall.config.db_path', mock_db_path):
|
|
from openrecall.database import (
|
|
create_db,
|
|
insert_entry,
|
|
get_all_entries,
|
|
get_timestamps,
|
|
Entry,
|
|
)
|
|
# Also patch db_path within the database module itself if it was imported directly there
|
|
import openrecall.database
|
|
openrecall.database.db_path = mock_db_path
|
|
|
|
|
|
class TestDatabase(unittest.TestCase):
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
"""Set up a temporary database file for all tests in this class."""
|
|
# The database path is already patched by the module-level patch
|
|
cls.db_path = mock_db_path
|
|
# Ensure the database and table are created once
|
|
create_db()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
"""Remove the temporary database file after all tests."""
|
|
# Try closing connection if any test left it open (though setUp/tearDown should handle this)
|
|
try:
|
|
if hasattr(cls, 'conn') and cls.conn:
|
|
cls.conn.close()
|
|
except Exception:
|
|
pass # Ignore errors during cleanup
|
|
os.remove(cls.db_path)
|
|
# Clean up sys.path modification
|
|
sys.path.pop(0)
|
|
|
|
|
|
def setUp(self):
|
|
"""Connect to the database and clear entries before each test."""
|
|
self.conn = sqlite3.connect(self.db_path)
|
|
cursor = self.conn.cursor()
|
|
cursor.execute("DELETE FROM entries")
|
|
self.conn.commit()
|
|
# No need to close here, will be handled by tearDown or next setUp potentially
|
|
|
|
def tearDown(self):
|
|
"""Close the database connection after each test."""
|
|
if self.conn:
|
|
self.conn.close()
|
|
|
|
def test_create_db(self):
|
|
"""Test if create_db creates the table and index."""
|
|
# Check if table exists
|
|
cursor = self.conn.cursor()
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='entries'")
|
|
result = cursor.fetchone()
|
|
self.assertIsNotNone(result)
|
|
self.assertEqual(result[0], 'entries')
|
|
|
|
# Check if index exists
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='index' AND name='idx_timestamp'")
|
|
result = cursor.fetchone()
|
|
self.assertIsNotNone(result)
|
|
self.assertEqual(result[0], 'idx_timestamp')
|
|
|
|
def test_02_insert_entry(self):
|
|
"""Test inserting a single entry."""
|
|
ts = int(time.time())
|
|
embedding = np.array([0.1, 0.2, 0.3], dtype=np.float32)
|
|
inserted_id = insert_entry("Test text", ts, embedding, "TestApp", "TestTitle")
|
|
|
|
self.assertIsNotNone(inserted_id)
|
|
self.assertIsInstance(inserted_id, int)
|
|
|
|
# Verify the entry exists in the DB
|
|
cursor = self.conn.cursor()
|
|
cursor.execute("SELECT * FROM entries WHERE id = ?", (inserted_id,))
|
|
result = cursor.fetchone()
|
|
self.assertIsNotNone(result)
|
|
# (id, app, title, text, timestamp, embedding_blob)
|
|
self.assertEqual(result[1], "TestApp")
|
|
self.assertEqual(result[2], "TestTitle")
|
|
self.assertEqual(result[3], "Test text")
|
|
self.assertEqual(result[4], ts)
|
|
retrieved_embedding = np.frombuffer(result[5], dtype=np.float32)
|
|
np.testing.assert_array_almost_equal(retrieved_embedding, embedding)
|
|
|
|
def test_insert_duplicate_timestamp(self):
|
|
"""Test inserting an entry with a duplicate timestamp (should be ignored)."""
|
|
ts = int(time.time())
|
|
embedding1 = np.array([0.1, 0.2, 0.3], dtype=np.float32)
|
|
embedding2 = np.array([0.4, 0.5, 0.6], dtype=np.float32)
|
|
|
|
id1 = insert_entry("First text", ts, embedding1, "App1", "Title1")
|
|
self.assertIsNotNone(id1)
|
|
|
|
# Try inserting another entry with the same timestamp
|
|
id2 = insert_entry("Second text", ts, embedding2, "App2", "Title2")
|
|
self.assertIsNone(id2, "Inserting duplicate timestamp should return None")
|
|
|
|
# Verify only the first entry exists
|
|
cursor = self.conn.cursor()
|
|
cursor.execute("SELECT COUNT(*) FROM entries WHERE timestamp = ?", (ts,))
|
|
count = cursor.fetchone()[0]
|
|
self.assertEqual(count, 1)
|
|
|
|
cursor.execute("SELECT text FROM entries WHERE timestamp = ?", (ts,))
|
|
text = cursor.fetchone()[0]
|
|
self.assertEqual(text, "First text") # Ensure the first one was kept
|
|
|
|
def test_get_all_entries_empty(self):
|
|
"""Test getting entries from an empty database."""
|
|
entries = get_all_entries()
|
|
self.assertEqual(entries, [])
|
|
|
|
def test_get_all_entries_multiple(self):
|
|
"""Test retrieving multiple entries."""
|
|
ts1 = int(time.time())
|
|
ts2 = ts1 + 10
|
|
ts3 = ts1 - 10 # Ensure ordering works
|
|
emb1 = np.array([0.1] * 5, dtype=np.float32)
|
|
emb2 = np.array([0.2] * 5, dtype=np.float32)
|
|
emb3 = np.array([0.3] * 5, dtype=np.float32)
|
|
|
|
insert_entry("Text 1", ts1, emb1, "App1", "Title1")
|
|
insert_entry("Text 2", ts2, emb2, "App2", "Title2")
|
|
insert_entry("Text 3", ts3, emb3, "App3", "Title3")
|
|
|
|
entries = get_all_entries()
|
|
self.assertEqual(len(entries), 3)
|
|
|
|
# Entries should be ordered by timestamp DESC
|
|
self.assertEqual(entries[0].timestamp, ts2)
|
|
self.assertEqual(entries[0].text, "Text 2")
|
|
self.assertEqual(entries[0].app, "App2")
|
|
self.assertEqual(entries[0].title, "Title2")
|
|
np.testing.assert_array_almost_equal(entries[0].embedding, emb2)
|
|
self.assertIsInstance(entries[0].id, int)
|
|
|
|
self.assertEqual(entries[1].timestamp, ts1)
|
|
self.assertEqual(entries[1].text, "Text 1")
|
|
np.testing.assert_array_almost_equal(entries[1].embedding, emb1)
|
|
|
|
self.assertEqual(entries[2].timestamp, ts3)
|
|
self.assertEqual(entries[2].text, "Text 3")
|
|
np.testing.assert_array_almost_equal(entries[2].embedding, emb3)
|
|
|
|
def test_get_timestamps_empty(self):
|
|
"""Test getting timestamps from an empty database."""
|
|
timestamps = get_timestamps()
|
|
self.assertEqual(timestamps, [])
|
|
|
|
def test_get_timestamps_multiple(self):
|
|
"""Test retrieving multiple timestamps."""
|
|
ts1 = int(time.time())
|
|
ts2 = ts1 + 10
|
|
ts3 = ts1 - 10
|
|
emb = np.array([0.1] * 5, dtype=np.float32) # Embedding content doesn't matter here
|
|
|
|
insert_entry("T1", ts1, emb, "A1", "T1")
|
|
insert_entry("T2", ts2, emb, "A2", "T2")
|
|
insert_entry("T3", ts3, emb, "A3", "T3")
|
|
|
|
timestamps = get_timestamps()
|
|
self.assertEqual(len(timestamps), 3)
|
|
# Timestamps should be ordered DESC
|
|
self.assertEqual(timestamps, [ts2, ts1, ts3])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |