diff --git a/open_notebook/graphs/chat.py b/open_notebook/graphs/chat.py index e403ace..5e3b4ca 100644 --- a/open_notebook/graphs/chat.py +++ b/open_notebook/graphs/chat.py @@ -25,7 +25,7 @@ class ThreadState(TypedDict): def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict: system_prompt = Prompter(prompt_template="chat").render(data=state) payload = [system_prompt] + state.get("messages", []) - model = provision_langchain_model(str(payload), config, "chat") + model = provision_langchain_model(str(payload), config, "chat", max_tokens=2000) ai_message = model.invoke(payload) return {"messages": ai_message} diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py index a3874e8..07365ea 100644 --- a/open_notebook/graphs/utils.py +++ b/open_notebook/graphs/utils.py @@ -8,7 +8,7 @@ from open_notebook.prompter import Prompter from open_notebook.utils import token_count -def provision_langchain_model(content, config, default_type) -> BaseChatModel: +def provision_langchain_model(content, config, default_type, **kwargs) -> BaseChatModel: """ Returns the best model to use based on the context size and on whether there is a specific model being requested in Config. If context > 105_000, returns the large_context_model @@ -21,11 +21,13 @@ def provision_langchain_model(content, config, default_type) -> BaseChatModel: logger.debug( f"Using large context model because the content has {tokens} tokens" ) - model = model_manager.get_default_model("large_context") + model = model_manager.get_default_model("large_context", **kwargs) elif config.get("configurable", {}).get("model_id"): - model = model_manager.get_model(config.get("configurable", {}).get("model_id")) + model = model_manager.get_model( + config.get("configurable", {}).get("model_id"), **kwargs + ) else: - model = model_manager.get_default_model(default_type) + model = model_manager.get_default_model(default_type, **kwargs) assert isinstance(model, LanguageModel), f"Model is not a LanguageModel: {model}" return model.to_langchain()