mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-04-29 12:00:00 +00:00
add model router and improve prompts
This commit is contained in:
parent
f96fc580b3
commit
9042b08ae3
13 changed files with 173 additions and 236 deletions
|
|
@ -4,12 +4,10 @@ from langchain_core.runnables import (
|
|||
RunnableConfig,
|
||||
)
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from loguru import logger
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from open_notebook.domain import Note, Notebook, Source
|
||||
from open_notebook.model_configs import get_langchain_model
|
||||
from open_notebook.prompter import Prompter
|
||||
from open_notebook.graphs.utils import run_pattern
|
||||
|
||||
|
||||
class DocQueryState(TypedDict):
|
||||
|
|
@ -20,17 +18,11 @@ class DocQueryState(TypedDict):
|
|||
notebook: Notebook
|
||||
|
||||
|
||||
def call_model_with_messages(state: DocQueryState, config: RunnableConfig) -> dict:
|
||||
if config.get("configurable", {}).get("model_name", None):
|
||||
model_name = config.get("configurable", {}).get("model_name", None)
|
||||
else:
|
||||
model_name = os.environ.get("RETRIEVAL_MODEL", os.environ["DEFAULT_MODEL"])
|
||||
|
||||
model = get_langchain_model(model_name)
|
||||
system_prompt = Prompter(prompt_template="doc_query").render(data=state)
|
||||
logger.debug(f"System prompt: {system_prompt}")
|
||||
ai_message = model.invoke(system_prompt)
|
||||
return {"answer": ai_message}
|
||||
def call_model(state: dict, config: RunnableConfig) -> dict:
|
||||
model_name = config.get("configurable", {}).get(
|
||||
"model_name", os.environ.get("RETRIEVAL_MODEL")
|
||||
)
|
||||
return {"answer": run_pattern("doc_query", model_name, state)}
|
||||
|
||||
|
||||
# todo: there is probably a better way to do this and avoid repetition
|
||||
|
|
@ -46,7 +38,7 @@ def get_content(state: DocQueryState) -> dict:
|
|||
|
||||
agent_state = StateGraph(DocQueryState)
|
||||
agent_state.add_node("get_content", get_content)
|
||||
agent_state.add_node("agent", call_model_with_messages)
|
||||
agent_state.add_node("agent", call_model)
|
||||
agent_state.add_edge(START, "get_content")
|
||||
agent_state.add_edge("get_content", "agent")
|
||||
agent_state.add_edge("agent", END)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue