mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-04-29 12:00:00 +00:00
54 lines
1.7 KiB
Python
54 lines
1.7 KiB
Python
import os
|
|
|
|
from langchain_core.runnables import (
|
|
RunnableConfig,
|
|
)
|
|
from langgraph.graph import END, START, StateGraph
|
|
from loguru import logger
|
|
from typing_extensions import TypedDict
|
|
|
|
from open_notebook.domain import Note, Notebook, Source
|
|
from open_notebook.model_configs import get_langchain_model
|
|
from open_notebook.prompter import Prompter
|
|
|
|
|
|
class DocQueryState(TypedDict):
|
|
doc_id: str
|
|
doc_content: str
|
|
question: str
|
|
answer: str
|
|
notebook: Notebook
|
|
|
|
|
|
def call_model_with_messages(state: DocQueryState, config: RunnableConfig) -> dict:
|
|
if config.get("configurable", {}).get("model_name", None):
|
|
model_name = config.get("configurable", {}).get("model_name", None)
|
|
else:
|
|
model_name = os.environ.get("RETRIEVAL_MODEL", os.environ["DEFAULT_MODEL"])
|
|
|
|
model = get_langchain_model(model_name)
|
|
system_prompt = Prompter(prompt_template="ask_content").render(data=state)
|
|
logger.debug(f"System prompt: {system_prompt}")
|
|
ai_message = model.invoke(system_prompt)
|
|
return {"answer": ai_message}
|
|
|
|
|
|
# todo: there is probably a better way to do this and avoid repetition
|
|
def get_content(state: DocQueryState) -> dict:
|
|
doc_id = state["doc_id"]
|
|
if "note:" in doc_id:
|
|
doc: Note = Note.get(id=doc_id)
|
|
elif "source:" in doc_id:
|
|
doc: Source = Source.get(id=doc_id)
|
|
doc_content = doc.get_context("long") if doc else None
|
|
return {"doc_content": doc_content}
|
|
|
|
|
|
agent_state = StateGraph(DocQueryState)
|
|
agent_state.add_node("get_content", get_content)
|
|
agent_state.add_node("agent", call_model_with_messages)
|
|
agent_state.add_edge(START, "get_content")
|
|
agent_state.add_edge("get_content", "agent")
|
|
agent_state.add_edge("agent", END)
|
|
|
|
graph = agent_state.compile()
|