mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-04-28 11:30:00 +00:00
438 lines
18 KiB
Python
438 lines
18 KiB
Python
"""
|
|
Connection testing for AI providers.
|
|
|
|
This module provides functionality to test if a provider's API key is valid
|
|
by making minimal API calls to each provider, and to test individual model
|
|
configurations end-to-end.
|
|
"""
|
|
import io
|
|
import os
|
|
import struct
|
|
from typing import List, Optional, Tuple
|
|
|
|
import httpx
|
|
from esperanto.factory import AIFactory
|
|
from loguru import logger
|
|
|
|
from open_notebook.domain.credential import Credential
|
|
|
|
# Test models for each provider - uses minimal/cheapest models for testing
|
|
# Format: (model_name, model_type)
|
|
TEST_MODELS = {
|
|
"openai": ("gpt-3.5-turbo", "language"),
|
|
"anthropic": ("claude-3-haiku-20240307", "language"),
|
|
"google": ("gemini-2.0-flash", "language"),
|
|
"groq": ("llama-3.1-8b-instant", "language"),
|
|
"mistral": ("mistral-small-latest", "language"),
|
|
"deepseek": ("deepseek-chat", "language"),
|
|
"xai": ("grok-beta", "language"),
|
|
"openrouter": ("openai/gpt-3.5-turbo", "language"),
|
|
"voyage": ("voyage-3-lite", "embedding"),
|
|
"elevenlabs": ("eleven_multilingual_v2", "text_to_speech"),
|
|
"ollama": (None, "language"), # Dynamic - will use first available model
|
|
# Complex providers with additional configuration
|
|
"vertex": ("gemini-2.0-flash", "language"), # Uses Google Vertex AI
|
|
"azure": ("gpt-35-turbo", "language"), # Azure OpenAI deployment name
|
|
"openai_compatible": (None, "language"), # Dynamic - will use first available model
|
|
}
|
|
|
|
|
|
async def _test_azure_connection(
|
|
endpoint: Optional[str] = None,
|
|
api_key: Optional[str] = None,
|
|
api_version: Optional[str] = None,
|
|
) -> Tuple[bool, str]:
|
|
"""
|
|
Test Azure OpenAI connectivity by listing models.
|
|
|
|
Azure requires deployment names which vary per user, so instead of
|
|
invoking a model, we list available models to validate credentials.
|
|
"""
|
|
test_endpoint = endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
|
test_api_key = api_key or os.environ.get("AZURE_OPENAI_API_KEY")
|
|
test_api_version = api_version or os.environ.get("AZURE_OPENAI_API_VERSION", "2024-10-21")
|
|
|
|
if not test_endpoint:
|
|
return False, "No Azure endpoint configured"
|
|
if not test_api_key:
|
|
return False, "No Azure API key configured"
|
|
|
|
# Strip trailing slash to avoid double-slash in URL
|
|
test_endpoint = test_endpoint.rstrip("/")
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
response = await client.get(
|
|
f"{test_endpoint}/openai/models?api-version={test_api_version}",
|
|
headers={"api-key": test_api_key},
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
models = data.get("data", [])
|
|
count = len(models)
|
|
if count > 0:
|
|
names = [m.get("id", "unknown") for m in models[:3]]
|
|
name_list = ", ".join(names)
|
|
if count > 3:
|
|
name_list += f" (+{count - 3} more)"
|
|
return True, f"Connected. {count} models: {name_list}"
|
|
else:
|
|
return True, "Connected successfully (no models found)"
|
|
elif response.status_code == 401:
|
|
return False, "Invalid API key"
|
|
elif response.status_code == 403:
|
|
return False, "API key lacks required permissions"
|
|
else:
|
|
return False, f"Azure returned status {response.status_code}"
|
|
|
|
except httpx.ConnectError:
|
|
return False, "Cannot connect to Azure endpoint. Check the URL."
|
|
except httpx.TimeoutException:
|
|
return False, "Connection timed out. Check the endpoint URL."
|
|
except Exception as e:
|
|
return False, f"Connection error: {str(e)[:100]}"
|
|
|
|
|
|
async def _test_ollama_connection(base_url: str) -> Tuple[bool, str]:
|
|
"""Test Ollama server connectivity."""
|
|
try:
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
# Try /api/tags endpoint (standard Ollama)
|
|
response = await client.get(f"{base_url}/api/tags")
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
models = data.get("models", [])
|
|
model_count = len(models)
|
|
|
|
if model_count > 0:
|
|
model_names = [m.get("name", "unknown") for m in models[:3]]
|
|
model_list = ", ".join(model_names)
|
|
if model_count > 3:
|
|
model_list += f" (+{model_count - 3} more)"
|
|
return True, f"Connected. {model_count} models available: {model_list}"
|
|
else:
|
|
return True, "Connected successfully (no models listed)"
|
|
elif response.status_code == 401:
|
|
return False, "Invalid API key"
|
|
elif response.status_code == 403:
|
|
return False, "API key lacks required permissions"
|
|
else:
|
|
return False, f"Server returned status {response.status_code}"
|
|
|
|
except httpx.ConnectError:
|
|
return False, "Cannot connect to Ollama. Check if Ollama server is running."
|
|
except httpx.TimeoutException:
|
|
return False, "Connection timed out. Check if Ollama server is accessible."
|
|
except Exception as e:
|
|
return False, f"Connection error: {str(e)[:100]}"
|
|
|
|
|
|
async def _test_openai_compatible_connection(base_url: str, api_key: Optional[str] = None) -> Tuple[bool, str]:
|
|
"""Test OpenAI-compatible server connectivity."""
|
|
try:
|
|
headers = {}
|
|
if api_key:
|
|
headers["Authorization"] = f"Bearer {api_key}"
|
|
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
# Try /models endpoint (standard OpenAI-compatible)
|
|
response = await client.get(f"{base_url}/models", headers=headers)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
models = data.get("data", [])
|
|
model_count = len(models)
|
|
|
|
if model_count > 0:
|
|
model_names = [m.get("id", "unknown") for m in models[:3]]
|
|
model_list = ", ".join(model_names)
|
|
if model_count > 3:
|
|
model_list += f" (+{model_count - 3} more)"
|
|
return True, f"Connected. {model_count} models available: {model_list}"
|
|
else:
|
|
return True, "Connected successfully (no models listed)"
|
|
elif response.status_code == 401:
|
|
return False, "Invalid API key"
|
|
elif response.status_code == 403:
|
|
return False, "API key lacks required permissions"
|
|
else:
|
|
return False, f"Server returned status {response.status_code}"
|
|
|
|
except httpx.ConnectError:
|
|
return False, "Cannot connect to server. Check the URL is correct."
|
|
except httpx.TimeoutException:
|
|
return False, "Connection timed out. Check if server is accessible."
|
|
except Exception as e:
|
|
return False, f"Connection error: {str(e)[:100]}"
|
|
|
|
async def test_provider_connection(
|
|
provider: str, model_type: str = "language", config_id: Optional[str] = None
|
|
) -> Tuple[bool, str]:
|
|
"""
|
|
Test if a provider's API key is valid by making a minimal API call.
|
|
|
|
Args:
|
|
provider: Provider name (openai, anthropic, etc.)
|
|
model_type: Type of model to test (language, embedding, etc.)
|
|
Note: This is overridden by TEST_MODELS if provider is in that dict.
|
|
config_id: Optional specific configuration ID to test (format: configId)
|
|
If provided, uses the configuration from ProviderConfig for this specific config.
|
|
|
|
Returns:
|
|
Tuple of (success: bool, message: str)
|
|
"""
|
|
try:
|
|
# Get configuration - either specific config or default
|
|
api_key: Optional[str] = None
|
|
base_url: Optional[str] = None
|
|
endpoint: Optional[str] = None
|
|
api_version: Optional[str] = None
|
|
model_name: Optional[str] = None
|
|
|
|
if config_id:
|
|
# Load specific credential from database
|
|
try:
|
|
cred = await Credential.get(config_id)
|
|
config = cred.to_esperanto_config()
|
|
api_key = config.get("api_key")
|
|
base_url = config.get("base_url")
|
|
endpoint = config.get("endpoint")
|
|
api_version = config.get("api_version")
|
|
except Exception:
|
|
return False, f"Credential not found: {config_id}"
|
|
|
|
# Normalize provider name (handle hyphenated aliases)
|
|
normalized_provider = provider.replace("-", "_")
|
|
|
|
# Special handling for URL-based providers (no API key, just connectivity)
|
|
if normalized_provider == "ollama":
|
|
# Use base_url from specific config, or environment variable
|
|
test_base_url = base_url or os.environ.get("OLLAMA_API_BASE", "http://localhost:11434")
|
|
return await _test_ollama_connection(test_base_url)
|
|
|
|
if normalized_provider == "openai_compatible":
|
|
# Use base_url from specific config, or environment variable
|
|
test_base_url = base_url or os.environ.get("OPENAI_COMPATIBLE_BASE_URL")
|
|
test_api_key = api_key or os.environ.get("OPENAI_COMPATIBLE_API_KEY")
|
|
if not test_base_url:
|
|
return False, "No base URL configured for OpenAI-compatible provider"
|
|
return await _test_openai_compatible_connection(test_base_url, test_api_key)
|
|
|
|
if normalized_provider == "azure":
|
|
return await _test_azure_connection(endpoint, api_key, api_version)
|
|
|
|
# Get test model for provider
|
|
if normalized_provider not in TEST_MODELS:
|
|
return False, f"Unknown provider: {provider}"
|
|
|
|
test_model, test_model_type = TEST_MODELS[normalized_provider]
|
|
|
|
# Use model from config if provided, otherwise use TEST_MODELS default
|
|
model_to_use = model_name if model_name else test_model
|
|
|
|
# For providers with dynamic model detection
|
|
if model_to_use is None:
|
|
if normalized_provider == "openai_compatible":
|
|
# OpenAI-compatible servers should already be tested via _test_openai_compatible_connection
|
|
test_base_url = base_url or os.environ.get("OPENAI_COMPATIBLE_BASE_URL", "")
|
|
test_api_key = api_key or os.environ.get("OPENAI_COMPATIBLE_API_KEY")
|
|
return await _test_openai_compatible_connection(test_base_url, test_api_key)
|
|
else:
|
|
return False, f"No test model configured for {provider}"
|
|
|
|
# If we have a specific API key, set it in environment for this test
|
|
if api_key:
|
|
os.environ[f"{provider.upper()}_API_KEY"] = api_key
|
|
|
|
# Try to create the model and make a minimal call
|
|
if test_model_type == "language":
|
|
model = AIFactory.create_language(model_name=model_to_use, provider=provider)
|
|
# Convert to LangChain and make a minimal call
|
|
lc_model = model.to_langchain()
|
|
await lc_model.ainvoke("Hi")
|
|
return True, "Connection successful"
|
|
|
|
elif test_model_type == "embedding":
|
|
model = AIFactory.create_embedding(model_name=model_to_use, provider=provider)
|
|
# Embed a single short test string
|
|
await model.aembed(["test"])
|
|
return True, "Connection successful"
|
|
|
|
elif test_model_type == "text_to_speech":
|
|
# For TTS, we just verify the model can be created
|
|
# Making an actual TTS call would be more expensive
|
|
# Most TTS providers validate the key on model creation
|
|
AIFactory.create_text_to_speech(
|
|
model_name=model_to_use, provider=provider
|
|
)
|
|
return True, "Connection successful (key format valid)"
|
|
|
|
else:
|
|
return False, f"Unsupported model type for testing: {test_model_type}"
|
|
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
|
|
# Clean up common error messages for user-friendly display
|
|
if "401" in error_msg or "unauthorized" in error_msg.lower():
|
|
return False, "Invalid API key"
|
|
elif "403" in error_msg or "forbidden" in error_msg.lower():
|
|
return False, "API key lacks required permissions"
|
|
elif "rate" in error_msg.lower() and "limit" in error_msg.lower():
|
|
# Rate limit means the key is valid but we hit limits
|
|
return True, "Rate limited - but connection works"
|
|
elif "connection" in error_msg.lower() or "network" in error_msg.lower():
|
|
return False, "Connection error - check network/endpoint"
|
|
elif "timeout" in error_msg.lower():
|
|
return False, "Connection timed out - check network/endpoint"
|
|
elif "not found" in error_msg.lower() and "model" in error_msg.lower():
|
|
# Model not found but auth worked - this is actually a success for connectivity
|
|
return True, "API key valid (test model not available)"
|
|
elif provider == "ollama" and "connection refused" in error_msg.lower():
|
|
return False, "Ollama not running - check if Ollama server is started"
|
|
else:
|
|
logger.debug(f"Test connection error for {provider}: {e}")
|
|
# Truncate long error messages
|
|
truncated = error_msg[:100] + "..." if len(error_msg) > 100 else error_msg
|
|
return False, f"Error: {truncated}"
|
|
|
|
|
|
# Default voices for TTS testing per provider
|
|
# ElevenLabs excluded: uses voice_id (not name), looked up dynamically
|
|
DEFAULT_TEST_VOICES = {
|
|
"openai": "alloy",
|
|
"azure": "alloy",
|
|
"google": "Kore",
|
|
"vertex": "Kore",
|
|
"openai_compatible": "alloy",
|
|
}
|
|
|
|
|
|
def _generate_test_wav() -> io.BytesIO:
|
|
"""Generate a minimal 0.5s silence WAV file in memory (16kHz, 16-bit mono)."""
|
|
sample_rate = 16000
|
|
num_samples = sample_rate // 2 # 0.5 seconds
|
|
bits_per_sample = 16
|
|
num_channels = 1
|
|
byte_rate = sample_rate * num_channels * bits_per_sample // 8
|
|
block_align = num_channels * bits_per_sample // 8
|
|
data_size = num_samples * block_align
|
|
|
|
buf = io.BytesIO()
|
|
# RIFF header
|
|
buf.write(b"RIFF")
|
|
buf.write(struct.pack("<I", 36 + data_size))
|
|
buf.write(b"WAVE")
|
|
# fmt chunk
|
|
buf.write(b"fmt ")
|
|
buf.write(struct.pack("<I", 16)) # chunk size
|
|
buf.write(struct.pack("<H", 1)) # PCM format
|
|
buf.write(struct.pack("<H", num_channels))
|
|
buf.write(struct.pack("<I", sample_rate))
|
|
buf.write(struct.pack("<I", byte_rate))
|
|
buf.write(struct.pack("<H", block_align))
|
|
buf.write(struct.pack("<H", bits_per_sample))
|
|
# data chunk
|
|
buf.write(b"data")
|
|
buf.write(struct.pack("<I", data_size))
|
|
buf.write(b"\x00" * data_size) # silence
|
|
|
|
buf.seek(0)
|
|
buf.name = "test.wav"
|
|
return buf
|
|
|
|
|
|
def _normalize_error_message(error_msg: str) -> Tuple[bool, str]:
|
|
"""Normalize common error patterns into user-friendly messages."""
|
|
lower = error_msg.lower()
|
|
|
|
if "401" in error_msg or "unauthorized" in lower:
|
|
return False, "Invalid API key"
|
|
elif "403" in error_msg or "forbidden" in lower:
|
|
return False, "API key lacks required permissions"
|
|
elif "rate" in lower and "limit" in lower:
|
|
return True, "Rate limited - but connection works"
|
|
elif "not found" in lower and "model" in lower:
|
|
return False, "Model not found on this provider"
|
|
elif "connection" in lower or "network" in lower:
|
|
return False, "Connection error - check network/endpoint"
|
|
elif "timeout" in lower:
|
|
return False, "Connection timed out - check network/endpoint"
|
|
|
|
return False, error_msg
|
|
|
|
|
|
async def test_individual_model(model) -> Tuple[bool, str]:
|
|
"""
|
|
Test a specific model configuration end-to-end by making a real API call.
|
|
|
|
Args:
|
|
model: A Model instance (from open_notebook.ai.models)
|
|
|
|
Returns:
|
|
Tuple of (success: bool, message: str)
|
|
"""
|
|
from open_notebook.ai.models import ModelManager
|
|
|
|
try:
|
|
manager = ModelManager()
|
|
esp_model = await manager.get_model(model.id)
|
|
|
|
if esp_model is None:
|
|
return False, "Could not create model instance"
|
|
|
|
if model.type == "language":
|
|
response = await esp_model.achat_complete(
|
|
messages=[{"role": "user", "content": "Hi!"}]
|
|
)
|
|
text = response.content[:100] if response.content else "(empty response)"
|
|
return True, f"Response: {text}"
|
|
|
|
elif model.type == "embedding":
|
|
result = await esp_model.aembed(["This is a test."])
|
|
if result and len(result) > 0:
|
|
dims = len(result[0])
|
|
return True, f"Embedding dimensions: {dims}"
|
|
return True, "Embedding successful"
|
|
|
|
elif model.type == "text_to_speech":
|
|
# For ElevenLabs, look up first available voice (API uses voice_id, not name)
|
|
voice = DEFAULT_TEST_VOICES.get(model.provider)
|
|
if not voice and hasattr(esp_model, "available_voices"):
|
|
try:
|
|
voices = esp_model.available_voices
|
|
if voices:
|
|
voice = next(iter(voices.keys()))
|
|
except Exception:
|
|
pass
|
|
if not voice:
|
|
voice = "alloy" # fallback
|
|
|
|
result = await esp_model.agenerate_speech(
|
|
text="Hello from Open Notebook", voice=voice
|
|
)
|
|
if result and hasattr(result, "content"):
|
|
size = len(result.content)
|
|
return True, f"Audio generated: {size} bytes"
|
|
return True, "Speech generation successful"
|
|
|
|
elif model.type == "speech_to_text":
|
|
audio_file = _generate_test_wav()
|
|
result = await esp_model.atranscribe(
|
|
audio_file=audio_file, language="en"
|
|
)
|
|
text = str(result.text) if hasattr(result, "text") else str(result)
|
|
return True, f"Transcription: {text[:100]}"
|
|
|
|
else:
|
|
return False, f"Unsupported model type: {model.type}"
|
|
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
success, normalized = _normalize_error_message(error_msg)
|
|
if success:
|
|
return True, normalized
|
|
logger.debug(f"Test individual model error for {model.id}: {e}")
|
|
return False, normalized
|