add model router and improve prompts

This commit is contained in:
LUIS NOVO 2024-10-22 18:24:24 -03:00
parent f96fc580b3
commit 9042b08ae3
13 changed files with 173 additions and 236 deletions

View file

@ -4,12 +4,10 @@ from langchain_core.runnables import (
RunnableConfig,
)
from langgraph.graph import END, START, StateGraph
from loguru import logger
from typing_extensions import TypedDict
from open_notebook.domain import Note, Notebook, Source
from open_notebook.model_configs import get_langchain_model
from open_notebook.prompter import Prompter
from open_notebook.graphs.utils import run_pattern
class DocQueryState(TypedDict):
@ -20,17 +18,11 @@ class DocQueryState(TypedDict):
notebook: Notebook
def call_model_with_messages(state: DocQueryState, config: RunnableConfig) -> dict:
if config.get("configurable", {}).get("model_name", None):
model_name = config.get("configurable", {}).get("model_name", None)
else:
model_name = os.environ.get("RETRIEVAL_MODEL", os.environ["DEFAULT_MODEL"])
model = get_langchain_model(model_name)
system_prompt = Prompter(prompt_template="doc_query").render(data=state)
logger.debug(f"System prompt: {system_prompt}")
ai_message = model.invoke(system_prompt)
return {"answer": ai_message}
def call_model(state: dict, config: RunnableConfig) -> dict:
model_name = config.get("configurable", {}).get(
"model_name", os.environ.get("RETRIEVAL_MODEL")
)
return {"answer": run_pattern("doc_query", model_name, state)}
# todo: there is probably a better way to do this and avoid repetition
@ -46,7 +38,7 @@ def get_content(state: DocQueryState) -> dict:
agent_state = StateGraph(DocQueryState)
agent_state.add_node("get_content", get_content)
agent_state.add_node("agent", call_model_with_messages)
agent_state.add_node("agent", call_model)
agent_state.add_edge(START, "get_content")
agent_state.add_edge("get_content", "agent")
agent_state.add_edge("agent", END)