diff --git a/open_notebook/graphs/chat.py b/open_notebook/graphs/chat.py index 0c0e517..e403ace 100644 --- a/open_notebook/graphs/chat.py +++ b/open_notebook/graphs/chat.py @@ -11,7 +11,7 @@ from typing_extensions import TypedDict from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE from open_notebook.domain.notebook import Notebook -from open_notebook.graphs.utils import provision_model +from open_notebook.graphs.utils import provision_langchain_model from open_notebook.prompter import Prompter @@ -25,7 +25,7 @@ class ThreadState(TypedDict): def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict: system_prompt = Prompter(prompt_template="chat").render(data=state) payload = [system_prompt] + state.get("messages", []) - model = provision_model(str(payload), config, "chat") + model = provision_langchain_model(str(payload), config, "chat") ai_message = model.invoke(payload) return {"messages": ai_message} diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py index 3f84f3b..a3874e8 100644 --- a/open_notebook/graphs/utils.py +++ b/open_notebook/graphs/utils.py @@ -1,13 +1,14 @@ -from langchain.output_parsers import OutputFixingParser -from langchain_core.messages import AIMessage +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import BaseMessage from loguru import logger from open_notebook.domain.models import model_manager +from open_notebook.models.llms import LanguageModel from open_notebook.prompter import Prompter from open_notebook.utils import token_count -def provision_model(content, config, default_type): +def provision_langchain_model(content, config, default_type) -> BaseChatModel: """ Returns the best model to use based on the context size and on whether there is a specific model being requested in Config. If context > 105_000, returns the large_context_model @@ -20,13 +21,14 @@ def provision_model(content, config, default_type): logger.debug( f"Using large context model because the content has {tokens} tokens" ) - return model_manager.get_default_model("large_context").to_langchain() + model = model_manager.get_default_model("large_context") elif config.get("configurable", {}).get("model_id"): - return model_manager.get_model( - config.get("configurable", {}).get("model_id") - ).to_langchain() + model = model_manager.get_model(config.get("configurable", {}).get("model_id")) else: - return model_manager.get_default_model(default_type).to_langchain() + model = model_manager.get_default_model(default_type) + + assert isinstance(model, LanguageModel), f"Model is not a LanguageModel: {model}" + return model.to_langchain() # todo: turn into a graph @@ -36,23 +38,12 @@ def run_pattern( messages=[], state: dict = {}, parser=None, - output_fixing_model_id=None, -) -> AIMessage: +) -> BaseMessage: system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render( data=state ) payload = [system_prompt] + messages - chain = provision_model(str(payload), config, "transformation") - - if parser: - chain = chain | parser - - if output_fixing_model_id and parser: - output_fix_model = model_manager.get_model(output_fixing_model_id) - chain = chain | OutputFixingParser.from_llm( - parser=parser, - llm=output_fix_model, - ) + chain = provision_langchain_model(str(payload), config, "transformation") response = chain.invoke(payload) diff --git a/open_notebook/utils.py b/open_notebook/utils.py index 340762e..86479e2 100644 --- a/open_notebook/utils.py +++ b/open_notebook/utils.py @@ -28,7 +28,7 @@ def split_text(txt: str, chunk=1000, overlap=0, separator=" "): return text_splitter.split_text(txt) -def token_count(input_string): +def token_count(input_string) -> int: """ Count the number of tokens in the input string using the 'o200k_base' encoding. @@ -46,7 +46,7 @@ def token_count(input_string): return token_count -def token_cost(token_count, cost_per_million=0.150): +def token_cost(token_count, cost_per_million=0.150) -> float: """ Calculate the cost of tokens based on the token count and cost per million tokens. @@ -60,11 +60,11 @@ def token_cost(token_count, cost_per_million=0.150): return cost_per_million * (token_count / 1_000_000) -def remove_non_ascii(text): +def remove_non_ascii(text) -> str: return re.sub(r"[^\x00-\x7F]+", "", text) -def remove_non_printable(text): +def remove_non_printable(text) -> str: # Remove control characters, except newlines and tabs text = "".join( char for char in text if unicodedata.category(char)[0] != "C" or char in "\n\t" @@ -74,7 +74,7 @@ def remove_non_printable(text): return re.sub(r"[^\w\s.,!?\-\n\t]", "", text, flags=re.UNICODE) -def surreal_clean(text): +def surreal_clean(text) -> str: """ Clean the input text by removing non-ASCII and non-printable characters, and adjusting colon placement for SurrealDB compatibility.