open-notebook/open_notebook/graphs/rag.py
2024-11-04 09:53:49 -03:00

44 lines
1.4 KiB
Python

from typing import Annotated
from langchain_core.runnables import (
RunnableConfig,
)
from langgraph.graph import START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from typing_extensions import TypedDict
from open_notebook.graphs.tools import repository_search
from open_notebook.graphs.utils import provision_langchain_model
from open_notebook.prompter import Prompter
tools = [repository_search]
tool_node = ToolNode(tools)
class ThreadState(TypedDict):
messages: Annotated[list, add_messages]
# notebook: Optional[Notebook]
# context: Optional[str]
# context_config: Optional[dict]
def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
system_prompt = Prompter(prompt_template="rag").render(data=state)
payload = [system_prompt] + state.get("messages", [])
model = provision_langchain_model(str(payload), config, "tools", max_tokens=2000)
model = model.bind_tools(tools)
ai_message = model.invoke(payload)
return {"messages": ai_message}
agent_state = StateGraph(ThreadState)
agent_state.add_node("agent", call_model_with_messages)
agent_state.add_node("tools", tool_node)
agent_state.add_edge(START, "agent")
agent_state.add_conditional_edges(
"agent",
tools_condition,
)
agent_state.add_edge("tools", "agent")
graph = agent_state.compile()