diff --git a/open_notebook/graphs/ask.py b/open_notebook/graphs/ask.py index c320642..3297627 100644 --- a/open_notebook/graphs/ask.py +++ b/open_notebook/graphs/ask.py @@ -1,6 +1,7 @@ import operator from typing import Annotated, List +from ai_prompter import Prompter from langchain_core.output_parsers.pydantic import PydanticOutputParser from langchain_core.runnables import ( RunnableConfig, @@ -12,7 +13,6 @@ from typing_extensions import TypedDict from open_notebook.domain.notebook import vector_search from open_notebook.graphs.utils import provision_langchain_model -from open_notebook.prompter import Prompter class SubGraphState(TypedDict): diff --git a/open_notebook/graphs/chat.py b/open_notebook/graphs/chat.py index 8d6835a..afbf054 100644 --- a/open_notebook/graphs/chat.py +++ b/open_notebook/graphs/chat.py @@ -1,6 +1,7 @@ import sqlite3 from typing import Annotated, Optional +from ai_prompter import Prompter from langchain_core.messages import SystemMessage from langchain_core.runnables import ( RunnableConfig, @@ -13,7 +14,6 @@ from typing_extensions import TypedDict from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE from open_notebook.domain.notebook import Notebook from open_notebook.graphs.utils import provision_langchain_model -from open_notebook.prompter import Prompter class ThreadState(TypedDict): diff --git a/open_notebook/graphs/prompt.py b/open_notebook/graphs/prompt.py index 176576c..574e399 100644 --- a/open_notebook/graphs/prompt.py +++ b/open_notebook/graphs/prompt.py @@ -1,15 +1,13 @@ from typing import Any, Optional +from ai_prompter import Prompter from langchain_core.messages import HumanMessage, SystemMessage -from langchain_core.runnables import ( - RunnableConfig, -) +from langchain_core.runnables import RunnableConfig from langgraph.graph import END, START, StateGraph from loguru import logger from typing_extensions import TypedDict from open_notebook.graphs.utils import provision_langchain_model -from open_notebook.prompter import Prompter class PatternChainState(TypedDict): @@ -22,7 +20,7 @@ class PatternChainState(TypedDict): def call_model(state: dict, config: RunnableConfig) -> dict: content = state["input_text"] system_prompt = Prompter( - prompt_text=state["prompt"], parser=state.get("parser") + template_text=state["prompt"], parser=state.get("parser") ).render(data=state) logger.warning(content) payload = [SystemMessage(content=system_prompt)] + [HumanMessage(content=content)] diff --git a/open_notebook/graphs/transformation.py b/open_notebook/graphs/transformation.py index ab5c799..f610945 100644 --- a/open_notebook/graphs/transformation.py +++ b/open_notebook/graphs/transformation.py @@ -1,14 +1,12 @@ +from ai_prompter import Prompter from langchain_core.messages import HumanMessage, SystemMessage -from langchain_core.runnables import ( - RunnableConfig, -) +from langchain_core.runnables import RunnableConfig from langgraph.graph import END, START, StateGraph from typing_extensions import TypedDict from open_notebook.domain.notebook import Source from open_notebook.domain.transformation import DefaultPrompts, Transformation from open_notebook.graphs.utils import provision_langchain_model -from open_notebook.prompter import Prompter class TransformationState(TypedDict): @@ -25,14 +23,16 @@ def run_transformation(state: dict, config: RunnableConfig) -> dict: transformation: Transformation = state["transformation"] if not content: content = source.full_text - transformation_prompt_text = transformation.prompt + transformation_template_text = transformation.prompt default_prompts: DefaultPrompts = DefaultPrompts() if default_prompts.transformation_instructions: - transformation_prompt_text = f"{default_prompts.transformation_instructions}\n\n{transformation_prompt_text}" + transformation_template_text = f"{default_prompts.transformation_instructions}\n\n{transformation_template_text}" - transformation_prompt_text = f"{transformation_prompt_text}\n\n# INPUT" + transformation_template_text = f"{transformation_template_text}\n\n# INPUT" - system_prompt = Prompter(prompt_text=transformation_prompt_text).render(data=state) + system_prompt = Prompter(template_text=transformation_template_text).render( + data=state + ) payload = [SystemMessage(content=system_prompt)] + [HumanMessage(content=content)] chain = provision_langchain_model( str(payload), diff --git a/open_notebook/prompter.py b/open_notebook/prompter.py deleted file mode 100644 index 12d6971..0000000 --- a/open_notebook/prompter.py +++ /dev/null @@ -1,102 +0,0 @@ -""" -A prompt management module using Jinja to generate complex prompts with simple templates. -""" - -import os -from dataclasses import dataclass -from datetime import datetime -from typing import Any, Optional, Union - -from jinja2 import Environment, FileSystemLoader, Template - -current_dir = os.path.dirname(os.path.abspath(__file__)) - -project_root = os.path.dirname(current_dir) - -env = Environment( - loader=FileSystemLoader( - os.path.join(project_root, os.environ.get("PROMPT_PATH", "prompts")) - ) -) - - -@dataclass -class Prompter: - """ - A class for managing and rendering prompt templates. - - Attributes: - prompt_template (str, optional): The name of the prompt template file. - prompt_variation (str, optional): The variation of the prompt template. - prompt_text (str, optional): The raw prompt text. - template (Union[str, Template], optional): The Jinja2 template object. - """ - - prompt_template: Optional[str] = None - prompt_variation: Optional[str] = "default" - prompt_text: Optional[str] = None - template: Optional[Union[str, Template]] = None - parser: Optional[Any] = None - - def __init__(self, prompt_template=None, prompt_text=None, parser=None): - """ - Initialize the Prompter with either a template file or raw text. - - Args: - prompt_template (str, optional): The name of the prompt template file. - prompt_text (str, optional): The raw prompt text. - """ - self.prompt_template = prompt_template - self.prompt_text = prompt_text - self.parser = parser - self.setup() - - def setup(self): - """ - Set up the Jinja2 template based on the provided template file or text. - Raises: - ValueError: If neither prompt_template nor prompt_text is provided. - """ - if self.prompt_template: - self.template = env.get_template(f"{self.prompt_template}.jinja") - elif self.prompt_text: - self.template = Template(self.prompt_text) - else: - raise ValueError("Prompter must have a prompt_template or prompt_text") - - assert self.prompt_template or self.prompt_text, "Prompt is required" - - @classmethod - def from_text(cls, text: str): - """ - Create a Prompter instance from raw text, which can contain Jinja code. - - Args: - text (str): The raw prompt text. - - Returns: - Prompter: A new Prompter instance. - """ - return cls(prompt_text=text) - - def render(self, data) -> str: - """ - Render the prompt template with the given data. - - Args: - data (dict): The data to be used in rendering the template. - - Returns: - str: The rendered prompt text. - - Raises: - AssertionError: If the template is not defined or not a Jinja2 Template. - """ - data["current_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - if self.parser: - data["format_instructions"] = self.parser.get_format_instructions() - assert self.template, "Prompter template is not defined" - assert isinstance( - self.template, Template - ), "Prompter template is not a Jinja2 Template" - return self.template.render(data) diff --git a/pyproject.toml b/pyproject.toml index 7c59511..c722717 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "open-notebook" -version = "0.2.0" +version = "0.2.1" description = "An open source implementation of a research assistant, inspired by Google Notebook LM" authors = [ {name = "Luis Novo", email = "lfnovo@gmail.com"} @@ -41,6 +41,7 @@ dependencies = [ "podcastfy", "nest-asyncio>=1.6.0", "content-core>=1.0.0", + "ai-prompter>=0.3", ] [tool.setuptools] diff --git a/uv.lock b/uv.lock index 31a246c..da70186 100644 --- a/uv.lock +++ b/uv.lock @@ -20,16 +20,16 @@ resolution-markers = [ [[package]] name = "ai-prompter" -version = "0.2.3" +version = "0.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jinja2" }, { name = "pip" }, { name = "pydantic" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/39/ff/cf13c31b88c06e11a1ffeed505601c167293b23d3e2e4e02adac93cc9300/ai_prompter-0.2.3.tar.gz", hash = "sha256:40f55c18f87df250a13f84d0cf7a4e8b31815a01f27666039386d6592849694b", size = 72955 } +sdist = { url = "https://files.pythonhosted.org/packages/88/1a/263b2fb49a485d1b394ead887361cb8855ab28daa20a184cef0d2a0f8f2c/ai_prompter-0.3.0.tar.gz", hash = "sha256:3369555345386c6b9eebb7edbbb96df268977ab2657acb2890c217290bf92569", size = 74091 } wheels = [ - { url = "https://files.pythonhosted.org/packages/5f/11/9e3712b8393dbef152258c68617baec343040c3d08b372d77b57e51d8e5d/ai_prompter-0.2.3-py3-none-any.whl", hash = "sha256:e8c0becbb3c8bdff399e372830e2c0a3cc3292e02d67921e2b255871329ee477", size = 7345 }, + { url = "https://files.pythonhosted.org/packages/90/ae/cc493d9d37cd1501e442154aa7265fa05814d0e8519ddf549ebd2f5fcb1b/ai_prompter-0.3.0-py3-none-any.whl", hash = "sha256:b70569bf6a64258ab3453e1ff99a7a4cd1c7709296093dc2a35127230d408e7b", size = 8419 }, ] [[package]] @@ -2701,9 +2701,10 @@ wheels = [ [[package]] name = "open-notebook" -version = "0.2.0" +version = "0.2.1" source = { editable = "." } dependencies = [ + { name = "ai-prompter" }, { name = "content-core" }, { name = "google-generativeai" }, { name = "groq" }, @@ -2752,6 +2753,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "ai-prompter", specifier = ">=0.3" }, { name = "content-core", specifier = ">=1.0.0" }, { name = "google-generativeai", specifier = ">=0.8.3" }, { name = "groq", specifier = ">=0.12.0" },