import asyncio from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Union from loguru import logger from pydantic import BaseModel, Field, field_validator from surrealdb import RecordID from open_notebook.database.repository import ensure_record_id, repo_query from open_notebook.domain.base import ObjectModel from open_notebook.domain.models import model_manager from open_notebook.exceptions import DatabaseOperationError, InvalidInputError from open_notebook.utils import split_text class Notebook(ObjectModel): table_name: ClassVar[str] = "notebook" name: str description: str archived: Optional[bool] = False @field_validator("name") @classmethod def name_must_not_be_empty(cls, v): if not v.strip(): raise InvalidInputError("Notebook name cannot be empty") return v async def get_sources(self) -> List["Source"]: try: srcs = await repo_query( """ select * omit source.full_text from ( select in as source from reference where out=$id fetch source ) order by source.updated desc """, {"id": ensure_record_id(self.id)}, ) return [Source(**src["source"]) for src in srcs] if srcs else [] except Exception as e: logger.error(f"Error fetching sources for notebook {self.id}: {str(e)}") logger.exception(e) raise DatabaseOperationError(e) async def get_notes(self) -> List["Note"]: try: srcs = await repo_query( """ select * omit note.content, note.embedding from ( select in as note from artifact where out=$id fetch note ) order by note.updated desc """, {"id": ensure_record_id(self.id)}, ) return [Note(**src["note"]) for src in srcs] if srcs else [] except Exception as e: logger.error(f"Error fetching notes for notebook {self.id}: {str(e)}") logger.exception(e) raise DatabaseOperationError(e) async def get_chat_sessions(self) -> List["ChatSession"]: try: srcs = await repo_query( """ select * from ( select <- chat_session as chat_session from refers_to where out=$id fetch chat_session ) order by chat_session.updated desc """, {"id": ensure_record_id(self.id)}, ) return ( [ChatSession(**src["chat_session"][0]) for src in srcs] if srcs else [] ) except Exception as e: logger.error( f"Error fetching chat sessions for notebook {self.id}: {str(e)}" ) logger.exception(e) raise DatabaseOperationError(e) class Asset(BaseModel): file_path: Optional[str] = None url: Optional[str] = None class SourceEmbedding(ObjectModel): table_name: ClassVar[str] = "source_embedding" content: str async def get_source(self) -> "Source": try: src = await repo_query( """ select source.* from $id fetch source """, {"id": ensure_record_id(self.id)}, ) return Source(**src[0]["source"]) except Exception as e: logger.error(f"Error fetching source for embedding {self.id}: {str(e)}") logger.exception(e) raise DatabaseOperationError(e) class SourceInsight(ObjectModel): table_name: ClassVar[str] = "source_insight" insight_type: str content: str async def get_source(self) -> "Source": try: src = await repo_query( """ select source.* from $id fetch source """, {"id": ensure_record_id(self.id)}, ) return Source(**src[0]["source"]) except Exception as e: logger.error(f"Error fetching source for insight {self.id}: {str(e)}") logger.exception(e) raise DatabaseOperationError(e) async def save_as_note(self, notebook_id: Optional[str] = None) -> Any: source = await self.get_source() note = Note( title=f"{self.insight_type} from source {source.title}", content=self.content, ) await note.save() if notebook_id: await note.add_to_notebook(notebook_id) return note class Source(ObjectModel): table_name: ClassVar[str] = "source" asset: Optional[Asset] = None title: Optional[str] = None topics: Optional[List[str]] = Field(default_factory=list) full_text: Optional[str] = None command: Optional[Union[str, RecordID]] = Field( default=None, description="Link to surreal-commands processing job" ) class Config: arbitrary_types_allowed = True @field_validator("command", mode="before") @classmethod def parse_command(cls, value): """Parse command field to ensure RecordID format""" if isinstance(value, str) and value: return ensure_record_id(value) return value @field_validator("id", mode="before") @classmethod def parse_id(cls, value): """Parse id field to handle both string and RecordID inputs""" if value is None: return None if isinstance(value, RecordID): return str(value) return str(value) if value else None async def get_status(self) -> Optional[str]: """Get the processing status of the associated command""" if not self.command: return None try: from surreal_commands import get_command_status status = await get_command_status(str(self.command)) return status.status if status else "unknown" except Exception as e: logger.warning(f"Failed to get command status for {self.command}: {e}") return "unknown" async def get_processing_progress(self) -> Optional[Dict[str, Any]]: """Get detailed processing information for the associated command""" if not self.command: return None try: from surreal_commands import get_command_status status_result = await get_command_status(str(self.command)) if not status_result: return None # Extract execution metadata if available result = getattr(status_result, "result", None) execution_metadata = result.get("execution_metadata", {}) if isinstance(result, dict) else {} return { "status": status_result.status, "started_at": execution_metadata.get("started_at"), "completed_at": execution_metadata.get("completed_at"), "error": getattr(status_result, "error_message", None), "result": result, } except Exception as e: logger.warning(f"Failed to get command progress for {self.command}: {e}") return None async def get_context( self, context_size: Literal["short", "long"] = "short" ) -> Dict[str, Any]: insights_list = await self.get_insights() insights = [insight.model_dump() for insight in insights_list] if context_size == "long": return dict( id=self.id, title=self.title, insights=insights, full_text=self.full_text, ) else: return dict(id=self.id, title=self.title, insights=insights) async def get_embedded_chunks(self) -> int: try: result = await repo_query( """ select count() as chunks from source_embedding where source=$id GROUP ALL """, {"id": ensure_record_id(self.id)}, ) if len(result) == 0: return 0 return result[0]["chunks"] except Exception as e: logger.error(f"Error fetching chunks count for source {self.id}: {str(e)}") logger.exception(e) raise DatabaseOperationError(f"Failed to count chunks for source: {str(e)}") async def get_insights(self) -> List[SourceInsight]: try: result = await repo_query( """ SELECT * FROM source_insight WHERE source=$id """, {"id": ensure_record_id(self.id)}, ) return [SourceInsight(**insight) for insight in result] except Exception as e: logger.error(f"Error fetching insights for source {self.id}: {str(e)}") logger.exception(e) raise DatabaseOperationError("Failed to fetch insights for source") async def add_to_notebook(self, notebook_id: str) -> Any: if not notebook_id: raise InvalidInputError("Notebook ID must be provided") return await self.relate("reference", notebook_id) async def vectorize(self) -> None: logger.info(f"Starting vectorization for source {self.id}") EMBEDDING_MODEL = await model_manager.get_embedding_model() try: # DELETE EXISTING EMBEDDINGS FIRST - Makes vectorize() idempotent delete_result = await repo_query( "DELETE source_embedding WHERE source = $source_id", {"source_id": ensure_record_id(self.id)} ) deleted_count = len(delete_result) if delete_result else 0 if deleted_count > 0: logger.info(f"Deleted {deleted_count} existing embeddings for source {self.id}") else: logger.debug(f"No existing embeddings found for source {self.id}") if not self.full_text: logger.warning(f"No text to vectorize for source {self.id}") return chunks = split_text( self.full_text, ) chunk_count = len(chunks) logger.info(f"Split into {chunk_count} chunks for source {self.id}") if chunk_count == 0: logger.warning("No chunks created after splitting") return # Process chunks concurrently using async gather logger.info("Starting concurrent processing of chunks") async def process_chunk( idx: int, chunk: str ) -> Tuple[int, List[float], str]: logger.debug(f"Processing chunk {idx}/{chunk_count}") try: if EMBEDDING_MODEL is None: raise ValueError("EMBEDDING_MODEL is not configured") embedding = (await EMBEDDING_MODEL.aembed([chunk]))[0] cleaned_content = chunk logger.debug(f"Successfully processed chunk {idx}") return (idx, embedding, cleaned_content) except Exception as e: logger.error(f"Error processing chunk {idx}: {str(e)}") raise # Create tasks for all chunks and process them concurrently tasks = [process_chunk(idx, chunk) for idx, chunk in enumerate(chunks)] results = await asyncio.gather(*tasks) logger.info(f"Parallel processing complete. Got {len(results)} results") # Insert results in order (they're already ordered by index) for idx, embedding, content in results: logger.debug(f"Inserting chunk {idx} into database") await repo_query( """ CREATE source_embedding CONTENT { "source": $source_id, "order": $order, "content": $content, "embedding": $embedding, };""", { "source_id": ensure_record_id(self.id), "order": idx, "content": content, "embedding": embedding, }, ) logger.info(f"Vectorization complete for source {self.id}") except Exception as e: logger.error(f"Error vectorizing source {self.id}: {str(e)}") logger.exception(e) raise DatabaseOperationError(e) async def add_insight(self, insight_type: str, content: str) -> Any: EMBEDDING_MODEL = await model_manager.get_embedding_model() if not EMBEDDING_MODEL: logger.warning("No embedding model found. Insight will not be searchable.") if not insight_type or not content: raise InvalidInputError("Insight type and content must be provided") try: embedding = ( (await EMBEDDING_MODEL.aembed([content]))[0] if EMBEDDING_MODEL else [] ) return await repo_query( """ CREATE source_insight CONTENT { "source": $source_id, "insight_type": $insight_type, "content": $content, "embedding": $embedding, };""", { "source_id": ensure_record_id(self.id), "insight_type": insight_type, "content": content, "embedding": embedding, }, ) except Exception as e: logger.error(f"Error adding insight to source {self.id}: {str(e)}") raise # DatabaseOperationError(e) def _prepare_save_data(self) -> dict: """Override to ensure command field is always RecordID format for database""" data = super()._prepare_save_data() # Ensure command field is RecordID format if not None if data.get("command") is not None: data["command"] = ensure_record_id(data["command"]) return data class Note(ObjectModel): table_name: ClassVar[str] = "note" title: Optional[str] = None note_type: Optional[Literal["human", "ai"]] = None content: Optional[str] = None @field_validator("content") @classmethod def content_must_not_be_empty(cls, v): if v is not None and not v.strip(): raise InvalidInputError("Note content cannot be empty") return v async def add_to_notebook(self, notebook_id: str) -> Any: if not notebook_id: raise InvalidInputError("Notebook ID must be provided") return await self.relate("artifact", notebook_id) def get_context( self, context_size: Literal["short", "long"] = "short" ) -> Dict[str, Any]: if context_size == "long": return dict(id=self.id, title=self.title, content=self.content) else: return dict( id=self.id, title=self.title, content=self.content[:100] if self.content else None, ) def needs_embedding(self) -> bool: return True def get_embedding_content(self) -> Optional[str]: return self.content class ChatSession(ObjectModel): table_name: ClassVar[str] = "chat_session" title: Optional[str] = None model_override: Optional[str] = None async def relate_to_notebook(self, notebook_id: str) -> Any: if not notebook_id: raise InvalidInputError("Notebook ID must be provided") return await self.relate("refers_to", notebook_id) async def relate_to_source(self, source_id: str) -> Any: if not source_id: raise InvalidInputError("Source ID must be provided") return await self.relate("refers_to", source_id) async def text_search( keyword: str, results: int, source: bool = True, note: bool = True ): if not keyword: raise InvalidInputError("Search keyword cannot be empty") try: search_results = await repo_query( """ select * from fn::text_search($keyword, $results, $source, $note) """, {"keyword": keyword, "results": results, "source": source, "note": note}, ) return search_results except Exception as e: logger.error(f"Error performing text search: {str(e)}") logger.exception(e) raise DatabaseOperationError(e) async def vector_search( keyword: str, results: int, source: bool = True, note: bool = True, minimum_score=0.2, ): if not keyword: raise InvalidInputError("Search keyword cannot be empty") try: EMBEDDING_MODEL = await model_manager.get_embedding_model() if EMBEDDING_MODEL is None: raise ValueError("EMBEDDING_MODEL is not configured") embed = (await EMBEDDING_MODEL.aembed([keyword]))[0] search_results = await repo_query( """ SELECT * FROM fn::vector_search($embed, $results, $source, $note, $minimum_score); """, { "embed": embed, "results": results, "source": source, "note": note, "minimum_score": minimum_score, }, ) return search_results except Exception as e: logger.error(f"Error performing vector search: {str(e)}") logger.exception(e) raise DatabaseOperationError(e)