import asyncio from concurrent.futures import ThreadPoolExecutor from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple from loguru import logger from pydantic import BaseModel, Field, field_validator from open_notebook.database.repository import ensure_record_id, repo_query from open_notebook.domain.base import ObjectModel from open_notebook.domain.models import model_manager from open_notebook.exceptions import DatabaseOperationError, InvalidInputError from open_notebook.utils import split_text class Notebook(ObjectModel): table_name: ClassVar[str] = "notebook" name: str description: str archived: Optional[bool] = False @field_validator("name") @classmethod def name_must_not_be_empty(cls, v): if not v.strip(): raise InvalidInputError("Notebook name cannot be empty") return v async def get_sources(self) -> List["Source"]: try: srcs = await repo_query( """ select * omit source.full_text from ( select in as source from reference where out=$id fetch source ) order by source.updated desc """, {"id": ensure_record_id(self.id)}, ) return [Source(**src["source"]) for src in srcs] if srcs else [] except Exception as e: logger.error(f"Error fetching sources for notebook {self.id}: {str(e)}") logger.exception(e) raise DatabaseOperationError(e) async def get_notes(self) -> List["Note"]: try: srcs = await repo_query( """ select * omit note.content, note.embedding from ( select in as note from artifact where out=$id fetch note ) order by note.updated desc """, {"id": ensure_record_id(self.id)}, ) return [Note(**src["note"]) for src in srcs] if srcs else [] except Exception as e: logger.error(f"Error fetching notes for notebook {self.id}: {str(e)}") logger.exception(e) raise DatabaseOperationError(e) async def get_chat_sessions(self) -> List["ChatSession"]: try: srcs = await repo_query( """ select * from ( select <- chat_session as chat_session from refers_to where out=$id fetch chat_session ) order by chat_session.updated desc """, {"id": ensure_record_id(self.id)}, ) return ( [ChatSession(**src["chat_session"][0]) for src in srcs] if srcs else [] ) except Exception as e: logger.error( f"Error fetching chat sessions for notebook {self.id}: {str(e)}" ) logger.exception(e) raise DatabaseOperationError(e) class Asset(BaseModel): file_path: Optional[str] = None url: Optional[str] = None class SourceEmbedding(ObjectModel): table_name: ClassVar[str] = "source_embedding" content: str async def get_source(self) -> "Source": try: src = await repo_query( """ select source.* from $id fetch source """, {"id": ensure_record_id(self.id)}, ) return Source(**src[0]["source"]) except Exception as e: logger.error(f"Error fetching source for embedding {self.id}: {str(e)}") logger.exception(e) raise DatabaseOperationError(e) class SourceInsight(ObjectModel): table_name: ClassVar[str] = "source_insight" insight_type: str content: str async def get_source(self) -> "Source": try: src = await repo_query( """ select source.* from $id fetch source """, {"id": ensure_record_id(self.id)}, ) return Source(**src[0]["source"]) except Exception as e: logger.error(f"Error fetching source for insight {self.id}: {str(e)}") logger.exception(e) raise DatabaseOperationError(e) async def save_as_note(self, notebook_id: str = None) -> Any: source = await self.get_source() note = Note( title=f"{self.insight_type} from source {source.title}", content=self.content, ) await note.save() if notebook_id: await note.add_to_notebook(notebook_id) return note class Source(ObjectModel): table_name: ClassVar[str] = "source" asset: Optional[Asset] = None title: Optional[str] = None topics: Optional[List[str]] = Field(default_factory=list) full_text: Optional[str] = None async def get_context( self, context_size: Literal["short", "long"] = "short" ) -> Dict[str, Any]: insights_list = await self.get_insights() insights = [insight.model_dump() for insight in insights_list] if context_size == "long": return dict( id=self.id, title=self.title, insights=insights, full_text=self.full_text, ) else: return dict(id=self.id, title=self.title, insights=insights) async def get_embedded_chunks(self) -> int: try: result = await repo_query( """ select count() as chunks from source_embedding where source=$id GROUP ALL """, {"id": ensure_record_id(self.id)}, ) if len(result) == 0: return 0 return result[0]["chunks"] except Exception as e: logger.error(f"Error fetching chunks count for source {self.id}: {str(e)}") logger.exception(e) raise DatabaseOperationError(f"Failed to count chunks for source: {str(e)}") async def get_insights(self) -> List[SourceInsight]: try: result = await repo_query( """ SELECT * FROM source_insight WHERE source=$id """, {"id": ensure_record_id(self.id)}, ) return [SourceInsight(**insight) for insight in result] except Exception as e: logger.error(f"Error fetching insights for source {self.id}: {str(e)}") logger.exception(e) raise DatabaseOperationError("Failed to fetch insights for source") async def add_to_notebook(self, notebook_id: str) -> Any: if not notebook_id: raise InvalidInputError("Notebook ID must be provided") return await self.relate("reference", notebook_id) async def vectorize(self) -> None: logger.info(f"Starting vectorization for source {self.id}") EMBEDDING_MODEL = await model_manager.get_embedding_model() try: if not self.full_text: logger.warning(f"No text to vectorize for source {self.id}") return chunks = split_text( self.full_text, ) chunk_count = len(chunks) logger.info(f"Split into {chunk_count} chunks for source {self.id}") if chunk_count == 0: logger.warning("No chunks created after splitting") return # Process chunks concurrently using async gather logger.info("Starting concurrent processing of chunks") async def process_chunk( idx: int, chunk: str ) -> Tuple[int, List[float], str]: logger.debug(f"Processing chunk {idx}/{chunk_count}") try: embedding = (await EMBEDDING_MODEL.aembed([chunk]))[0] cleaned_content = chunk logger.debug(f"Successfully processed chunk {idx}") return (idx, embedding, cleaned_content) except Exception as e: logger.error(f"Error processing chunk {idx}: {str(e)}") raise # Create tasks for all chunks and process them concurrently tasks = [process_chunk(idx, chunk) for idx, chunk in enumerate(chunks)] results = await asyncio.gather(*tasks) logger.info(f"Parallel processing complete. Got {len(results)} results") # Insert results in order (they're already ordered by index) for idx, embedding, content in results: logger.debug(f"Inserting chunk {idx} into database") await repo_query( """ CREATE source_embedding CONTENT { "source": $source_id, "order": $order, "content": $content, "embedding": $embedding, };""", { "source_id": ensure_record_id(self.id), "order": idx, "content": content, "embedding": embedding, }, ) logger.info(f"Vectorization complete for source {self.id}") except Exception as e: logger.error(f"Error vectorizing source {self.id}: {str(e)}") logger.exception(e) raise DatabaseOperationError(e) async def add_insight(self, insight_type: str, content: str) -> Any: EMBEDDING_MODEL = await model_manager.get_embedding_model() if not EMBEDDING_MODEL: logger.warning("No embedding model found. Insight will not be searchable.") if not insight_type or not content: raise InvalidInputError("Insight type and content must be provided") try: embedding = ( (await EMBEDDING_MODEL.aembed([content]))[0] if EMBEDDING_MODEL else [] ) return await repo_query( """ CREATE source_insight CONTENT { "source": $source_id, "insight_type": $insight_type, "content": $content, "embedding": $embedding, };""", { "source_id": ensure_record_id(self.id), "insight_type": insight_type, "content": content, "embedding": embedding, }, ) except Exception as e: logger.error(f"Error adding insight to source {self.id}: {str(e)}") raise # DatabaseOperationError(e) class Note(ObjectModel): table_name: ClassVar[str] = "note" title: Optional[str] = None note_type: Optional[Literal["human", "ai"]] = None content: Optional[str] = None @field_validator("content") @classmethod def content_must_not_be_empty(cls, v): if v is not None and not v.strip(): raise InvalidInputError("Note content cannot be empty") return v async def add_to_notebook(self, notebook_id: str) -> Any: if not notebook_id: raise InvalidInputError("Notebook ID must be provided") return await self.relate("artifact", notebook_id) def get_context( self, context_size: Literal["short", "long"] = "short" ) -> Dict[str, Any]: if context_size == "long": return dict(id=self.id, title=self.title, content=self.content) else: return dict( id=self.id, title=self.title, content=self.content[:100] if self.content else None, ) def needs_embedding(self) -> bool: return True def get_embedding_content(self) -> Optional[str]: return self.content class ChatSession(ObjectModel): table_name: ClassVar[str] = "chat_session" title: Optional[str] = None async def relate_to_notebook(self, notebook_id: str) -> Any: if not notebook_id: raise InvalidInputError("Notebook ID must be provided") return await self.relate("refers_to", notebook_id) async def text_search( keyword: str, results: int, source: bool = True, note: bool = True ): if not keyword: raise InvalidInputError("Search keyword cannot be empty") try: results = await repo_query( """ select * from fn::text_search($keyword, $results, $source, $note) """, {"keyword": keyword, "results": results, "source": source, "note": note}, ) return results except Exception as e: logger.error(f"Error performing text search: {str(e)}") logger.exception(e) raise DatabaseOperationError(e) async def vector_search( keyword: str, results: int, source: bool = True, note: bool = True, minimum_score=0.2, ): if not keyword: raise InvalidInputError("Search keyword cannot be empty") try: EMBEDDING_MODEL = await model_manager.get_embedding_model() embed = (await EMBEDDING_MODEL.aembed([keyword]))[0] results = await repo_query( """ SELECT * FROM fn::vector_search($embed, $results, $source, $note, $minimum_score); """, { "embed": embed, "results": results, "source": source, "note": note, "minimum_score": minimum_score, }, ) return results except Exception as e: logger.error(f"Error performing vector search: {str(e)}") logger.exception(e) raise DatabaseOperationError(e)