mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-04-29 03:50:04 +00:00
New front-end Launch Chat API Manage Sources Enable re-embedding of all contents Sources can be added without a notebook now Improved settings Enable model selector on all chats Background processing for better experience Dark mode Improved Notes Improved Docs: - Remove all Streamlit references from documentation - Update deployment guides with React frontend setup - Fix Docker environment variables format (SURREAL_URL, SURREAL_PASSWORD) - Update docker image tag from :latest to :v1-latest - Change navigation references (Settings → Models to just Models) - Update development setup to include frontend npm commands - Add MIGRATION.md guide for users upgrading from Streamlit - Update quick-start guide with correct environment variables - Add port 5055 documentation for API access - Update project structure to reflect frontend/ directory - Remove outdated source-chat documentation files
489 lines
17 KiB
Python
489 lines
17 KiB
Python
import asyncio
|
|
from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Union
|
|
|
|
from loguru import logger
|
|
from pydantic import BaseModel, Field, field_validator
|
|
from surrealdb import RecordID
|
|
|
|
from open_notebook.database.repository import ensure_record_id, repo_query
|
|
from open_notebook.domain.base import ObjectModel
|
|
from open_notebook.domain.models import model_manager
|
|
from open_notebook.exceptions import DatabaseOperationError, InvalidInputError
|
|
from open_notebook.utils import split_text
|
|
|
|
|
|
class Notebook(ObjectModel):
|
|
table_name: ClassVar[str] = "notebook"
|
|
name: str
|
|
description: str
|
|
archived: Optional[bool] = False
|
|
|
|
@field_validator("name")
|
|
@classmethod
|
|
def name_must_not_be_empty(cls, v):
|
|
if not v.strip():
|
|
raise InvalidInputError("Notebook name cannot be empty")
|
|
return v
|
|
|
|
async def get_sources(self) -> List["Source"]:
|
|
try:
|
|
srcs = await repo_query(
|
|
"""
|
|
select * omit source.full_text from (
|
|
select in as source from reference where out=$id
|
|
fetch source
|
|
) order by source.updated desc
|
|
""",
|
|
{"id": ensure_record_id(self.id)},
|
|
)
|
|
return [Source(**src["source"]) for src in srcs] if srcs else []
|
|
except Exception as e:
|
|
logger.error(f"Error fetching sources for notebook {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
async def get_notes(self) -> List["Note"]:
|
|
try:
|
|
srcs = await repo_query(
|
|
"""
|
|
select * omit note.content, note.embedding from (
|
|
select in as note from artifact where out=$id
|
|
fetch note
|
|
) order by note.updated desc
|
|
""",
|
|
{"id": ensure_record_id(self.id)},
|
|
)
|
|
return [Note(**src["note"]) for src in srcs] if srcs else []
|
|
except Exception as e:
|
|
logger.error(f"Error fetching notes for notebook {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
async def get_chat_sessions(self) -> List["ChatSession"]:
|
|
try:
|
|
srcs = await repo_query(
|
|
"""
|
|
select * from (
|
|
select
|
|
<- chat_session as chat_session
|
|
from refers_to
|
|
where out=$id
|
|
fetch chat_session
|
|
)
|
|
order by chat_session.updated desc
|
|
""",
|
|
{"id": ensure_record_id(self.id)},
|
|
)
|
|
return (
|
|
[ChatSession(**src["chat_session"][0]) for src in srcs] if srcs else []
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error fetching chat sessions for notebook {self.id}: {str(e)}"
|
|
)
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
|
|
class Asset(BaseModel):
|
|
file_path: Optional[str] = None
|
|
url: Optional[str] = None
|
|
|
|
|
|
class SourceEmbedding(ObjectModel):
|
|
table_name: ClassVar[str] = "source_embedding"
|
|
content: str
|
|
|
|
async def get_source(self) -> "Source":
|
|
try:
|
|
src = await repo_query(
|
|
"""
|
|
select source.* from $id fetch source
|
|
""",
|
|
{"id": ensure_record_id(self.id)},
|
|
)
|
|
return Source(**src[0]["source"])
|
|
except Exception as e:
|
|
logger.error(f"Error fetching source for embedding {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
|
|
class SourceInsight(ObjectModel):
|
|
table_name: ClassVar[str] = "source_insight"
|
|
insight_type: str
|
|
content: str
|
|
|
|
async def get_source(self) -> "Source":
|
|
try:
|
|
src = await repo_query(
|
|
"""
|
|
select source.* from $id fetch source
|
|
""",
|
|
{"id": ensure_record_id(self.id)},
|
|
)
|
|
return Source(**src[0]["source"])
|
|
except Exception as e:
|
|
logger.error(f"Error fetching source for insight {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
async def save_as_note(self, notebook_id: Optional[str] = None) -> Any:
|
|
source = await self.get_source()
|
|
note = Note(
|
|
title=f"{self.insight_type} from source {source.title}",
|
|
content=self.content,
|
|
)
|
|
await note.save()
|
|
if notebook_id:
|
|
await note.add_to_notebook(notebook_id)
|
|
return note
|
|
|
|
|
|
class Source(ObjectModel):
|
|
table_name: ClassVar[str] = "source"
|
|
asset: Optional[Asset] = None
|
|
title: Optional[str] = None
|
|
topics: Optional[List[str]] = Field(default_factory=list)
|
|
full_text: Optional[str] = None
|
|
command: Optional[Union[str, RecordID]] = Field(
|
|
default=None, description="Link to surreal-commands processing job"
|
|
)
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
@field_validator("command", mode="before")
|
|
@classmethod
|
|
def parse_command(cls, value):
|
|
"""Parse command field to ensure RecordID format"""
|
|
if isinstance(value, str) and value:
|
|
return ensure_record_id(value)
|
|
return value
|
|
|
|
@field_validator("id", mode="before")
|
|
@classmethod
|
|
def parse_id(cls, value):
|
|
"""Parse id field to handle both string and RecordID inputs"""
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, RecordID):
|
|
return str(value)
|
|
return str(value) if value else None
|
|
|
|
async def get_status(self) -> Optional[str]:
|
|
"""Get the processing status of the associated command"""
|
|
if not self.command:
|
|
return None
|
|
|
|
try:
|
|
from surreal_commands import get_command_status
|
|
|
|
status = await get_command_status(str(self.command))
|
|
return status.status if status else "unknown"
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get command status for {self.command}: {e}")
|
|
return "unknown"
|
|
|
|
async def get_processing_progress(self) -> Optional[Dict[str, Any]]:
|
|
"""Get detailed processing information for the associated command"""
|
|
if not self.command:
|
|
return None
|
|
|
|
try:
|
|
from surreal_commands import get_command_status
|
|
|
|
status_result = await get_command_status(str(self.command))
|
|
if not status_result:
|
|
return None
|
|
|
|
# Extract execution metadata if available
|
|
result = getattr(status_result, "result", None)
|
|
execution_metadata = result.get("execution_metadata", {}) if isinstance(result, dict) else {}
|
|
|
|
return {
|
|
"status": status_result.status,
|
|
"started_at": execution_metadata.get("started_at"),
|
|
"completed_at": execution_metadata.get("completed_at"),
|
|
"error": getattr(status_result, "error_message", None),
|
|
"result": result,
|
|
}
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get command progress for {self.command}: {e}")
|
|
return None
|
|
|
|
async def get_context(
|
|
self, context_size: Literal["short", "long"] = "short"
|
|
) -> Dict[str, Any]:
|
|
insights_list = await self.get_insights()
|
|
insights = [insight.model_dump() for insight in insights_list]
|
|
if context_size == "long":
|
|
return dict(
|
|
id=self.id,
|
|
title=self.title,
|
|
insights=insights,
|
|
full_text=self.full_text,
|
|
)
|
|
else:
|
|
return dict(id=self.id, title=self.title, insights=insights)
|
|
|
|
async def get_embedded_chunks(self) -> int:
|
|
try:
|
|
result = await repo_query(
|
|
"""
|
|
select count() as chunks from source_embedding where source=$id GROUP ALL
|
|
""",
|
|
{"id": ensure_record_id(self.id)},
|
|
)
|
|
if len(result) == 0:
|
|
return 0
|
|
return result[0]["chunks"]
|
|
except Exception as e:
|
|
logger.error(f"Error fetching chunks count for source {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(f"Failed to count chunks for source: {str(e)}")
|
|
|
|
async def get_insights(self) -> List[SourceInsight]:
|
|
try:
|
|
result = await repo_query(
|
|
"""
|
|
SELECT * FROM source_insight WHERE source=$id
|
|
""",
|
|
{"id": ensure_record_id(self.id)},
|
|
)
|
|
return [SourceInsight(**insight) for insight in result]
|
|
except Exception as e:
|
|
logger.error(f"Error fetching insights for source {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError("Failed to fetch insights for source")
|
|
|
|
async def add_to_notebook(self, notebook_id: str) -> Any:
|
|
if not notebook_id:
|
|
raise InvalidInputError("Notebook ID must be provided")
|
|
return await self.relate("reference", notebook_id)
|
|
|
|
async def vectorize(self) -> None:
|
|
logger.info(f"Starting vectorization for source {self.id}")
|
|
EMBEDDING_MODEL = await model_manager.get_embedding_model()
|
|
|
|
try:
|
|
# DELETE EXISTING EMBEDDINGS FIRST - Makes vectorize() idempotent
|
|
delete_result = await repo_query(
|
|
"DELETE source_embedding WHERE source = $source_id",
|
|
{"source_id": ensure_record_id(self.id)}
|
|
)
|
|
deleted_count = len(delete_result) if delete_result else 0
|
|
if deleted_count > 0:
|
|
logger.info(f"Deleted {deleted_count} existing embeddings for source {self.id}")
|
|
else:
|
|
logger.debug(f"No existing embeddings found for source {self.id}")
|
|
|
|
if not self.full_text:
|
|
logger.warning(f"No text to vectorize for source {self.id}")
|
|
return
|
|
|
|
chunks = split_text(
|
|
self.full_text,
|
|
)
|
|
chunk_count = len(chunks)
|
|
logger.info(f"Split into {chunk_count} chunks for source {self.id}")
|
|
|
|
if chunk_count == 0:
|
|
logger.warning("No chunks created after splitting")
|
|
return
|
|
|
|
# Process chunks concurrently using async gather
|
|
logger.info("Starting concurrent processing of chunks")
|
|
|
|
async def process_chunk(
|
|
idx: int, chunk: str
|
|
) -> Tuple[int, List[float], str]:
|
|
logger.debug(f"Processing chunk {idx}/{chunk_count}")
|
|
try:
|
|
if EMBEDDING_MODEL is None:
|
|
raise ValueError("EMBEDDING_MODEL is not configured")
|
|
embedding = (await EMBEDDING_MODEL.aembed([chunk]))[0]
|
|
cleaned_content = chunk
|
|
logger.debug(f"Successfully processed chunk {idx}")
|
|
return (idx, embedding, cleaned_content)
|
|
except Exception as e:
|
|
logger.error(f"Error processing chunk {idx}: {str(e)}")
|
|
raise
|
|
|
|
# Create tasks for all chunks and process them concurrently
|
|
tasks = [process_chunk(idx, chunk) for idx, chunk in enumerate(chunks)]
|
|
results = await asyncio.gather(*tasks)
|
|
|
|
logger.info(f"Parallel processing complete. Got {len(results)} results")
|
|
|
|
# Insert results in order (they're already ordered by index)
|
|
for idx, embedding, content in results:
|
|
logger.debug(f"Inserting chunk {idx} into database")
|
|
await repo_query(
|
|
"""
|
|
CREATE source_embedding CONTENT {
|
|
"source": $source_id,
|
|
"order": $order,
|
|
"content": $content,
|
|
"embedding": $embedding,
|
|
};""",
|
|
{
|
|
"source_id": ensure_record_id(self.id),
|
|
"order": idx,
|
|
"content": content,
|
|
"embedding": embedding,
|
|
},
|
|
)
|
|
|
|
logger.info(f"Vectorization complete for source {self.id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error vectorizing source {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
async def add_insight(self, insight_type: str, content: str) -> Any:
|
|
EMBEDDING_MODEL = await model_manager.get_embedding_model()
|
|
if not EMBEDDING_MODEL:
|
|
logger.warning("No embedding model found. Insight will not be searchable.")
|
|
|
|
if not insight_type or not content:
|
|
raise InvalidInputError("Insight type and content must be provided")
|
|
try:
|
|
embedding = (
|
|
(await EMBEDDING_MODEL.aembed([content]))[0] if EMBEDDING_MODEL else []
|
|
)
|
|
return await repo_query(
|
|
"""
|
|
CREATE source_insight CONTENT {
|
|
"source": $source_id,
|
|
"insight_type": $insight_type,
|
|
"content": $content,
|
|
"embedding": $embedding,
|
|
};""",
|
|
{
|
|
"source_id": ensure_record_id(self.id),
|
|
"insight_type": insight_type,
|
|
"content": content,
|
|
"embedding": embedding,
|
|
},
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error adding insight to source {self.id}: {str(e)}")
|
|
raise # DatabaseOperationError(e)
|
|
|
|
def _prepare_save_data(self) -> dict:
|
|
"""Override to ensure command field is always RecordID format for database"""
|
|
data = super()._prepare_save_data()
|
|
|
|
# Ensure command field is RecordID format if not None
|
|
if data.get("command") is not None:
|
|
data["command"] = ensure_record_id(data["command"])
|
|
|
|
return data
|
|
|
|
|
|
class Note(ObjectModel):
|
|
table_name: ClassVar[str] = "note"
|
|
title: Optional[str] = None
|
|
note_type: Optional[Literal["human", "ai"]] = None
|
|
content: Optional[str] = None
|
|
|
|
@field_validator("content")
|
|
@classmethod
|
|
def content_must_not_be_empty(cls, v):
|
|
if v is not None and not v.strip():
|
|
raise InvalidInputError("Note content cannot be empty")
|
|
return v
|
|
|
|
async def add_to_notebook(self, notebook_id: str) -> Any:
|
|
if not notebook_id:
|
|
raise InvalidInputError("Notebook ID must be provided")
|
|
return await self.relate("artifact", notebook_id)
|
|
|
|
def get_context(
|
|
self, context_size: Literal["short", "long"] = "short"
|
|
) -> Dict[str, Any]:
|
|
if context_size == "long":
|
|
return dict(id=self.id, title=self.title, content=self.content)
|
|
else:
|
|
return dict(
|
|
id=self.id,
|
|
title=self.title,
|
|
content=self.content[:100] if self.content else None,
|
|
)
|
|
|
|
def needs_embedding(self) -> bool:
|
|
return True
|
|
|
|
def get_embedding_content(self) -> Optional[str]:
|
|
return self.content
|
|
|
|
|
|
class ChatSession(ObjectModel):
|
|
table_name: ClassVar[str] = "chat_session"
|
|
title: Optional[str] = None
|
|
model_override: Optional[str] = None
|
|
|
|
async def relate_to_notebook(self, notebook_id: str) -> Any:
|
|
if not notebook_id:
|
|
raise InvalidInputError("Notebook ID must be provided")
|
|
return await self.relate("refers_to", notebook_id)
|
|
|
|
async def relate_to_source(self, source_id: str) -> Any:
|
|
if not source_id:
|
|
raise InvalidInputError("Source ID must be provided")
|
|
return await self.relate("refers_to", source_id)
|
|
|
|
|
|
async def text_search(
|
|
keyword: str, results: int, source: bool = True, note: bool = True
|
|
):
|
|
if not keyword:
|
|
raise InvalidInputError("Search keyword cannot be empty")
|
|
try:
|
|
search_results = await repo_query(
|
|
"""
|
|
select *
|
|
from fn::text_search($keyword, $results, $source, $note)
|
|
""",
|
|
{"keyword": keyword, "results": results, "source": source, "note": note},
|
|
)
|
|
return search_results
|
|
except Exception as e:
|
|
logger.error(f"Error performing text search: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
|
|
async def vector_search(
|
|
keyword: str,
|
|
results: int,
|
|
source: bool = True,
|
|
note: bool = True,
|
|
minimum_score=0.2,
|
|
):
|
|
if not keyword:
|
|
raise InvalidInputError("Search keyword cannot be empty")
|
|
try:
|
|
EMBEDDING_MODEL = await model_manager.get_embedding_model()
|
|
if EMBEDDING_MODEL is None:
|
|
raise ValueError("EMBEDDING_MODEL is not configured")
|
|
embed = (await EMBEDDING_MODEL.aembed([keyword]))[0]
|
|
search_results = await repo_query(
|
|
"""
|
|
SELECT * FROM fn::vector_search($embed, $results, $source, $note, $minimum_score);
|
|
""",
|
|
{
|
|
"embed": embed,
|
|
"results": results,
|
|
"source": source,
|
|
"note": note,
|
|
"minimum_score": minimum_score,
|
|
},
|
|
)
|
|
return search_results
|
|
except Exception as e:
|
|
logger.error(f"Error performing vector search: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|