rename doc_query tool

This commit is contained in:
LUIS NOVO 2024-10-22 16:36:21 -03:00
parent e0c3fe26c8
commit 809ecb45e1
2 changed files with 18 additions and 14 deletions

View file

@ -3,16 +3,16 @@ import os
from langchain_core.runnables import (
RunnableConfig,
)
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, StateGraph
from loguru import logger
from typing_extensions import TypedDict
from open_notebook.domain import Note, Notebook, Source
from open_notebook.model_configs import get_langchain_model
from open_notebook.prompter import Prompter
class AskState(TypedDict):
class DocQueryState(TypedDict):
doc_id: str
doc_content: str
question: str
@ -20,11 +20,13 @@ class AskState(TypedDict):
notebook: Notebook
def call_model_with_messages(state: AskState, config: RunnableConfig) -> dict:
model = ChatOpenAI(
model=os.environ.get("RETRIEVAL_MODEL", os.environ["DEFAULT_MODEL"]),
temperature=0,
)
def call_model_with_messages(state: DocQueryState, config: RunnableConfig) -> dict:
if config.get("configurable", {}).get("model_name", None):
model_name = config.get("configurable", {}).get("model_name", None)
else:
model_name = os.environ.get("RETRIEVAL_MODEL", os.environ["DEFAULT_MODEL"])
model = get_langchain_model(model_name)
system_prompt = Prompter(prompt_template="ask_content").render(data=state)
logger.debug(f"System prompt: {system_prompt}")
ai_message = model.invoke(system_prompt)
@ -32,7 +34,7 @@ def call_model_with_messages(state: AskState, config: RunnableConfig) -> dict:
# todo: there is probably a better way to do this and avoid repetition
def get_content(state: AskState) -> dict:
def get_content(state: DocQueryState) -> dict:
doc_id = state["doc_id"]
if "note:" in doc_id:
doc: Note = Note.get(id=doc_id)
@ -42,7 +44,7 @@ def get_content(state: AskState) -> dict:
return {"doc_content": doc_content}
agent_state = StateGraph(AskState)
agent_state = StateGraph(DocQueryState)
agent_state.add_node("get_content", get_content)
agent_state.add_node("agent", call_model_with_messages)
agent_state.add_edge(START, "get_content")

View file

@ -6,19 +6,21 @@ from langchain.tools import tool
@tool
def get_current_timestamp() -> str:
"""
name: get_current_timestamp
Returns the current timestamp in the format YYYYMMDDHHmmss.
"""
return datetime.now().strftime("%Y%m%d%H%M%S")
@tool
def ask_the_document(doc_id: str, question: str):
def doc_query(doc_id: str, question: str):
"""
Use this tool to ask a question to the document.
Another LLM will ready the document and answer the question.
Be specific and complete in your query given the LLM that will process it is very capable.
name: doc_query
Use this tool if you need to investigate into a particular document.
Another LLM will read the document and answer the question that you might have.
Use this when the user question cannot be answered with the content you have in context.
"""
from open_notebook.graphs.ask_content import graph
from open_notebook.graphs.doc_query import graph
result = graph.invoke({"doc_id": doc_id, "question": question})
return result["answer"]