open-notebook/open_notebook/domain/notebook.py
2024-11-05 16:55:59 -03:00

458 lines
16 KiB
Python

from typing import Any, ClassVar, Dict, List, Literal, Optional
from loguru import logger
from pydantic import BaseModel, Field, field_validator
from open_notebook.database.repository import (
repo_query,
)
from open_notebook.domain.base import ObjectModel
from open_notebook.domain.models import model_manager
from open_notebook.exceptions import (
DatabaseOperationError,
InvalidInputError,
)
from open_notebook.utils import split_text, surreal_clean
class Notebook(ObjectModel):
table_name: ClassVar[str] = "notebook"
name: str
description: str
archived: Optional[bool] = False
@field_validator("name")
@classmethod
def name_must_not_be_empty(cls, v):
if not v.strip():
raise InvalidInputError("Notebook name cannot be empty")
return v
@property
def sources(self) -> List["Source"]:
try:
srcs = repo_query(f"""
select * OMIT full_text from (
select
<- source as source
from reference
where out={self.id}
fetch source
)
order by source.updated desc
""")
return [Source(**src["source"][0]) for src in srcs] if srcs else []
except Exception as e:
logger.error(f"Error fetching sources for notebook {self.id}: {str(e)}")
logger.exception(e)
raise DatabaseOperationError(e)
@property
def notes(self) -> List["Note"]:
try:
srcs = repo_query(f"""
select * OMIT content from (
select
<- note as note
from artifact
where out={self.id}
fetch note
)
order by updated desc
""")
return [Note(**src["note"][0]) for src in srcs] if srcs else []
except Exception as e:
logger.error(f"Error fetching notes for notebook {self.id}: {str(e)}")
logger.exception(e)
raise DatabaseOperationError(e)
@property
def chat_sessions(self) -> List["ChatSession"]:
try:
srcs = repo_query(f"""
select * from (
select
<- chat_session as chat_session
from refers_to
where out={self.id}
fetch chat_session
)
order by chat_session.updated desc
""")
return (
[ChatSession(**src["chat_session"][0]) for src in srcs] if srcs else []
)
except Exception as e:
logger.error(f"Error fetching notes for notebook {self.id}: {str(e)}")
logger.exception(e)
raise DatabaseOperationError(e)
class Asset(BaseModel):
file_path: Optional[str] = None
url: Optional[str] = None
class SourceEmbedding(ObjectModel):
table_name: ClassVar[str] = "source_embedding"
content: str
@property
def source(self) -> "Source":
try:
src = repo_query(f"""
select source.* from {self.id} fetch source
""")
return Source(**src[0]["source"])
except Exception as e:
logger.error(f"Error fetching source for embedding {self.id}: {str(e)}")
logger.exception(e)
raise DatabaseOperationError(e)
class SourceInsight(ObjectModel):
table_name: ClassVar[str] = "source_insight"
insight_type: str
content: str
@property
def source(self) -> "Source":
try:
src = repo_query(f"""
select source.* from {self.id} fetch source
""")
return Source(**src[0]["source"])
except Exception as e:
logger.error(f"Error fetching source for insight {self.id}: {str(e)}")
logger.exception(e)
raise DatabaseOperationError(e)
class Source(ObjectModel):
table_name: ClassVar[str] = "source"
asset: Optional[Asset] = None
title: Optional[str] = None
topics: Optional[List[str]] = Field(default_factory=list)
full_text: Optional[str] = None
def get_context(
self, context_size: Literal["short", "long"] = "short"
) -> Dict[str, Any]:
if context_size == "long":
return dict(
id=self.id,
title=self.title,
insights=[insight.model_dump() for insight in self.insights],
full_text=self.full_text,
)
else:
return dict(id=self.id, title=self.title, insights=self.insights)
@property
def embedded_chunks(self) -> int:
try:
result = repo_query(
f"""
select count() as chunks from source_embedding where source={self.id} GROUP ALL
"""
)
if len(result) == 0:
return 0
return result[0]["chunks"]
except Exception as e:
logger.error(f"Error fetching insights for source {self.id}: {str(e)}")
logger.exception(e)
raise DatabaseOperationError(f"Failed to count chunks for source: {str(e)}")
@property
def insights(self) -> List[SourceInsight]:
try:
result = repo_query(
f"""
SELECT * FROM source_insight WHERE source={self.id}
"""
)
return [SourceInsight(**insight) for insight in result]
except Exception as e:
logger.error(f"Error fetching insights for source {self.id}: {str(e)}")
logger.exception(e)
raise DatabaseOperationError("Failed to fetch insights for source")
def add_to_notebook(self, notebook_id: str) -> Any:
if not notebook_id:
raise InvalidInputError("Notebook ID must be provided")
return self.relate("reference", notebook_id)
def vectorize(self) -> None:
EMBEDDING_MODEL = model_manager.embedding_model
try:
if not self.full_text:
return
chunks = split_text(
self.full_text,
)
logger.debug(f"Split into {len(chunks)} chunks")
# future: we can increase the batch size after surreal launches their new SDK
for i, chunk in enumerate(chunks):
repo_query(
f"""
CREATE source_embedding CONTENT {{
"source": {self.id},
"order": {i},
"content": $content,
"embedding": {EMBEDDING_MODEL.embed(chunk)},
}};""",
{"content": surreal_clean(chunk)},
)
except Exception as e:
logger.error(f"Error vectorizing source {self.id}: {str(e)}")
logger.exception(e)
raise DatabaseOperationError(e)
# @classmethod
# def search(cls, query: str) -> List[Dict[str, Any]]:
# if not query:
# raise InvalidInputError("Search query cannot be empty")
# try:
# result = repo_query(
# """
# SELECT * omit full_text
# FROM source
# WHERE string::lowercase(title) CONTAINS $query or title @@ $query
# OR string::lowercase(summary) CONTAINS $query or summary @@ $query
# OR string::lowercase(full_text) CONTAINS $query or full_text @@ $query
# """,
# {"query": query},
# )
# return result
# except Exception as e:
# logger.error(f"Error searching sources: {str(e)}")
# logger.exception(e)
# raise DatabaseOperationError("Failed to search sources")
def add_insight(self, insight_type: str, content: str) -> Any:
EMBEDDING_MODEL = model_manager.embedding_model
if not insight_type or not content:
raise InvalidInputError("Insight type and content must be provided")
try:
embedding = EMBEDDING_MODEL.embed(content)
return repo_query(
f"""
CREATE source_insight CONTENT {{
"source": {self.id},
"insight_type": '{insight_type}',
"content": $content,
"embedding": {embedding},
}};""",
{"content": surreal_clean(content)},
)
except Exception as e:
logger.error(f"Error adding insight to source {self.id}: {str(e)}")
raise DatabaseOperationError(e)
class Note(ObjectModel):
table_name: ClassVar[str] = "note"
title: Optional[str] = None
note_type: Optional[Literal["human", "ai"]] = None
content: Optional[str] = None
@field_validator("content")
@classmethod
def content_must_not_be_empty(cls, v):
if v is not None and not v.strip():
raise InvalidInputError("Note content cannot be empty")
return v
def add_to_notebook(self, notebook_id: str) -> Any:
if not notebook_id:
raise InvalidInputError("Notebook ID must be provided")
return self.relate("artifact", notebook_id)
def get_context(
self, context_size: Literal["short", "long"] = "short"
) -> Dict[str, Any]:
if context_size == "long":
return dict(id=self.id, title=self.title, content=self.content)
else:
return dict(
id=self.id,
title=self.title,
content=self.content[:100] if self.content else None,
)
def needs_embedding(self) -> bool:
return True
def get_embedding_content(self) -> Optional[str]:
return self.content
class ChatSession(ObjectModel):
table_name: ClassVar[str] = "chat_session"
title: Optional[str] = None
def relate_to_notebook(self, notebook_id: str) -> Any:
if not notebook_id:
raise InvalidInputError("Notebook ID must be provided")
return self.relate("refers_to", notebook_id)
def text_search(keyword: str, results: int, source: bool = True, note: bool = True):
if not keyword:
raise InvalidInputError("Search keyword cannot be empty")
try:
results = repo_query(
"""
SELECT * FROM fn::text_search($keyword, $results, $source, $note);
""",
{"keyword": keyword, "results": results, "source": source, "note": note},
)
return results
except Exception as e:
logger.error(f"Error performing text search: {str(e)}")
logger.exception(e)
raise DatabaseOperationError(e)
def vector_search(keyword: str, results: int, source: bool = True, note: bool = True):
if not keyword:
raise InvalidInputError("Search keyword cannot be empty")
try:
EMBEDDING_MODEL = model_manager.embedding_model
embed = EMBEDDING_MODEL.embed(keyword)
results = repo_query(
"""
SELECT * FROM fn::vector_search($embed, $results, $source, $note, 0.15);
""",
{"embed": embed, "results": results, "source": source, "note": note},
)
return results
except Exception as e:
logger.error(f"Error performing vector search: {str(e)}")
logger.exception(e)
raise DatabaseOperationError(e)
def hybrid_search(
keyword_search: List[str],
embed_search: List[str],
results: int = 50,
source: bool = True,
note: bool = True,
max_chunks_per_doc: int = 3,
min_results_per_query: int = 3,
) -> Dict[str, List[Dict]]:
if not keyword_search and not embed_search:
raise InvalidInputError("At least one search term required")
# Process keyword searches
all_keyword_results = {} # Dictionary to store results per keyword
for keyword in keyword_search:
try:
search_results = text_search(keyword, results, source, note)
# Sort results by relevance
sorted_results = sorted(
search_results, key=lambda x: x.get("relevance", 0), reverse=True
)
# Group by parent_id and limit chunks per document
seen_parent_ids = {}
filtered_results = []
for result in sorted_results:
parent_id = result["parent_id"]
if parent_id not in seen_parent_ids:
seen_parent_ids[parent_id] = 1
filtered_results.append(result)
elif seen_parent_ids[parent_id] < max_chunks_per_doc:
seen_parent_ids[parent_id] += 1
filtered_results.append(result)
all_keyword_results[keyword] = filtered_results
except Exception as e:
logger.warning(f"Error in keyword search for term '{keyword}': {str(e)}")
continue
# Ensure minimum results from each keyword query
keyword_results = []
remaining_slots = results
# First pass: add minimum results from each query
for keyword, query_results in all_keyword_results.items():
keyword_results.extend(query_results[:min_results_per_query])
remaining_slots -= min(len(query_results), min_results_per_query)
# Second pass: fill remaining slots with best results
all_remaining = []
for keyword, query_results in all_keyword_results.items():
all_remaining.extend(query_results[min_results_per_query:])
# Sort remaining by relevance and add until we hit the limit
all_remaining = sorted(
all_remaining, key=lambda x: x.get("relevance", 0), reverse=True
)
seen_ids = {r["id"] for r in keyword_results}
for result in all_remaining:
if remaining_slots <= 0:
break
if result["id"] not in seen_ids:
keyword_results.append(result)
seen_ids.add(result["id"])
remaining_slots -= 1
# Process vector searches with the same approach
all_vector_results = {} # Dictionary to store results per embedding
for embed in embed_search:
try:
search_results = vector_search(embed, results, source, note)
# Sort results by similarity
sorted_results = sorted(
search_results, key=lambda x: x.get("similarity", 0), reverse=True
)
# Group by parent_id and limit chunks per document
seen_parent_ids = {}
filtered_results = []
for result in sorted_results:
parent_id = result["parent_id"]
if parent_id not in seen_parent_ids:
seen_parent_ids[parent_id] = 1
filtered_results.append(result)
elif seen_parent_ids[parent_id] < max_chunks_per_doc:
seen_parent_ids[parent_id] += 1
filtered_results.append(result)
all_vector_results[embed] = filtered_results
except Exception as e:
logger.warning(f"Error in vector search for term '{embed}': {str(e)}")
continue
# Ensure minimum results from each vector query
vector_results = []
remaining_slots = results
# First pass: add minimum results from each query
for embed, query_results in all_vector_results.items():
vector_results.extend(query_results[:min_results_per_query])
remaining_slots -= min(len(query_results), min_results_per_query)
# Second pass: fill remaining slots with best results
all_remaining = []
for embed, query_results in all_vector_results.items():
all_remaining.extend(query_results[min_results_per_query:])
# Sort remaining by similarity and add until we hit the limit
all_remaining = sorted(
all_remaining, key=lambda x: x.get("similarity", 0), reverse=True
)
seen_ids = {r["id"] for r in vector_results}
for result in all_remaining:
if remaining_slots <= 0:
break
if result["id"] not in seen_ids:
vector_results.append(result)
seen_ids.add(result["id"])
remaining_slots -= 1
return {"keyword_results": keyword_results, "vector_results": vector_results}