enable multiple chat sessions

This commit is contained in:
LUIS NOVO 2024-11-04 15:08:14 -03:00
parent 3be1ecae8a
commit 0f2216207b
8 changed files with 276 additions and 133 deletions

View file

@ -1,11 +1,9 @@
import os
from typing import Any, ClassVar, Dict, List, Literal, Optional
from loguru import logger
from pydantic import BaseModel, Field, field_validator
from open_notebook.database.repository import (
repo_create,
repo_query,
)
from open_notebook.domain.base import ObjectModel
@ -68,6 +66,27 @@ class Notebook(ObjectModel):
logger.exception(e)
raise DatabaseOperationError(e)
@property
def chat_sessions(self) -> List["ChatSession"]:
try:
srcs = repo_query(f"""
select * from (
select
<- chat_session as chat_session
from refers_to
where out={self.id}
fetch chat_session
)
order by chat_session.updated desc
""")
return (
[ChatSession(**src["chat_session"][0]) for src in srcs] if srcs else []
)
except Exception as e:
logger.error(f"Error fetching notes for notebook {self.id}: {str(e)}")
logger.exception(e)
raise DatabaseOperationError(e)
class Asset(BaseModel):
file_path: Optional[str] = None
@ -99,6 +118,22 @@ class Source(ObjectModel):
else:
return dict(id=self.id, title=self.title, insights=self.insights)
@property
def embedded_chunks(self) -> int:
try:
result = repo_query(
f"""
select count() as chunks from source_embedding where source={self.id} GROUP ALL
"""
)
if len(result) == 0:
return 0
return result[0]["chunks"]
except Exception as e:
logger.error(f"Error fetching insights for source {self.id}: {str(e)}")
logger.exception(e)
raise DatabaseOperationError(f"Failed to count chunks for source: {str(e)}")
@property
def insights(self) -> List[SourceInsight]:
try:
@ -118,24 +153,6 @@ class Source(ObjectModel):
raise InvalidInputError("Notebook ID must be provided")
return self.relate("reference", notebook_id)
def save_chunks(self, text: str) -> None:
if not text:
raise InvalidInputError("Text cannot be empty")
try:
chunks = split_text(text, chunk=500000, overlap=1000)
logger.debug(f"Split into {len(chunks)} chunks")
for i, chunk in enumerate(chunks):
logger.debug(f"Saving chunk {i}")
data = {"source": self.id, "order": i, "content": surreal_clean(chunk)}
repo_create(
"source_chunk",
data,
)
except Exception as e:
logger.exception(e)
logger.error(f"Error saving chunks for source {self.id}: {str(e)}")
raise DatabaseOperationError(e)
def vectorize(self) -> None:
EMBEDDING_MODEL = model_manager.embedding_model
@ -144,8 +161,6 @@ class Source(ObjectModel):
return
chunks = split_text(
self.full_text,
chunk=int(os.environ.get("EMBEDDING_CHUNK_SIZE", 1000)),
overlap=int(os.environ.get("EMBEDDING_CHUNK_OVERLAP", 1000)),
)
logger.debug(f"Split into {len(chunks)} chunks")
@ -166,26 +181,26 @@ class Source(ObjectModel):
logger.exception(e)
raise DatabaseOperationError(e)
@classmethod
def search(cls, query: str) -> List[Dict[str, Any]]:
if not query:
raise InvalidInputError("Search query cannot be empty")
try:
result = repo_query(
"""
SELECT * omit full_text
FROM source
WHERE string::lowercase(title) CONTAINS $query or title @@ $query
OR string::lowercase(summary) CONTAINS $query or summary @@ $query
OR string::lowercase(full_text) CONTAINS $query or full_text @@ $query
""",
{"query": query},
)
return result
except Exception as e:
logger.error(f"Error searching sources: {str(e)}")
logger.exception(e)
raise DatabaseOperationError("Failed to search sources")
# @classmethod
# def search(cls, query: str) -> List[Dict[str, Any]]:
# if not query:
# raise InvalidInputError("Search query cannot be empty")
# try:
# result = repo_query(
# """
# SELECT * omit full_text
# FROM source
# WHERE string::lowercase(title) CONTAINS $query or title @@ $query
# OR string::lowercase(summary) CONTAINS $query or summary @@ $query
# OR string::lowercase(full_text) CONTAINS $query or full_text @@ $query
# """,
# {"query": query},
# )
# return result
# except Exception as e:
# logger.error(f"Error searching sources: {str(e)}")
# logger.exception(e)
# raise DatabaseOperationError("Failed to search sources")
def add_insight(self, insight_type: str, content: str) -> Any:
EMBEDDING_MODEL = model_manager.embedding_model
@ -246,6 +261,16 @@ class Note(ObjectModel):
return self.content
class ChatSession(ObjectModel):
table_name: ClassVar[str] = "chat_session"
title: Optional[str] = None
def relate_to_notebook(self, notebook_id: str) -> Any:
if not notebook_id:
raise InvalidInputError("Notebook ID must be provided")
return self.relate("refers_to", notebook_id)
def text_search(keyword: str, results: int, source: bool = True, note: bool = True):
if not keyword:
raise InvalidInputError("Search keyword cannot be empty")
@ -263,18 +288,6 @@ def text_search(keyword: str, results: int, source: bool = True, note: bool = Tr
raise DatabaseOperationError(e)
# def hybrid_search(
# keyword_search: List[str],
# embed_search: List[str],
# results: int = 50,
# source: bool = True,
# note: bool = True,
# ):
# EMBEDDING_MODEL = model_manager.embedding_model
# embed1_vector = EMBEDDING_MODEL.embed(embed1) if embed1 else None
# todo: mover o embedding pra ca
def vector_search(keyword: str, results: int, source: bool = True, note: bool = True):
if not keyword:
raise InvalidInputError("Search keyword cannot be empty")