open-notebook/open_notebook/domain/credential.py
Luis Novo 0c2522074d fix: narrow exception handling and support migrate_to for broken credentials
- Catch only ValueError (decryption errors) instead of broad Exception
  so NotFoundError and other failures propagate correctly
- Support migrate_to parameter in the fallback delete path so linked
  models can be reassigned instead of always cascade-deleted
- Sanitize decryption_error message to not expose raw exception details
2026-04-14 10:34:32 -03:00

231 lines
8.7 KiB
Python

"""
Credential domain model for storing individual provider credentials.
Each credential is a standalone record in the 'credential' table, replacing
the old ProviderConfig singleton. Credentials store API keys (encrypted at
rest) and provider-specific configuration fields.
Usage:
cred = Credential(
name="Production",
provider="openai",
modalities=["language", "embedding"],
api_key=SecretStr("sk-..."),
)
await cred.save()
"""
from datetime import datetime
from typing import Any, ClassVar, Dict, List, Optional
from loguru import logger
from pydantic import SecretStr
from open_notebook.database.repository import ensure_record_id, repo_query
from open_notebook.domain.base import ObjectModel
from open_notebook.utils.encryption import decrypt_value, encrypt_value
class Credential(ObjectModel):
"""
Individual credential record for an AI provider.
Each record stores authentication and configuration for a single provider
account. Models link to credentials via the credential field.
"""
table_name: ClassVar[str] = "credential"
nullable_fields: ClassVar[set[str]] = {
"api_key",
"base_url",
"endpoint",
"api_version",
"endpoint_llm",
"endpoint_embedding",
"endpoint_stt",
"endpoint_tts",
"project",
"location",
"credentials_path",
}
name: str
provider: str
modalities: List[str] = []
api_key: Optional[SecretStr] = None
decryption_error: Optional[str] = None
base_url: Optional[str] = None
endpoint: Optional[str] = None
api_version: Optional[str] = None
endpoint_llm: Optional[str] = None
endpoint_embedding: Optional[str] = None
endpoint_stt: Optional[str] = None
endpoint_tts: Optional[str] = None
project: Optional[str] = None
location: Optional[str] = None
credentials_path: Optional[str] = None
def to_esperanto_config(self) -> Dict[str, Any]:
"""
Build config dict for AIFactory.create_*() calls.
Returns a dict that can be passed as the 'config' parameter to
Esperanto's AIFactory methods, overriding env var lookup.
"""
config: Dict[str, Any] = {}
if self.api_key:
config["api_key"] = self.api_key.get_secret_value()
if self.base_url:
config["base_url"] = self.base_url
if self.endpoint:
config["endpoint"] = self.endpoint
if self.api_version:
config["api_version"] = self.api_version
if self.endpoint_llm:
config["endpoint_llm"] = self.endpoint_llm
if self.endpoint_embedding:
config["endpoint_embedding"] = self.endpoint_embedding
if self.endpoint_stt:
config["endpoint_stt"] = self.endpoint_stt
if self.endpoint_tts:
config["endpoint_tts"] = self.endpoint_tts
if self.project:
config["project"] = self.project
if self.location:
config["location"] = self.location
if self.credentials_path:
config["credentials_path"] = self.credentials_path
return config
@classmethod
async def get_by_provider(cls, provider: str) -> List["Credential"]:
"""Get all credentials for a provider."""
results = await repo_query(
"SELECT * FROM credential WHERE string::lowercase(provider) = string::lowercase($provider) ORDER BY created ASC",
{"provider": provider},
)
credentials = []
for row in results:
try:
cred = cls._from_db_row(row)
credentials.append(cred)
except Exception as e:
logger.warning(f"Skipping invalid credential: {e}")
return credentials
@classmethod
async def get(cls, id: str) -> "Credential":
"""Override get() to handle api_key decryption."""
instance = await super().get(id)
# Pydantic auto-wraps the raw DB string in SecretStr, so we need
# to extract, decrypt, and re-wrap regardless of type.
if instance.api_key:
raw = (
instance.api_key.get_secret_value()
if isinstance(instance.api_key, SecretStr)
else instance.api_key
)
decrypted = decrypt_value(raw)
object.__setattr__(instance, "api_key", SecretStr(decrypted))
return instance
@classmethod
async def get_all(cls, order_by=None) -> List["Credential"]:
"""Override get_all() to handle api_key decryption with per-row error handling."""
order_clause = f" ORDER BY {order_by}" if order_by else ""
results = await repo_query(
f"SELECT * FROM {cls.table_name}{order_clause}",
{},
)
credentials = []
for row in results:
try:
cred = cls._from_db_row(row)
credentials.append(cred)
except Exception as e:
logger.warning(
f"Failed to decrypt credential {row.get('id', 'unknown')}: {e}"
)
# Create a minimal credential with error info from raw DB fields
try:
error_cred = cls(
name=row.get("name", "Unknown"),
provider=row.get("provider", "unknown"),
modalities=row.get("modalities", []),
decryption_error="Failed to decrypt API key. The encryption key may have changed.",
)
# Preserve the DB id, created, updated from the raw row
if row.get("id"):
object.__setattr__(error_cred, "id", str(row["id"]))
if row.get("created"):
object.__setattr__(error_cred, "created", row["created"])
if row.get("updated"):
object.__setattr__(error_cred, "updated", row["updated"])
# Mark that it had an api_key (even though we can't decrypt it)
if row.get("api_key"):
object.__setattr__(
error_cred, "api_key", SecretStr("UNDECRYPTABLE")
)
credentials.append(error_cred)
except Exception as inner_e:
logger.error(
f"Failed to create error credential for {row.get('id', 'unknown')}: {inner_e}"
)
return credentials
async def get_linked_models(self) -> list:
"""Get all models linked to this credential."""
if not self.id:
return []
from open_notebook.ai.models import Model
results = await repo_query(
"SELECT * FROM model WHERE credential = $cred_id",
{"cred_id": ensure_record_id(self.id)},
)
return [Model(**row) for row in results]
def _prepare_save_data(self) -> Dict[str, Any]:
"""Override to encrypt api_key before storage."""
data = {}
for key, value in self.model_dump().items():
if key == "decryption_error":
continue
if key == "api_key":
# Handle SecretStr: extract, encrypt, store
if self.api_key:
secret_value = self.api_key.get_secret_value()
data["api_key"] = encrypt_value(secret_value)
else:
data["api_key"] = None
elif value is not None or key in self.__class__.nullable_fields:
data[key] = value
return data
async def save(self) -> None:
"""Save credential, handling api_key re-hydration after DB round-trip."""
# Remember the original SecretStr before save
original_api_key = self.api_key
await super().save()
# After save, the api_key field may be set to the encrypted string
# from the DB result. Restore the original SecretStr.
if original_api_key:
object.__setattr__(self, "api_key", original_api_key)
elif self.api_key and isinstance(self.api_key, str):
# Decrypt if DB returned an encrypted string
decrypted = decrypt_value(self.api_key)
object.__setattr__(self, "api_key", SecretStr(decrypted))
@classmethod
def _from_db_row(cls, row: dict) -> "Credential":
"""Create a Credential from a database row, decrypting api_key."""
api_key_val = row.get("api_key")
if api_key_val and isinstance(api_key_val, str):
decrypted = decrypt_value(api_key_val)
row["api_key"] = SecretStr(decrypted)
elif api_key_val is None:
row["api_key"] = None
return cls(**row)