improved typing

This commit is contained in:
LUIS NOVO 2024-11-01 22:50:27 -03:00
parent 7dc37a3ac7
commit d9c0c93deb
3 changed files with 19 additions and 28 deletions

View file

@ -11,7 +11,7 @@ from typing_extensions import TypedDict
from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE
from open_notebook.domain.notebook import Notebook
from open_notebook.graphs.utils import provision_model
from open_notebook.graphs.utils import provision_langchain_model
from open_notebook.prompter import Prompter
@ -25,7 +25,7 @@ class ThreadState(TypedDict):
def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
system_prompt = Prompter(prompt_template="chat").render(data=state)
payload = [system_prompt] + state.get("messages", [])
model = provision_model(str(payload), config, "chat")
model = provision_langchain_model(str(payload), config, "chat")
ai_message = model.invoke(payload)
return {"messages": ai_message}

View file

@ -1,13 +1,14 @@
from langchain.output_parsers import OutputFixingParser
from langchain_core.messages import AIMessage
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage
from loguru import logger
from open_notebook.domain.models import model_manager
from open_notebook.models.llms import LanguageModel
from open_notebook.prompter import Prompter
from open_notebook.utils import token_count
def provision_model(content, config, default_type):
def provision_langchain_model(content, config, default_type) -> BaseChatModel:
"""
Returns the best model to use based on the context size and on whether there is a specific model being requested in Config.
If context > 105_000, returns the large_context_model
@ -20,13 +21,14 @@ def provision_model(content, config, default_type):
logger.debug(
f"Using large context model because the content has {tokens} tokens"
)
return model_manager.get_default_model("large_context").to_langchain()
model = model_manager.get_default_model("large_context")
elif config.get("configurable", {}).get("model_id"):
return model_manager.get_model(
config.get("configurable", {}).get("model_id")
).to_langchain()
model = model_manager.get_model(config.get("configurable", {}).get("model_id"))
else:
return model_manager.get_default_model(default_type).to_langchain()
model = model_manager.get_default_model(default_type)
assert isinstance(model, LanguageModel), f"Model is not a LanguageModel: {model}"
return model.to_langchain()
# todo: turn into a graph
@ -36,23 +38,12 @@ def run_pattern(
messages=[],
state: dict = {},
parser=None,
output_fixing_model_id=None,
) -> AIMessage:
) -> BaseMessage:
system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render(
data=state
)
payload = [system_prompt] + messages
chain = provision_model(str(payload), config, "transformation")
if parser:
chain = chain | parser
if output_fixing_model_id and parser:
output_fix_model = model_manager.get_model(output_fixing_model_id)
chain = chain | OutputFixingParser.from_llm(
parser=parser,
llm=output_fix_model,
)
chain = provision_langchain_model(str(payload), config, "transformation")
response = chain.invoke(payload)