Define haiku + prompt engine takes a directory arg (#279)

This commit is contained in:
Kerem Yilmaz 2024-05-08 02:07:18 -07:00 committed by GitHub
parent 42d652f381
commit e5d094493e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 16 additions and 2 deletions

View file

@ -65,6 +65,9 @@ if SettingsManager.get_settings().ENABLE_ANTHROPIC:
LLMConfigRegistry.register_config( LLMConfigRegistry.register_config(
"ANTHROPIC_CLAUDE3_SONNET", LLMConfig("anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], True) "ANTHROPIC_CLAUDE3_SONNET", LLMConfig("anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], True)
) )
LLMConfigRegistry.register_config(
"ANTHROPIC_CLAUDE3_HAIKU", LLMConfig("anthropic/claude-3-haiku-20240307", ["ANTHROPIC_API_KEY"], True)
)
if SettingsManager.get_settings().ENABLE_BEDROCK: if SettingsManager.get_settings().ENABLE_BEDROCK:
# Supported through AWS IAM authentication # Supported through AWS IAM authentication
@ -84,6 +87,14 @@ if SettingsManager.get_settings().ENABLE_BEDROCK:
True, True,
), ),
) )
LLMConfigRegistry.register_config(
"BEDROCK_ANTHROPIC_CLAUDE3_HAIKU",
LLMConfig(
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
["AWS_REGION"],
True,
),
)
if SettingsManager.get_settings().ENABLE_AZURE: if SettingsManager.get_settings().ENABLE_AZURE:
LLMConfigRegistry.register_config( LLMConfigRegistry.register_config(

View file

@ -21,11 +21,14 @@ class PromptEngine:
import glob import glob
import os import os
from difflib import get_close_matches from difflib import get_close_matches
from pathlib import Path
from typing import Any, List from typing import Any, List
import structlog import structlog
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
from skyvern.constants import SKYVERN_DIR
LOG = structlog.get_logger() LOG = structlog.get_logger()
@ -34,7 +37,7 @@ class PromptEngine:
Class to handle loading and populating Jinja2 templates for prompts. Class to handle loading and populating Jinja2 templates for prompts.
""" """
def __init__(self, model: str): def __init__(self, model: str, prompts_dir: Path = SKYVERN_DIR / "forge" / "prompts") -> None:
""" """
Initialize the PromptEngine with the specified model. Initialize the PromptEngine with the specified model.
@ -45,7 +48,7 @@ class PromptEngine:
try: try:
# Get the list of all model directories # Get the list of all model directories
models_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../prompts")) models_dir = os.path.abspath(prompts_dir)
model_names = [ model_names = [
os.path.basename(os.path.normpath(d)) os.path.basename(os.path.normpath(d))
for d in glob.glob(os.path.join(models_dir, "*/")) for d in glob.glob(os.path.join(models_dir, "*/"))