mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-04-30 04:20:02 +00:00
make model rag work with vector only
This commit is contained in:
parent
e4b8fa8cc7
commit
80353a97c9
3 changed files with 17 additions and 15 deletions
|
|
@ -1,5 +1,5 @@
|
|||
import operator
|
||||
from typing import Annotated, List, Literal
|
||||
from typing import Annotated, List
|
||||
|
||||
from langchain_core.output_parsers.pydantic import PydanticOutputParser
|
||||
from langchain_core.runnables import (
|
||||
|
|
@ -7,10 +7,11 @@ from langchain_core.runnables import (
|
|||
)
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.types import Send
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from open_notebook.domain.notebook import text_search, vector_search
|
||||
from open_notebook.domain.notebook import vector_search
|
||||
from open_notebook.graphs.utils import provision_langchain_model
|
||||
from open_notebook.prompter import Prompter
|
||||
|
||||
|
|
@ -18,7 +19,7 @@ from open_notebook.prompter import Prompter
|
|||
class SubGraphState(TypedDict):
|
||||
question: str
|
||||
term: str
|
||||
type: Literal["text", "vector"]
|
||||
# type: Literal["text", "vector"]
|
||||
instructions: str
|
||||
results: dict
|
||||
answer: str
|
||||
|
|
@ -26,9 +27,9 @@ class SubGraphState(TypedDict):
|
|||
|
||||
class Search(BaseModel):
|
||||
term: str
|
||||
type: Literal["text", "vector"] = Field(
|
||||
description="The type of search. Use 'text' for keyword search and 'vector' for semantic search. If you are using text, search always for a single word"
|
||||
)
|
||||
# type: Literal["text", "vector"] = Field(
|
||||
# description="The type of search. Use 'text' for keyword search and 'vector' for semantic search. If you are using text, search always for a single word"
|
||||
# )
|
||||
instructions: str = Field(
|
||||
description="Tell the answeting LLM what information you need extracted from this search"
|
||||
)
|
||||
|
|
@ -62,6 +63,7 @@ async def call_model_with_messages(state: ThreadState, config: RunnableConfig) -
|
|||
)
|
||||
# model = model.bind_tools(tools)
|
||||
ai_message = (model | parser).invoke(system_prompt)
|
||||
logger.debug(ai_message)
|
||||
return {"strategy": ai_message}
|
||||
|
||||
|
||||
|
|
@ -73,7 +75,7 @@ async def trigger_queries(state: ThreadState, config: RunnableConfig):
|
|||
"question": state["question"],
|
||||
"instructions": s.instructions,
|
||||
"term": s.term,
|
||||
"type": s.type,
|
||||
# "type": s.type,
|
||||
},
|
||||
)
|
||||
for s in state["strategy"].searches
|
||||
|
|
@ -82,10 +84,10 @@ async def trigger_queries(state: ThreadState, config: RunnableConfig):
|
|||
|
||||
async def provide_answer(state: SubGraphState, config: RunnableConfig) -> dict:
|
||||
payload = state
|
||||
if state["type"] == "text":
|
||||
results = text_search(state["term"], 10, True, True)
|
||||
else:
|
||||
results = vector_search(state["term"], 10, True, True)
|
||||
# if state["type"] == "text":
|
||||
# results = text_search(state["term"], 10, True, True)
|
||||
# else:
|
||||
results = vector_search(state["term"], 10, True, True)
|
||||
if len(results) == 0:
|
||||
return {"answers": []}
|
||||
payload["results"] = results
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue