tinycoder: Changes to openrecall/database.py

This commit is contained in:
Koen van Eijk 2025-04-23 10:38:41 +02:00
parent 400b8c8577
commit 83c7fcf688

View file

@ -1,50 +1,133 @@
import sqlite3
from collections import namedtuple
from typing import Any, List
import numpy as np
from typing import Any, List, Optional, Tuple
from openrecall.config import db_path
# Define the structure of a database entry using namedtuple
Entry = namedtuple("Entry", ["id", "app", "title", "text", "timestamp", "embedding"])
def create_db() -> None:
with sqlite3.connect(db_path) as conn:
c = conn.cursor()
c.execute(
"""CREATE TABLE IF NOT EXISTS entries
(id INTEGER PRIMARY KEY AUTOINCREMENT, app TEXT, title TEXT, text TEXT, timestamp INTEGER, embedding BLOB)"""
)
conn.commit()
"""
Creates the SQLite database and the 'entries' table if they don't exist.
The table schema includes columns for an auto-incrementing ID, application name,
window title, extracted text, timestamp, and text embedding.
"""
try:
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""CREATE TABLE IF NOT EXISTS entries (
id INTEGER PRIMARY KEY AUTOINCREMENT,
app TEXT,
title TEXT,
text TEXT,
timestamp INTEGER UNIQUE,
embedding BLOB
)"""
)
# Add index on timestamp for faster lookups
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_timestamp ON entries (timestamp)"
)
conn.commit()
except sqlite3.Error as e:
print(f"Database error during table creation: {e}")
def get_all_entries() -> List[Entry]:
with sqlite3.connect(db_path) as conn:
c = conn.cursor()
results = c.execute("SELECT * FROM entries").fetchall()
return [Entry(*result) for result in results]
"""
Retrieves all entries from the database.
Returns:
List[Entry]: A list of all entries as Entry namedtuples.
Returns an empty list if the table is empty or an error occurs.
"""
entries: List[Entry] = []
try:
with sqlite3.connect(db_path) as conn:
conn.row_factory = sqlite3.Row # Return rows as dictionary-like objects
cursor = conn.cursor()
cursor.execute("SELECT id, app, title, text, timestamp, embedding FROM entries ORDER BY timestamp DESC")
results = cursor.fetchall()
for row in results:
# Deserialize the embedding blob back into a NumPy array
embedding = np.frombuffer(row["embedding"], dtype=np.float32) # Assuming float32, adjust if needed
entries.append(
Entry(
id=row["id"],
app=row["app"],
title=row["title"],
text=row["text"],
timestamp=row["timestamp"],
embedding=embedding,
)
)
except sqlite3.Error as e:
print(f"Database error while fetching all entries: {e}")
return entries
def get_timestamps() -> List[int]:
with sqlite3.connect(db_path) as conn:
c = conn.cursor()
results = c.execute(
"SELECT timestamp FROM entries ORDER BY timestamp DESC"
).fetchall()
return [result[0] for result in results]
"""
Retrieves all timestamps from the database, ordered descending.
Returns:
List[int]: A list of all timestamps.
Returns an empty list if the table is empty or an error occurs.
"""
timestamps: List[int] = []
try:
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
# Use the index for potentially faster retrieval
cursor.execute("SELECT timestamp FROM entries ORDER BY timestamp DESC")
results = cursor.fetchall()
timestamps = [result[0] for result in results]
except sqlite3.Error as e:
print(f"Database error while fetching timestamps: {e}")
return timestamps
def insert_entry(
text: str, timestamp: int, embedding: Any, app: str, title: str
) -> None:
embedding_bytes = embedding.tobytes()
try:
text: str, timestamp: int, embedding: np.ndarray, app: str, title: str
) -> Optional[int]:
"""
Inserts a new entry into the database.
Args:
text (str): The extracted text content.
timestamp (int): The Unix timestamp of the screenshot.
embedding (np.ndarray): The embedding vector for the text.
app (str): The name of the active application.
title (str): The title of the active window.
Returns:
Optional[int]: The ID of the newly inserted row, or None if insertion fails.
Prints an error message to stderr on failure.
"""
embedding_bytes: bytes = embedding.astype(np.float32).tobytes() # Ensure consistent dtype
last_row_id: Optional[int] = None
try:
with sqlite3.connect(db_path) as conn:
c = conn.cursor()
c.execute(
"INSERT INTO entries (text, timestamp, embedding, app, title) VALUES (?, ?, ?, ?, ?)",
cursor = conn.cursor()
cursor.execute(
"""INSERT INTO entries (text, timestamp, embedding, app, title)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(timestamp) DO NOTHING""", # Avoid duplicates based on timestamp
(text, timestamp, embedding_bytes, app, title),
)
conn.commit()
except sqlite3.OperationalError as e:
print("Error inserting entry:", e)
if cursor.rowcount > 0: # Check if insert actually happened
last_row_id = cursor.lastrowid
# else:
# Optionally log that a duplicate timestamp was encountered
# print(f"Skipped inserting entry with duplicate timestamp: {timestamp}")
except sqlite3.Error as e:
# More specific error handling can be added (e.g., IntegrityError for UNIQUE constraint)
print(f"Database error during insertion: {e}")
return last_row_id