mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-04-29 03:50:04 +00:00
458 lines
16 KiB
Python
458 lines
16 KiB
Python
from typing import Any, ClassVar, Dict, List, Literal, Optional
|
|
|
|
from loguru import logger
|
|
from pydantic import BaseModel, Field, field_validator
|
|
|
|
from open_notebook.database.repository import (
|
|
repo_query,
|
|
)
|
|
from open_notebook.domain.base import ObjectModel
|
|
from open_notebook.domain.models import model_manager
|
|
from open_notebook.exceptions import (
|
|
DatabaseOperationError,
|
|
InvalidInputError,
|
|
)
|
|
from open_notebook.utils import split_text, surreal_clean
|
|
|
|
|
|
class Notebook(ObjectModel):
|
|
table_name: ClassVar[str] = "notebook"
|
|
name: str
|
|
description: str
|
|
archived: Optional[bool] = False
|
|
|
|
@field_validator("name")
|
|
@classmethod
|
|
def name_must_not_be_empty(cls, v):
|
|
if not v.strip():
|
|
raise InvalidInputError("Notebook name cannot be empty")
|
|
return v
|
|
|
|
@property
|
|
def sources(self) -> List["Source"]:
|
|
try:
|
|
srcs = repo_query(f"""
|
|
select * OMIT full_text from (
|
|
select
|
|
<- source as source
|
|
from reference
|
|
where out={self.id}
|
|
fetch source
|
|
)
|
|
order by source.updated desc
|
|
""")
|
|
return [Source(**src["source"][0]) for src in srcs] if srcs else []
|
|
except Exception as e:
|
|
logger.error(f"Error fetching sources for notebook {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
@property
|
|
def notes(self) -> List["Note"]:
|
|
try:
|
|
srcs = repo_query(f"""
|
|
select * OMIT content from (
|
|
select
|
|
<- note as note
|
|
from artifact
|
|
where out={self.id}
|
|
fetch note
|
|
)
|
|
order by updated desc
|
|
""")
|
|
return [Note(**src["note"][0]) for src in srcs] if srcs else []
|
|
except Exception as e:
|
|
logger.error(f"Error fetching notes for notebook {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
@property
|
|
def chat_sessions(self) -> List["ChatSession"]:
|
|
try:
|
|
srcs = repo_query(f"""
|
|
select * from (
|
|
select
|
|
<- chat_session as chat_session
|
|
from refers_to
|
|
where out={self.id}
|
|
fetch chat_session
|
|
)
|
|
order by chat_session.updated desc
|
|
""")
|
|
return (
|
|
[ChatSession(**src["chat_session"][0]) for src in srcs] if srcs else []
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error fetching notes for notebook {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
|
|
class Asset(BaseModel):
|
|
file_path: Optional[str] = None
|
|
url: Optional[str] = None
|
|
|
|
|
|
class SourceEmbedding(ObjectModel):
|
|
table_name: ClassVar[str] = "source_embedding"
|
|
content: str
|
|
|
|
@property
|
|
def source(self) -> "Source":
|
|
try:
|
|
src = repo_query(f"""
|
|
select source.* from {self.id} fetch source
|
|
|
|
""")
|
|
return Source(**src[0]["source"])
|
|
except Exception as e:
|
|
logger.error(f"Error fetching source for embedding {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
|
|
class SourceInsight(ObjectModel):
|
|
table_name: ClassVar[str] = "source_insight"
|
|
insight_type: str
|
|
content: str
|
|
|
|
@property
|
|
def source(self) -> "Source":
|
|
try:
|
|
src = repo_query(f"""
|
|
select source.* from {self.id} fetch source
|
|
|
|
""")
|
|
return Source(**src[0]["source"])
|
|
except Exception as e:
|
|
logger.error(f"Error fetching source for insight {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
|
|
class Source(ObjectModel):
|
|
table_name: ClassVar[str] = "source"
|
|
asset: Optional[Asset] = None
|
|
title: Optional[str] = None
|
|
topics: Optional[List[str]] = Field(default_factory=list)
|
|
full_text: Optional[str] = None
|
|
|
|
def get_context(
|
|
self, context_size: Literal["short", "long"] = "short"
|
|
) -> Dict[str, Any]:
|
|
if context_size == "long":
|
|
return dict(
|
|
id=self.id,
|
|
title=self.title,
|
|
insights=[insight.model_dump() for insight in self.insights],
|
|
full_text=self.full_text,
|
|
)
|
|
else:
|
|
return dict(id=self.id, title=self.title, insights=self.insights)
|
|
|
|
@property
|
|
def embedded_chunks(self) -> int:
|
|
try:
|
|
result = repo_query(
|
|
f"""
|
|
select count() as chunks from source_embedding where source={self.id} GROUP ALL
|
|
"""
|
|
)
|
|
if len(result) == 0:
|
|
return 0
|
|
return result[0]["chunks"]
|
|
except Exception as e:
|
|
logger.error(f"Error fetching insights for source {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(f"Failed to count chunks for source: {str(e)}")
|
|
|
|
@property
|
|
def insights(self) -> List[SourceInsight]:
|
|
try:
|
|
result = repo_query(
|
|
f"""
|
|
SELECT * FROM source_insight WHERE source={self.id}
|
|
"""
|
|
)
|
|
return [SourceInsight(**insight) for insight in result]
|
|
except Exception as e:
|
|
logger.error(f"Error fetching insights for source {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError("Failed to fetch insights for source")
|
|
|
|
def add_to_notebook(self, notebook_id: str) -> Any:
|
|
if not notebook_id:
|
|
raise InvalidInputError("Notebook ID must be provided")
|
|
return self.relate("reference", notebook_id)
|
|
|
|
def vectorize(self) -> None:
|
|
EMBEDDING_MODEL = model_manager.embedding_model
|
|
|
|
try:
|
|
if not self.full_text:
|
|
return
|
|
chunks = split_text(
|
|
self.full_text,
|
|
)
|
|
logger.debug(f"Split into {len(chunks)} chunks")
|
|
|
|
# future: we can increase the batch size after surreal launches their new SDK
|
|
for i, chunk in enumerate(chunks):
|
|
repo_query(
|
|
f"""
|
|
CREATE source_embedding CONTENT {{
|
|
"source": {self.id},
|
|
"order": {i},
|
|
"content": $content,
|
|
"embedding": {EMBEDDING_MODEL.embed(chunk)},
|
|
}};""",
|
|
{"content": surreal_clean(chunk)},
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error vectorizing source {self.id}: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
# @classmethod
|
|
# def search(cls, query: str) -> List[Dict[str, Any]]:
|
|
# if not query:
|
|
# raise InvalidInputError("Search query cannot be empty")
|
|
# try:
|
|
# result = repo_query(
|
|
# """
|
|
# SELECT * omit full_text
|
|
# FROM source
|
|
# WHERE string::lowercase(title) CONTAINS $query or title @@ $query
|
|
# OR string::lowercase(summary) CONTAINS $query or summary @@ $query
|
|
# OR string::lowercase(full_text) CONTAINS $query or full_text @@ $query
|
|
# """,
|
|
# {"query": query},
|
|
# )
|
|
# return result
|
|
# except Exception as e:
|
|
# logger.error(f"Error searching sources: {str(e)}")
|
|
# logger.exception(e)
|
|
# raise DatabaseOperationError("Failed to search sources")
|
|
|
|
def add_insight(self, insight_type: str, content: str) -> Any:
|
|
EMBEDDING_MODEL = model_manager.embedding_model
|
|
|
|
if not insight_type or not content:
|
|
raise InvalidInputError("Insight type and content must be provided")
|
|
try:
|
|
embedding = EMBEDDING_MODEL.embed(content)
|
|
return repo_query(
|
|
f"""
|
|
CREATE source_insight CONTENT {{
|
|
"source": {self.id},
|
|
"insight_type": '{insight_type}',
|
|
"content": $content,
|
|
"embedding": {embedding},
|
|
}};""",
|
|
{"content": surreal_clean(content)},
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error adding insight to source {self.id}: {str(e)}")
|
|
raise DatabaseOperationError(e)
|
|
|
|
|
|
class Note(ObjectModel):
|
|
table_name: ClassVar[str] = "note"
|
|
title: Optional[str] = None
|
|
note_type: Optional[Literal["human", "ai"]] = None
|
|
content: Optional[str] = None
|
|
|
|
@field_validator("content")
|
|
@classmethod
|
|
def content_must_not_be_empty(cls, v):
|
|
if v is not None and not v.strip():
|
|
raise InvalidInputError("Note content cannot be empty")
|
|
return v
|
|
|
|
def add_to_notebook(self, notebook_id: str) -> Any:
|
|
if not notebook_id:
|
|
raise InvalidInputError("Notebook ID must be provided")
|
|
return self.relate("artifact", notebook_id)
|
|
|
|
def get_context(
|
|
self, context_size: Literal["short", "long"] = "short"
|
|
) -> Dict[str, Any]:
|
|
if context_size == "long":
|
|
return dict(id=self.id, title=self.title, content=self.content)
|
|
else:
|
|
return dict(
|
|
id=self.id,
|
|
title=self.title,
|
|
content=self.content[:100] if self.content else None,
|
|
)
|
|
|
|
def needs_embedding(self) -> bool:
|
|
return True
|
|
|
|
def get_embedding_content(self) -> Optional[str]:
|
|
return self.content
|
|
|
|
|
|
class ChatSession(ObjectModel):
|
|
table_name: ClassVar[str] = "chat_session"
|
|
title: Optional[str] = None
|
|
|
|
def relate_to_notebook(self, notebook_id: str) -> Any:
|
|
if not notebook_id:
|
|
raise InvalidInputError("Notebook ID must be provided")
|
|
return self.relate("refers_to", notebook_id)
|
|
|
|
|
|
def text_search(keyword: str, results: int, source: bool = True, note: bool = True):
|
|
if not keyword:
|
|
raise InvalidInputError("Search keyword cannot be empty")
|
|
try:
|
|
results = repo_query(
|
|
"""
|
|
SELECT * FROM fn::text_search($keyword, $results, $source, $note);
|
|
""",
|
|
{"keyword": keyword, "results": results, "source": source, "note": note},
|
|
)
|
|
return results
|
|
except Exception as e:
|
|
logger.error(f"Error performing text search: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
|
|
def vector_search(keyword: str, results: int, source: bool = True, note: bool = True):
|
|
if not keyword:
|
|
raise InvalidInputError("Search keyword cannot be empty")
|
|
try:
|
|
EMBEDDING_MODEL = model_manager.embedding_model
|
|
embed = EMBEDDING_MODEL.embed(keyword)
|
|
results = repo_query(
|
|
"""
|
|
SELECT * FROM fn::vector_search($embed, $results, $source, $note, 0.15);
|
|
""",
|
|
{"embed": embed, "results": results, "source": source, "note": note},
|
|
)
|
|
return results
|
|
except Exception as e:
|
|
logger.error(f"Error performing vector search: {str(e)}")
|
|
logger.exception(e)
|
|
raise DatabaseOperationError(e)
|
|
|
|
|
|
def hybrid_search(
|
|
keyword_search: List[str],
|
|
embed_search: List[str],
|
|
results: int = 50,
|
|
source: bool = True,
|
|
note: bool = True,
|
|
max_chunks_per_doc: int = 3,
|
|
min_results_per_query: int = 3,
|
|
) -> Dict[str, List[Dict]]:
|
|
if not keyword_search and not embed_search:
|
|
raise InvalidInputError("At least one search term required")
|
|
|
|
# Process keyword searches
|
|
all_keyword_results = {} # Dictionary to store results per keyword
|
|
for keyword in keyword_search:
|
|
try:
|
|
search_results = text_search(keyword, results, source, note)
|
|
# Sort results by relevance
|
|
sorted_results = sorted(
|
|
search_results, key=lambda x: x.get("relevance", 0), reverse=True
|
|
)
|
|
# Group by parent_id and limit chunks per document
|
|
seen_parent_ids = {}
|
|
filtered_results = []
|
|
for result in sorted_results:
|
|
parent_id = result["parent_id"]
|
|
if parent_id not in seen_parent_ids:
|
|
seen_parent_ids[parent_id] = 1
|
|
filtered_results.append(result)
|
|
elif seen_parent_ids[parent_id] < max_chunks_per_doc:
|
|
seen_parent_ids[parent_id] += 1
|
|
filtered_results.append(result)
|
|
all_keyword_results[keyword] = filtered_results
|
|
except Exception as e:
|
|
logger.warning(f"Error in keyword search for term '{keyword}': {str(e)}")
|
|
continue
|
|
|
|
# Ensure minimum results from each keyword query
|
|
keyword_results = []
|
|
remaining_slots = results
|
|
|
|
# First pass: add minimum results from each query
|
|
for keyword, query_results in all_keyword_results.items():
|
|
keyword_results.extend(query_results[:min_results_per_query])
|
|
remaining_slots -= min(len(query_results), min_results_per_query)
|
|
|
|
# Second pass: fill remaining slots with best results
|
|
all_remaining = []
|
|
for keyword, query_results in all_keyword_results.items():
|
|
all_remaining.extend(query_results[min_results_per_query:])
|
|
|
|
# Sort remaining by relevance and add until we hit the limit
|
|
all_remaining = sorted(
|
|
all_remaining, key=lambda x: x.get("relevance", 0), reverse=True
|
|
)
|
|
seen_ids = {r["id"] for r in keyword_results}
|
|
for result in all_remaining:
|
|
if remaining_slots <= 0:
|
|
break
|
|
if result["id"] not in seen_ids:
|
|
keyword_results.append(result)
|
|
seen_ids.add(result["id"])
|
|
remaining_slots -= 1
|
|
|
|
# Process vector searches with the same approach
|
|
all_vector_results = {} # Dictionary to store results per embedding
|
|
for embed in embed_search:
|
|
try:
|
|
search_results = vector_search(embed, results, source, note)
|
|
# Sort results by similarity
|
|
sorted_results = sorted(
|
|
search_results, key=lambda x: x.get("similarity", 0), reverse=True
|
|
)
|
|
# Group by parent_id and limit chunks per document
|
|
seen_parent_ids = {}
|
|
filtered_results = []
|
|
for result in sorted_results:
|
|
parent_id = result["parent_id"]
|
|
if parent_id not in seen_parent_ids:
|
|
seen_parent_ids[parent_id] = 1
|
|
filtered_results.append(result)
|
|
elif seen_parent_ids[parent_id] < max_chunks_per_doc:
|
|
seen_parent_ids[parent_id] += 1
|
|
filtered_results.append(result)
|
|
all_vector_results[embed] = filtered_results
|
|
except Exception as e:
|
|
logger.warning(f"Error in vector search for term '{embed}': {str(e)}")
|
|
continue
|
|
|
|
# Ensure minimum results from each vector query
|
|
vector_results = []
|
|
remaining_slots = results
|
|
|
|
# First pass: add minimum results from each query
|
|
for embed, query_results in all_vector_results.items():
|
|
vector_results.extend(query_results[:min_results_per_query])
|
|
remaining_slots -= min(len(query_results), min_results_per_query)
|
|
|
|
# Second pass: fill remaining slots with best results
|
|
all_remaining = []
|
|
for embed, query_results in all_vector_results.items():
|
|
all_remaining.extend(query_results[min_results_per_query:])
|
|
|
|
# Sort remaining by similarity and add until we hit the limit
|
|
all_remaining = sorted(
|
|
all_remaining, key=lambda x: x.get("similarity", 0), reverse=True
|
|
)
|
|
seen_ids = {r["id"] for r in vector_results}
|
|
for result in all_remaining:
|
|
if remaining_slots <= 0:
|
|
break
|
|
if result["id"] not in seen_ids:
|
|
vector_results.append(result)
|
|
seen_ids.add(result["id"])
|
|
remaining_slots -= 1
|
|
|
|
return {"keyword_results": keyword_results, "vector_results": vector_results}
|