diff --git a/llm_config.py b/llm_config.py index b5f99c5..9f92e1a 100644 --- a/llm_config.py +++ b/llm_config.py @@ -1,6 +1,6 @@ # llm_config.py -LLM_TYPE = "ollama" # Options: 'llama_cpp', 'ollama' +LLM_TYPE = "anthropic" # Options: 'llama_cpp', 'ollama', 'openai', 'anthropic' # LLM settings for llama_cpp MODEL_PATH = "/home/james/llama.cpp/models/gemma-2-9b-it-Q6_K.gguf" # Replace with your llama.cpp models filepath @@ -31,10 +31,39 @@ LLM_CONFIG_OLLAMA = { "stop": ["User:", "\n\n"] } +# LLM settings for OpenAI +LLM_CONFIG_OPENAI = { + "llm_type": "openai", + "api_key": "", # Set via environment variable OPENAI_API_KEY + "base_url": None, # Optional: Set to use alternative OpenAI-compatible endpoints + "model_name": "gpt-4o", # Required: Specify the model to use + "temperature": 0.7, + "top_p": 0.9, + "max_tokens": 4096, + "stop": ["User:", "\n\n"], + "presence_penalty": 0, + "frequency_penalty": 0 +} + +# LLM settings for Anthropic +LLM_CONFIG_ANTHROPIC = { + "llm_type": "anthropic", + "api_key": "", # Set via environment variable ANTHROPIC_API_KEY + "model_name": "claude-3-5-sonnet-latest", # Required: Specify the model to use + "temperature": 0.7, + "top_p": 0.9, + "max_tokens": 4096, + "stop": ["User:", "\n\n"] +} + def get_llm_config(): if LLM_TYPE == "llama_cpp": return LLM_CONFIG_LLAMA_CPP elif LLM_TYPE == "ollama": return LLM_CONFIG_OLLAMA + elif LLM_TYPE == "openai": + return LLM_CONFIG_OPENAI + elif LLM_TYPE == "anthropic": + return LLM_CONFIG_ANTHROPIC else: raise ValueError(f"Invalid LLM_TYPE: {LLM_TYPE}") diff --git a/llm_wrapper.py b/llm_wrapper.py index f8b97c0..ac0cdb4 100644 --- a/llm_wrapper.py +++ b/llm_wrapper.py @@ -1,17 +1,25 @@ +import os from llama_cpp import Llama import requests import json from llm_config import get_llm_config +from openai import OpenAI +from anthropic import Anthropic class LLMWrapper: def __init__(self): self.llm_config = get_llm_config() self.llm_type = self.llm_config.get('llm_type', 'llama_cpp') + if self.llm_type == 'llama_cpp': self.llm = self._initialize_llama_cpp() elif self.llm_type == 'ollama': self.base_url = self.llm_config.get('base_url', 'http://localhost:11434') self.model_name = self.llm_config.get('model_name', 'your_model_name') + elif self.llm_type == 'openai': + self._initialize_openai() + elif self.llm_type == 'anthropic': + self._initialize_anthropic() else: raise ValueError(f"Unsupported LLM type: {self.llm_type}") @@ -24,6 +32,36 @@ class LLMWrapper: verbose=False ) + def _initialize_openai(self): + api_key = os.getenv('OPENAI_API_KEY') or self.llm_config.get('api_key') + if not api_key: + raise ValueError("OpenAI API key not found. Set OPENAI_API_KEY environment variable.") + + base_url = self.llm_config.get('base_url') + model_name = self.llm_config.get('model_name') + + if not model_name: + raise ValueError("OpenAI model name not specified in config") + + client_kwargs = {'api_key': api_key} + if base_url: + client_kwargs['base_url'] = base_url + + self.client = OpenAI(**client_kwargs) + self.model_name = model_name + + def _initialize_anthropic(self): + api_key = os.getenv('ANTHROPIC_API_KEY') or self.llm_config.get('api_key') + if not api_key: + raise ValueError("Anthropic API key not found. Set ANTHROPIC_API_KEY environment variable.") + + model_name = self.llm_config.get('model_name') + if not model_name: + raise ValueError("Anthropic model name not specified in config") + + self.client = Anthropic(api_key=api_key) + self.model_name = model_name + def generate(self, prompt, **kwargs): if self.llm_type == 'llama_cpp': llama_kwargs = self._prepare_llama_kwargs(kwargs) @@ -31,6 +69,10 @@ class LLMWrapper: return response['choices'][0]['text'].strip() elif self.llm_type == 'ollama': return self._ollama_generate(prompt, **kwargs) + elif self.llm_type == 'openai': + return self._openai_generate(prompt, **kwargs) + elif self.llm_type == 'anthropic': + return self._anthropic_generate(prompt, **kwargs) else: raise ValueError(f"Unsupported LLM type: {self.llm_type}") @@ -53,6 +95,38 @@ class LLMWrapper: text = ''.join(json.loads(line)['response'] for line in response.iter_lines() if line) return text.strip() + def _openai_generate(self, prompt, **kwargs): + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=[{"role": "user", "content": prompt}], + temperature=kwargs.get('temperature', self.llm_config.get('temperature', 0.7)), + top_p=kwargs.get('top_p', self.llm_config.get('top_p', 0.9)), + max_tokens=kwargs.get('max_tokens', self.llm_config.get('max_tokens', 4096)), + stop=kwargs.get('stop', self.llm_config.get('stop', [])), + presence_penalty=self.llm_config.get('presence_penalty', 0), + frequency_penalty=self.llm_config.get('frequency_penalty', 0) + ) + return response.choices[0].message.content.strip() + except Exception as e: + raise Exception(f"OpenAI API request failed: {str(e)}") + + def _anthropic_generate(self, prompt, **kwargs): + try: + response = self.client.messages.create( + model=self.model_name, + max_tokens=kwargs.get('max_tokens', self.llm_config.get('max_tokens', 4096)), + temperature=kwargs.get('temperature', self.llm_config.get('temperature', 0.7)), + top_p=kwargs.get('top_p', self.llm_config.get('top_p', 0.9)), + messages=[{ + "role": "user", + "content": prompt + }] + ) + return response.content[0].text.strip() + except Exception as e: + raise Exception(f"Anthropic API request failed: {str(e)}") + def _cleanup(self): """Force terminate any running LLM processes""" if self.llm_type == 'ollama': diff --git a/requirements.txt b/requirements.txt index 491c599..71b5c6b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,5 @@ keyboard curses-windows; sys_platform == 'win32' tqdm urllib3 +openai>=1.0.0 +anthropic>=0.7.0