mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-04-28 11:30:00 +00:00
- Add batching to generate_embeddings() (50 texts per batch with per-batch retry) to prevent 413 Payload Too Large errors on large documents - Add 413 error classification rule for user-friendly error messages - Fix misleading "Created 0 embedded chunks" log in process_source_command by removing premature get_embedded_chunks() call (embedding is fire-and-forget) Closes #594
680 lines
24 KiB
Python
680 lines
24 KiB
Python
import asyncio
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Union
|
|
|
|
from loguru import logger
|
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
from surreal_commands import submit_command
|
|
from surrealdb import RecordID
|
|
|
|
from open_notebook.database.repository import ensure_record_id, repo_query
|
|
from open_notebook.domain.base import ObjectModel
|
|
from open_notebook.exceptions import DatabaseOperationError, InvalidInputError
|
|
|
|
|
|
class Notebook(ObjectModel):
|
|
table_name: ClassVar[str] = "notebook"
|
|
name: str
|
|
description: str
|
|
archived: Optional[bool] = False
|
|
|
|
@field_validator("name")
|
|
@classmethod
|
|
def name_must_not_be_empty(cls, v):
|
|
if not v.strip():
|
|
raise InvalidInputError("Notebook name cannot be empty")
|
|
return v
|
|
|
|
async def get_sources(self) -> List["Source"]:
|
|
try:
|
|
srcs = await repo_query(
|
|
"""
|
|
select * omit source.full_text from (
|
|
select in as source from reference where out=$id
|
|
fetch source
|
|
) order by source.updated desc
|
|
""",
|
|
{"id": ensure_record_id(self.id)},
|
|
)
|
|
return [Source(**src["source"]) for src in srcs] if srcs else []
|
|
except Exception as e:
|
|
logger.error(f"Error fetching sources for notebook {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
async def get_notes(self) -> List["Note"]:
|
|
try:
|
|
srcs = await repo_query(
|
|
"""
|
|
select * omit note.content, note.embedding from (
|
|
select in as note from artifact where out=$id
|
|
fetch note
|
|
) order by note.updated desc
|
|
""",
|
|
{"id": ensure_record_id(self.id)},
|
|
)
|
|
return [Note(**src["note"]) for src in srcs] if srcs else []
|
|
except Exception as e:
|
|
logger.error(f"Error fetching notes for notebook {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
async def get_chat_sessions(self) -> List["ChatSession"]:
|
|
try:
|
|
srcs = await repo_query(
|
|
"""
|
|
select * from (
|
|
select
|
|
<- chat_session as chat_session
|
|
from refers_to
|
|
where out=$id
|
|
fetch chat_session
|
|
)
|
|
order by chat_session.updated desc
|
|
""",
|
|
{"id": ensure_record_id(self.id)},
|
|
)
|
|
return (
|
|
[ChatSession(**src["chat_session"][0]) for src in srcs] if srcs else []
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error fetching chat sessions for notebook {self.id}: {str(e)}"
|
|
)
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
async def get_delete_preview(self) -> Dict[str, Any]:
|
|
"""
|
|
Get counts of items that would be affected by deleting this notebook.
|
|
|
|
Returns a dict with:
|
|
- note_count: Number of notes that will be deleted
|
|
- exclusive_source_count: Sources only in this notebook (can be deleted)
|
|
- shared_source_count: Sources in other notebooks (will be unlinked only)
|
|
"""
|
|
try:
|
|
notebook_id = ensure_record_id(self.id)
|
|
|
|
# Count notes
|
|
note_result = await repo_query(
|
|
"SELECT count() as count FROM artifact WHERE out = $notebook_id GROUP ALL",
|
|
{"notebook_id": notebook_id},
|
|
)
|
|
note_count = note_result[0]["count"] if note_result else 0
|
|
|
|
# Get sources with count of references to OTHER notebooks
|
|
# If assigned_others = 0, source is exclusive to this notebook
|
|
# If assigned_others > 0, source is shared with other notebooks
|
|
source_counts = await repo_query(
|
|
"""
|
|
SELECT
|
|
id,
|
|
count(->reference[WHERE out != $notebook_id].out) as assigned_others
|
|
FROM (SELECT VALUE <-reference.in AS sources FROM $notebook_id)[0]
|
|
""",
|
|
{"notebook_id": notebook_id},
|
|
)
|
|
|
|
exclusive_count = 0
|
|
shared_count = 0
|
|
for src in source_counts:
|
|
if src.get("assigned_others", 0) == 0:
|
|
exclusive_count += 1
|
|
else:
|
|
shared_count += 1
|
|
|
|
return {
|
|
"note_count": note_count,
|
|
"exclusive_source_count": exclusive_count,
|
|
"shared_source_count": shared_count,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error getting delete preview for notebook {self.id}: {e}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
async def delete(self, delete_exclusive_sources: bool = False) -> Dict[str, int]:
|
|
"""
|
|
Delete notebook with cascade deletion of notes and optional source deletion.
|
|
|
|
Args:
|
|
delete_exclusive_sources: If True, also delete sources that belong
|
|
only to this notebook. Default is False.
|
|
|
|
Returns:
|
|
Dict with counts: deleted_notes, deleted_sources, unlinked_sources
|
|
"""
|
|
if self.id is None:
|
|
raise InvalidInputError("Cannot delete notebook without an ID")
|
|
|
|
try:
|
|
notebook_id = ensure_record_id(self.id)
|
|
deleted_notes = 0
|
|
deleted_sources = 0
|
|
unlinked_sources = 0
|
|
|
|
# 1. Get and delete all notes linked to this notebook
|
|
notes = await self.get_notes()
|
|
for note in notes:
|
|
await note.delete()
|
|
deleted_notes += 1
|
|
logger.info(f"Deleted {deleted_notes} notes for notebook {self.id}")
|
|
|
|
# Delete artifact relationships
|
|
await repo_query(
|
|
"DELETE artifact WHERE out = $notebook_id",
|
|
{"notebook_id": notebook_id},
|
|
)
|
|
|
|
# 2. Handle sources
|
|
if delete_exclusive_sources:
|
|
# Find sources with count of references to OTHER notebooks
|
|
# If assigned_others = 0, source is exclusive to this notebook
|
|
source_counts = await repo_query(
|
|
"""
|
|
SELECT
|
|
id,
|
|
count(->reference[WHERE out != $notebook_id].out) as assigned_others
|
|
FROM (SELECT VALUE <-reference.in AS sources FROM $notebook_id)[0]
|
|
""",
|
|
{"notebook_id": notebook_id},
|
|
)
|
|
|
|
for src in source_counts:
|
|
source_id = src.get("id")
|
|
if source_id and src.get("assigned_others", 0) == 0:
|
|
# Exclusive source - delete it
|
|
try:
|
|
source = await Source.get(str(source_id))
|
|
await source.delete()
|
|
deleted_sources += 1
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to delete exclusive source {source_id}: {e}"
|
|
)
|
|
else:
|
|
unlinked_sources += 1
|
|
else:
|
|
# Just count sources that will be unlinked
|
|
source_result = await repo_query(
|
|
"SELECT count() as count FROM reference WHERE out = $notebook_id GROUP ALL",
|
|
{"notebook_id": notebook_id},
|
|
)
|
|
unlinked_sources = source_result[0]["count"] if source_result else 0
|
|
|
|
# Delete reference relationships (unlink all sources)
|
|
await repo_query(
|
|
"DELETE reference WHERE out = $notebook_id",
|
|
{"notebook_id": notebook_id},
|
|
)
|
|
logger.info(
|
|
f"Unlinked {unlinked_sources} sources, deleted {deleted_sources} "
|
|
f"exclusive sources for notebook {self.id}"
|
|
)
|
|
|
|
# 3. Delete the notebook record itself
|
|
await super().delete()
|
|
logger.info(f"Deleted notebook {self.id}")
|
|
|
|
return {
|
|
"deleted_notes": deleted_notes,
|
|
"deleted_sources": deleted_sources,
|
|
"unlinked_sources": unlinked_sources,
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting notebook {self.id}: {e}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(f"Failed to delete notebook: {e}")
|
|
|
|
|
|
class Asset(BaseModel):
|
|
file_path: Optional[str] = None
|
|
url: Optional[str] = None
|
|
|
|
|
|
class SourceEmbedding(ObjectModel):
|
|
table_name: ClassVar[str] = "source_embedding"
|
|
content: str
|
|
|
|
async def get_source(self) -> "Source":
|
|
try:
|
|
src = await repo_query(
|
|
"""
|
|
select source.* from $id fetch source
|
|
""",
|
|
{"id": ensure_record_id(self.id)},
|
|
)
|
|
return Source(**src[0]["source"])
|
|
except Exception as e:
|
|
logger.error(f"Error fetching source for embedding {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
|
|
class SourceInsight(ObjectModel):
|
|
table_name: ClassVar[str] = "source_insight"
|
|
insight_type: str
|
|
content: str
|
|
|
|
async def get_source(self) -> "Source":
|
|
try:
|
|
src = await repo_query(
|
|
"""
|
|
select source.* from $id fetch source
|
|
""",
|
|
{"id": ensure_record_id(self.id)},
|
|
)
|
|
return Source(**src[0]["source"])
|
|
except Exception as e:
|
|
logger.error(f"Error fetching source for insight {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
async def save_as_note(self, notebook_id: Optional[str] = None) -> Any:
|
|
source = await self.get_source()
|
|
note = Note(
|
|
title=f"{self.insight_type} from source {source.title}",
|
|
content=self.content,
|
|
)
|
|
await note.save()
|
|
if notebook_id:
|
|
await note.add_to_notebook(notebook_id)
|
|
return note
|
|
|
|
|
|
class Source(ObjectModel):
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
table_name: ClassVar[str] = "source"
|
|
asset: Optional[Asset] = None
|
|
title: Optional[str] = None
|
|
topics: Optional[List[str]] = Field(default_factory=list)
|
|
full_text: Optional[str] = None
|
|
command: Optional[Union[str, RecordID]] = Field(
|
|
default=None, description="Link to surreal-commands processing job"
|
|
)
|
|
|
|
@field_validator("command", mode="before")
|
|
@classmethod
|
|
def parse_command(cls, value):
|
|
"""Parse command field to ensure RecordID format"""
|
|
if isinstance(value, str) and value:
|
|
return ensure_record_id(value)
|
|
return value
|
|
|
|
@field_validator("id", mode="before")
|
|
@classmethod
|
|
def parse_id(cls, value):
|
|
"""Parse id field to handle both string and RecordID inputs"""
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, RecordID):
|
|
return str(value)
|
|
return str(value) if value else None
|
|
|
|
async def get_status(self) -> Optional[str]:
|
|
"""Get the processing status of the associated command"""
|
|
if not self.command:
|
|
return None
|
|
|
|
try:
|
|
from surreal_commands import get_command_status
|
|
|
|
status = await get_command_status(str(self.command))
|
|
return status.status if status else "unknown"
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get command status for {self.command}: {e}")
|
|
return "unknown"
|
|
|
|
async def get_processing_progress(self) -> Optional[Dict[str, Any]]:
|
|
"""Get detailed processing information for the associated command"""
|
|
if not self.command:
|
|
return None
|
|
|
|
try:
|
|
from surreal_commands import get_command_status
|
|
|
|
status_result = await get_command_status(str(self.command))
|
|
if not status_result:
|
|
return None
|
|
|
|
# Extract execution metadata if available
|
|
result = getattr(status_result, "result", None)
|
|
execution_metadata = (
|
|
result.get("execution_metadata", {}) if isinstance(result, dict) else {}
|
|
)
|
|
|
|
return {
|
|
"status": status_result.status,
|
|
"started_at": execution_metadata.get("started_at"),
|
|
"completed_at": execution_metadata.get("completed_at"),
|
|
"error": getattr(status_result, "error_message", None),
|
|
"result": result,
|
|
}
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get command progress for {self.command}: {e}")
|
|
return None
|
|
|
|
async def get_context(
|
|
self, context_size: Literal["short", "long"] = "short"
|
|
) -> Dict[str, Any]:
|
|
insights_list = await self.get_insights()
|
|
insights = [insight.model_dump() for insight in insights_list]
|
|
if context_size == "long":
|
|
return dict(
|
|
id=self.id,
|
|
title=self.title,
|
|
insights=insights,
|
|
full_text=self.full_text,
|
|
)
|
|
else:
|
|
return dict(id=self.id, title=self.title, insights=insights)
|
|
|
|
async def get_embedded_chunks(self) -> int:
|
|
try:
|
|
result = await repo_query(
|
|
"""
|
|
select count() as chunks from source_embedding where source=$id GROUP ALL
|
|
""",
|
|
{"id": ensure_record_id(self.id)},
|
|
)
|
|
if len(result) == 0:
|
|
return 0
|
|
return result[0]["chunks"]
|
|
except Exception as e:
|
|
logger.error(f"Error fetching chunks count for source {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(f"Failed to count chunks for source: {str(e)}")
|
|
|
|
async def get_insights(self) -> List[SourceInsight]:
|
|
try:
|
|
result = await repo_query(
|
|
"""
|
|
SELECT * FROM source_insight WHERE source=$id
|
|
""",
|
|
{"id": ensure_record_id(self.id)},
|
|
)
|
|
return [SourceInsight(**insight) for insight in result]
|
|
except Exception as e:
|
|
logger.error(f"Error fetching insights for source {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError("Failed to fetch insights for source")
|
|
|
|
async def add_to_notebook(self, notebook_id: str) -> Any:
|
|
if not notebook_id:
|
|
raise InvalidInputError("Notebook ID must be provided")
|
|
return await self.relate("reference", notebook_id)
|
|
|
|
async def vectorize(self) -> str:
|
|
"""
|
|
Submit vectorization as a background job using the embed_source command.
|
|
|
|
This method leverages the job-based architecture to prevent HTTP connection
|
|
pool exhaustion when processing large documents. The embed_source command:
|
|
1. Detects content type from file path
|
|
2. Chunks text using content-type aware splitter
|
|
3. Generates all embeddings in batches
|
|
4. Bulk inserts source_embedding records
|
|
|
|
Returns:
|
|
str: The command/job ID that can be used to track progress via the commands API
|
|
|
|
Raises:
|
|
ValueError: If source has no text to vectorize
|
|
DatabaseOperationError: If job submission fails
|
|
"""
|
|
logger.info(f"Submitting embed_source job for source {self.id}")
|
|
|
|
try:
|
|
if not self.full_text or not self.full_text.strip():
|
|
raise ValueError(f"Source {self.id} has no text to vectorize")
|
|
|
|
# Submit the embed_source command
|
|
command_id = submit_command(
|
|
"open_notebook",
|
|
"embed_source",
|
|
{"source_id": str(self.id)},
|
|
)
|
|
|
|
command_id_str = str(command_id)
|
|
logger.info(
|
|
f"Embed source job submitted for source {self.id}: "
|
|
f"command_id={command_id_str}"
|
|
)
|
|
|
|
return command_id_str
|
|
|
|
except ValueError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to submit embed_source job for source {self.id}: {e}"
|
|
)
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
async def add_insight(self, insight_type: str, content: str) -> Optional[str]:
|
|
"""
|
|
Submit insight creation as an async command (fire-and-forget).
|
|
|
|
Submits a create_insight command that handles database operations with
|
|
automatic retry logic for transaction conflicts. The command also submits
|
|
an embed_insight command for async embedding.
|
|
|
|
This method returns immediately after submitting the command - it does NOT
|
|
wait for the insight to be created. Use this for batch operations where
|
|
throughput is more important than immediate confirmation.
|
|
|
|
Args:
|
|
insight_type: Type/category of the insight
|
|
content: The insight content text
|
|
|
|
Returns:
|
|
command_id for optional tracking, or None if submission failed
|
|
|
|
Raises:
|
|
InvalidInputError: If insight_type or content is empty
|
|
"""
|
|
if not insight_type or not content:
|
|
raise InvalidInputError("Insight type and content must be provided")
|
|
|
|
try:
|
|
# Submit create_insight command (fire-and-forget)
|
|
# Command handles retries internally for transaction conflicts
|
|
command_id = submit_command(
|
|
"open_notebook",
|
|
"create_insight",
|
|
{
|
|
"source_id": str(self.id),
|
|
"insight_type": insight_type,
|
|
"content": content,
|
|
},
|
|
)
|
|
logger.info(
|
|
f"Submitted create_insight command {command_id} for source {self.id} "
|
|
f"(type={insight_type})"
|
|
)
|
|
return str(command_id)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error submitting create_insight for source {self.id}: {e}")
|
|
return None
|
|
|
|
def _prepare_save_data(self) -> dict:
|
|
"""Override to ensure command field is always RecordID format for database"""
|
|
data = super()._prepare_save_data()
|
|
|
|
# Ensure command field is RecordID format if not None
|
|
if data.get("command") is not None:
|
|
data["command"] = ensure_record_id(data["command"])
|
|
|
|
return data
|
|
|
|
async def delete(self) -> bool:
|
|
"""Delete source and clean up associated file, embeddings, and insights."""
|
|
# Clean up uploaded file if it exists
|
|
if self.asset and self.asset.file_path:
|
|
file_path = Path(self.asset.file_path)
|
|
if file_path.exists():
|
|
try:
|
|
os.unlink(file_path)
|
|
logger.info(f"Deleted file for source {self.id}: {file_path}")
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to delete file {file_path} for source {self.id}: {e}. "
|
|
"Continuing with database deletion."
|
|
)
|
|
else:
|
|
logger.debug(
|
|
f"File {file_path} not found for source {self.id}, skipping cleanup"
|
|
)
|
|
|
|
# Delete associated embeddings and insights to prevent orphaned records
|
|
try:
|
|
source_id = ensure_record_id(self.id)
|
|
await repo_query(
|
|
"DELETE source_embedding WHERE source = $source_id",
|
|
{"source_id": source_id},
|
|
)
|
|
await repo_query(
|
|
"DELETE source_insight WHERE source = $source_id",
|
|
{"source_id": source_id},
|
|
)
|
|
logger.debug(f"Deleted embeddings and insights for source {self.id}")
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to delete embeddings/insights for source {self.id}: {e}. "
|
|
"Continuing with source deletion."
|
|
)
|
|
|
|
# Call parent delete to remove database record
|
|
return await super().delete()
|
|
|
|
|
|
class Note(ObjectModel):
|
|
table_name: ClassVar[str] = "note"
|
|
title: Optional[str] = None
|
|
note_type: Optional[Literal["human", "ai"]] = None
|
|
content: Optional[str] = None
|
|
|
|
@field_validator("content")
|
|
@classmethod
|
|
def content_must_not_be_empty(cls, v):
|
|
if v is not None and not v.strip():
|
|
raise InvalidInputError("Note content cannot be empty")
|
|
return v
|
|
|
|
async def save(self) -> Optional[str]:
|
|
"""
|
|
Save the note and submit embedding command.
|
|
|
|
Overrides ObjectModel.save() to submit an async embed_note command
|
|
after saving, instead of inline embedding.
|
|
|
|
Returns:
|
|
Optional[str]: The command_id if embedding was submitted, None otherwise
|
|
"""
|
|
# Call parent save (without embedding)
|
|
await super().save()
|
|
|
|
# Submit embedding command (fire-and-forget) if note has content
|
|
if self.id and self.content and self.content.strip():
|
|
command_id = submit_command(
|
|
"open_notebook",
|
|
"embed_note",
|
|
{"note_id": str(self.id)},
|
|
)
|
|
logger.debug(f"Submitted embed_note command {command_id} for {self.id}")
|
|
return command_id
|
|
|
|
return None
|
|
|
|
async def add_to_notebook(self, notebook_id: str) -> Any:
|
|
if not notebook_id:
|
|
raise InvalidInputError("Notebook ID must be provided")
|
|
return await self.relate("artifact", notebook_id)
|
|
|
|
def get_context(
|
|
self, context_size: Literal["short", "long"] = "short"
|
|
) -> Dict[str, Any]:
|
|
if context_size == "long":
|
|
return dict(id=self.id, title=self.title, content=self.content)
|
|
else:
|
|
return dict(
|
|
id=self.id,
|
|
title=self.title,
|
|
content=self.content[:100] if self.content else None,
|
|
)
|
|
|
|
|
|
class ChatSession(ObjectModel):
|
|
table_name: ClassVar[str] = "chat_session"
|
|
nullable_fields: ClassVar[set[str]] = {"model_override"}
|
|
title: Optional[str] = None
|
|
model_override: Optional[str] = None
|
|
|
|
async def relate_to_notebook(self, notebook_id: str) -> Any:
|
|
if not notebook_id:
|
|
raise InvalidInputError("Notebook ID must be provided")
|
|
return await self.relate("refers_to", notebook_id)
|
|
|
|
async def relate_to_source(self, source_id: str) -> Any:
|
|
if not source_id:
|
|
raise InvalidInputError("Source ID must be provided")
|
|
return await self.relate("refers_to", source_id)
|
|
|
|
|
|
async def text_search(
|
|
keyword: str, results: int, source: bool = True, note: bool = True
|
|
):
|
|
if not keyword:
|
|
raise InvalidInputError("Search keyword cannot be empty")
|
|
try:
|
|
search_results = await repo_query(
|
|
"""
|
|
select *
|
|
from fn::text_search($keyword, $results, $source, $note)
|
|
""",
|
|
{"keyword": keyword, "results": results, "source": source, "note": note},
|
|
)
|
|
return search_results
|
|
except Exception as e:
|
|
logger.error(f"Error performing text search: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
|
|
async def vector_search(
|
|
keyword: str,
|
|
results: int,
|
|
source: bool = True,
|
|
note: bool = True,
|
|
minimum_score=0.2,
|
|
):
|
|
if not keyword:
|
|
raise InvalidInputError("Search keyword cannot be empty")
|
|
try:
|
|
from open_notebook.utils.embedding import generate_embedding
|
|
|
|
# Use unified embedding function (handles chunking if query is very long)
|
|
embed = await generate_embedding(keyword)
|
|
search_results = await repo_query(
|
|
"""
|
|
SELECT * FROM fn::vector_search($embed, $results, $source, $note, $minimum_score);
|
|
""",
|
|
{
|
|
"embed": embed,
|
|
"results": results,
|
|
"source": source,
|
|
"note": note,
|
|
"minimum_score": minimum_score,
|
|
},
|
|
)
|
|
return search_results
|
|
except Exception as e:
|
|
logger.error(f"Error performing vector search: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|