mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-04-30 12:30:01 +00:00
rename doc_query tool
This commit is contained in:
parent
e0c3fe26c8
commit
809ecb45e1
2 changed files with 18 additions and 14 deletions
|
|
@ -3,16 +3,16 @@ import os
|
|||
from langchain_core.runnables import (
|
||||
RunnableConfig,
|
||||
)
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from loguru import logger
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from open_notebook.domain import Note, Notebook, Source
|
||||
from open_notebook.model_configs import get_langchain_model
|
||||
from open_notebook.prompter import Prompter
|
||||
|
||||
|
||||
class AskState(TypedDict):
|
||||
class DocQueryState(TypedDict):
|
||||
doc_id: str
|
||||
doc_content: str
|
||||
question: str
|
||||
|
|
@ -20,11 +20,13 @@ class AskState(TypedDict):
|
|||
notebook: Notebook
|
||||
|
||||
|
||||
def call_model_with_messages(state: AskState, config: RunnableConfig) -> dict:
|
||||
model = ChatOpenAI(
|
||||
model=os.environ.get("RETRIEVAL_MODEL", os.environ["DEFAULT_MODEL"]),
|
||||
temperature=0,
|
||||
)
|
||||
def call_model_with_messages(state: DocQueryState, config: RunnableConfig) -> dict:
|
||||
if config.get("configurable", {}).get("model_name", None):
|
||||
model_name = config.get("configurable", {}).get("model_name", None)
|
||||
else:
|
||||
model_name = os.environ.get("RETRIEVAL_MODEL", os.environ["DEFAULT_MODEL"])
|
||||
|
||||
model = get_langchain_model(model_name)
|
||||
system_prompt = Prompter(prompt_template="ask_content").render(data=state)
|
||||
logger.debug(f"System prompt: {system_prompt}")
|
||||
ai_message = model.invoke(system_prompt)
|
||||
|
|
@ -32,7 +34,7 @@ def call_model_with_messages(state: AskState, config: RunnableConfig) -> dict:
|
|||
|
||||
|
||||
# todo: there is probably a better way to do this and avoid repetition
|
||||
def get_content(state: AskState) -> dict:
|
||||
def get_content(state: DocQueryState) -> dict:
|
||||
doc_id = state["doc_id"]
|
||||
if "note:" in doc_id:
|
||||
doc: Note = Note.get(id=doc_id)
|
||||
|
|
@ -42,7 +44,7 @@ def get_content(state: AskState) -> dict:
|
|||
return {"doc_content": doc_content}
|
||||
|
||||
|
||||
agent_state = StateGraph(AskState)
|
||||
agent_state = StateGraph(DocQueryState)
|
||||
agent_state.add_node("get_content", get_content)
|
||||
agent_state.add_node("agent", call_model_with_messages)
|
||||
agent_state.add_edge(START, "get_content")
|
||||
|
|
@ -6,19 +6,21 @@ from langchain.tools import tool
|
|||
@tool
|
||||
def get_current_timestamp() -> str:
|
||||
"""
|
||||
name: get_current_timestamp
|
||||
Returns the current timestamp in the format YYYYMMDDHHmmss.
|
||||
"""
|
||||
return datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
|
||||
|
||||
@tool
|
||||
def ask_the_document(doc_id: str, question: str):
|
||||
def doc_query(doc_id: str, question: str):
|
||||
"""
|
||||
Use this tool to ask a question to the document.
|
||||
Another LLM will ready the document and answer the question.
|
||||
Be specific and complete in your query given the LLM that will process it is very capable.
|
||||
name: doc_query
|
||||
Use this tool if you need to investigate into a particular document.
|
||||
Another LLM will read the document and answer the question that you might have.
|
||||
Use this when the user question cannot be answered with the content you have in context.
|
||||
"""
|
||||
from open_notebook.graphs.ask_content import graph
|
||||
from open_notebook.graphs.doc_query import graph
|
||||
|
||||
result = graph.invoke({"doc_id": doc_id, "question": question})
|
||||
return result["answer"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue