Define haiku + prompt engine takes a directory arg (#279)

This commit is contained in:
Kerem Yilmaz 2024-05-08 02:07:18 -07:00 committed by GitHub
parent 42d652f381
commit e5d094493e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 16 additions and 2 deletions

View file

@ -65,6 +65,9 @@ if SettingsManager.get_settings().ENABLE_ANTHROPIC:
LLMConfigRegistry.register_config(
"ANTHROPIC_CLAUDE3_SONNET", LLMConfig("anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], True)
)
LLMConfigRegistry.register_config(
"ANTHROPIC_CLAUDE3_HAIKU", LLMConfig("anthropic/claude-3-haiku-20240307", ["ANTHROPIC_API_KEY"], True)
)
if SettingsManager.get_settings().ENABLE_BEDROCK:
# Supported through AWS IAM authentication
@ -84,6 +87,14 @@ if SettingsManager.get_settings().ENABLE_BEDROCK:
True,
),
)
LLMConfigRegistry.register_config(
"BEDROCK_ANTHROPIC_CLAUDE3_HAIKU",
LLMConfig(
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
["AWS_REGION"],
True,
),
)
if SettingsManager.get_settings().ENABLE_AZURE:
LLMConfigRegistry.register_config(

View file

@ -21,11 +21,14 @@ class PromptEngine:
import glob
import os
from difflib import get_close_matches
from pathlib import Path
from typing import Any, List
import structlog
from jinja2 import Environment, FileSystemLoader
from skyvern.constants import SKYVERN_DIR
LOG = structlog.get_logger()
@ -34,7 +37,7 @@ class PromptEngine:
Class to handle loading and populating Jinja2 templates for prompts.
"""
def __init__(self, model: str):
def __init__(self, model: str, prompts_dir: Path = SKYVERN_DIR / "forge" / "prompts") -> None:
"""
Initialize the PromptEngine with the specified model.
@ -45,7 +48,7 @@ class PromptEngine:
try:
# Get the list of all model directories
models_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../prompts"))
models_dir = os.path.abspath(prompts_dir)
model_names = [
os.path.basename(os.path.normpath(d))
for d in glob.glob(os.path.join(models_dir, "*/"))