mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-05-05 23:37:58 +00:00
44 lines
1.4 KiB
Python
44 lines
1.4 KiB
Python
from typing import Annotated
|
|
|
|
from langchain_core.runnables import (
|
|
RunnableConfig,
|
|
)
|
|
from langgraph.graph import START, StateGraph
|
|
from langgraph.graph.message import add_messages
|
|
from langgraph.prebuilt import ToolNode, tools_condition
|
|
from typing_extensions import TypedDict
|
|
|
|
from open_notebook.graphs.tools import repository_search
|
|
from open_notebook.graphs.utils import provision_langchain_model
|
|
from open_notebook.prompter import Prompter
|
|
|
|
tools = [repository_search]
|
|
tool_node = ToolNode(tools)
|
|
|
|
|
|
class ThreadState(TypedDict):
|
|
messages: Annotated[list, add_messages]
|
|
# notebook: Optional[Notebook]
|
|
# context: Optional[str]
|
|
# context_config: Optional[dict]
|
|
|
|
|
|
def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
|
|
system_prompt = Prompter(prompt_template="rag").render(data=state)
|
|
payload = [system_prompt] + state.get("messages", [])
|
|
model = provision_langchain_model(str(payload), config, "tools", max_tokens=2000)
|
|
model = model.bind_tools(tools)
|
|
ai_message = model.invoke(payload)
|
|
return {"messages": ai_message}
|
|
|
|
|
|
agent_state = StateGraph(ThreadState)
|
|
agent_state.add_node("agent", call_model_with_messages)
|
|
agent_state.add_node("tools", tool_node)
|
|
agent_state.add_edge(START, "agent")
|
|
agent_state.add_conditional_edges(
|
|
"agent",
|
|
tools_condition,
|
|
)
|
|
agent_state.add_edge("tools", "agent")
|
|
graph = agent_state.compile()
|