mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2025-09-02 18:50:24 +00:00
Define haiku + prompt engine takes a directory arg (#279)
This commit is contained in:
parent
42d652f381
commit
e5d094493e
2 changed files with 16 additions and 2 deletions
|
@ -65,6 +65,9 @@ if SettingsManager.get_settings().ENABLE_ANTHROPIC:
|
||||||
LLMConfigRegistry.register_config(
|
LLMConfigRegistry.register_config(
|
||||||
"ANTHROPIC_CLAUDE3_SONNET", LLMConfig("anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], True)
|
"ANTHROPIC_CLAUDE3_SONNET", LLMConfig("anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], True)
|
||||||
)
|
)
|
||||||
|
LLMConfigRegistry.register_config(
|
||||||
|
"ANTHROPIC_CLAUDE3_HAIKU", LLMConfig("anthropic/claude-3-haiku-20240307", ["ANTHROPIC_API_KEY"], True)
|
||||||
|
)
|
||||||
|
|
||||||
if SettingsManager.get_settings().ENABLE_BEDROCK:
|
if SettingsManager.get_settings().ENABLE_BEDROCK:
|
||||||
# Supported through AWS IAM authentication
|
# Supported through AWS IAM authentication
|
||||||
|
@ -84,6 +87,14 @@ if SettingsManager.get_settings().ENABLE_BEDROCK:
|
||||||
True,
|
True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
LLMConfigRegistry.register_config(
|
||||||
|
"BEDROCK_ANTHROPIC_CLAUDE3_HAIKU",
|
||||||
|
LLMConfig(
|
||||||
|
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
["AWS_REGION"],
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
if SettingsManager.get_settings().ENABLE_AZURE:
|
if SettingsManager.get_settings().ENABLE_AZURE:
|
||||||
LLMConfigRegistry.register_config(
|
LLMConfigRegistry.register_config(
|
||||||
|
|
|
@ -21,11 +21,14 @@ class PromptEngine:
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from difflib import get_close_matches
|
from difflib import get_close_matches
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
from jinja2 import Environment, FileSystemLoader
|
from jinja2 import Environment, FileSystemLoader
|
||||||
|
|
||||||
|
from skyvern.constants import SKYVERN_DIR
|
||||||
|
|
||||||
LOG = structlog.get_logger()
|
LOG = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,7 +37,7 @@ class PromptEngine:
|
||||||
Class to handle loading and populating Jinja2 templates for prompts.
|
Class to handle loading and populating Jinja2 templates for prompts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model: str):
|
def __init__(self, model: str, prompts_dir: Path = SKYVERN_DIR / "forge" / "prompts") -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the PromptEngine with the specified model.
|
Initialize the PromptEngine with the specified model.
|
||||||
|
|
||||||
|
@ -45,7 +48,7 @@ class PromptEngine:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get the list of all model directories
|
# Get the list of all model directories
|
||||||
models_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../prompts"))
|
models_dir = os.path.abspath(prompts_dir)
|
||||||
model_names = [
|
model_names = [
|
||||||
os.path.basename(os.path.normpath(d))
|
os.path.basename(os.path.normpath(d))
|
||||||
for d in glob.glob(os.path.join(models_dir, "*/"))
|
for d in glob.glob(os.path.join(models_dir, "*/"))
|
||||||
|
|
Loading…
Add table
Reference in a new issue