mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 03:20:01 +00:00
Improve deterministic error surfacing across stream and API
This commit is contained in:
parent
7f2612d2df
commit
34757511a0
17 changed files with 207 additions and 54 deletions
|
|
@ -6,6 +6,7 @@ from loguru import logger
|
|||
from config.settings import Settings
|
||||
from config.settings import get_settings as _get_settings
|
||||
from providers.base import BaseProvider, ProviderConfig
|
||||
from providers.common import get_user_facing_error_message
|
||||
from providers.exceptions import AuthenticationError
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
|
||||
|
|
@ -70,14 +71,14 @@ def _create_provider(settings: Settings) -> BaseProvider:
|
|||
provider = LMStudioProvider(config)
|
||||
else:
|
||||
logger.error(
|
||||
"Unknown provider_type: '%s'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
|
||||
"Unknown provider_type: '{}'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
|
||||
settings.provider_type,
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unknown provider_type: '{settings.provider_type}'. "
|
||||
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'"
|
||||
)
|
||||
logger.info("Provider initialized: %s", settings.provider_type)
|
||||
logger.info("Provider initialized: {}", settings.provider_type)
|
||||
return provider
|
||||
|
||||
|
||||
|
|
@ -88,7 +89,9 @@ def get_provider() -> BaseProvider:
|
|||
try:
|
||||
_provider = _create_provider(get_settings())
|
||||
except AuthenticationError as e:
|
||||
raise HTTPException(status_code=503, detail=str(e)) from e
|
||||
raise HTTPException(
|
||||
status_code=503, detail=get_user_facing_error_message(e)
|
||||
) from e
|
||||
return _provider
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from loguru import logger
|
|||
|
||||
from config.settings import Settings
|
||||
from providers.base import BaseProvider
|
||||
from providers.common import get_user_facing_error_message
|
||||
from providers.exceptions import InvalidRequestError, ProviderError
|
||||
from providers.logging_utils import build_request_summary, log_request_compact
|
||||
|
||||
|
|
@ -69,7 +70,8 @@ async def create_message(
|
|||
except Exception as e:
|
||||
logger.error(f"Error: {e!s}\n{traceback.format_exc()}")
|
||||
raise HTTPException(
|
||||
status_code=getattr(e, "status_code", 500), detail=str(e)
|
||||
status_code=getattr(e, "status_code", 500),
|
||||
detail=get_user_facing_error_message(e),
|
||||
) from e
|
||||
|
||||
|
||||
|
|
@ -85,16 +87,18 @@ async def count_tokens(request_data: TokenCountRequest):
|
|||
summary = build_request_summary(request_data)
|
||||
summary["request_id"] = request_id
|
||||
summary["input_tokens"] = tokens
|
||||
logger.info("COUNT_TOKENS: %s", json.dumps(summary))
|
||||
logger.info("COUNT_TOKENS: {}", json.dumps(summary))
|
||||
return TokenCountResponse(input_tokens=tokens)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"COUNT_TOKENS_ERROR: request_id=%s error=%s\n%s",
|
||||
"COUNT_TOKENS_ERROR: request_id={} error={}\n{}",
|
||||
request_id,
|
||||
str(e),
|
||||
get_user_facing_error_message(e),
|
||||
traceback.format_exc(),
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
raise HTTPException(
|
||||
status_code=500, detail=get_user_facing_error_message(e)
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/")
|
||||
|
|
@ -127,5 +131,5 @@ async def stop_cli(request: Request):
|
|||
raise HTTPException(status_code=503, detail="Messaging system not initialized")
|
||||
|
||||
count = await handler.stop_all_tasks()
|
||||
logger.info("STOP_CLI: source=handler cancelled_count=%d", count)
|
||||
logger.info("STOP_CLI: source=handler cancelled_count={}", count)
|
||||
return {"status": "stopped", "cancelled_count": count}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ import time
|
|||
|
||||
from loguru import logger
|
||||
|
||||
from providers.common import get_user_facing_error_message
|
||||
|
||||
from .commands import (
|
||||
handle_clear_command,
|
||||
handle_stats_command,
|
||||
|
|
@ -161,7 +163,7 @@ class ClaudeMessageHandler:
|
|||
if len(incoming.text or "") > 80:
|
||||
text_preview += "..."
|
||||
logger.info(
|
||||
"HANDLER_ENTRY: chat_id=%s message_id=%s reply_to=%s text_preview=%r",
|
||||
"HANDLER_ENTRY: chat_id={} message_id={} reply_to={} text_preview={!r}",
|
||||
incoming.chat_id,
|
||||
incoming.message_id,
|
||||
incoming.reply_to_message_id,
|
||||
|
|
@ -492,7 +494,7 @@ class ClaudeMessageHandler:
|
|||
return
|
||||
if display and display != last_displayed_text:
|
||||
logger.debug(
|
||||
"PLATFORM_EDIT: node_id=%s chat_id=%s msg_id=%s force=%s status=%r chars=%d",
|
||||
"PLATFORM_EDIT: node_id={} chat_id={} msg_id={} force={} status={!r} chars={}",
|
||||
node_id,
|
||||
chat_id,
|
||||
status_msg_id,
|
||||
|
|
@ -501,13 +503,13 @@ class ClaudeMessageHandler:
|
|||
len(display),
|
||||
)
|
||||
if os.getenv("DEBUG_TELEGRAM_EDITS") == "1":
|
||||
logger.debug("PLATFORM_EDIT_TEXT:\n%s", display)
|
||||
logger.debug("PLATFORM_EDIT_TEXT:\n{}", display)
|
||||
else:
|
||||
head = display[:500]
|
||||
tail = display[-500:] if len(display) > 500 else ""
|
||||
logger.debug("PLATFORM_EDIT_PREVIEW_HEAD:\n%s", head)
|
||||
logger.debug("PLATFORM_EDIT_PREVIEW_HEAD:\n{}", head)
|
||||
if tail:
|
||||
logger.debug("PLATFORM_EDIT_PREVIEW_TAIL:\n%s", tail)
|
||||
logger.debug("PLATFORM_EDIT_PREVIEW_TAIL:\n{}", tail)
|
||||
last_displayed_text = display
|
||||
try:
|
||||
await self.platform.queue_edit_message(
|
||||
|
|
@ -531,14 +533,17 @@ class ClaudeMessageHandler:
|
|||
else:
|
||||
captured_session_id = session_or_temp_id
|
||||
except RuntimeError as e:
|
||||
transcript.apply({"type": "error", "message": str(e)})
|
||||
error_message = get_user_facing_error_message(e)
|
||||
transcript.apply({"type": "error", "message": error_message})
|
||||
await update_ui(
|
||||
self.format_status("⏳", "Session limit reached"),
|
||||
force=True,
|
||||
)
|
||||
if tree:
|
||||
await tree.update_state(
|
||||
node_id, MessageState.ERROR, error_message=str(e)
|
||||
node_id,
|
||||
MessageState.ERROR,
|
||||
error_message=error_message,
|
||||
)
|
||||
return
|
||||
|
||||
|
|
@ -607,7 +612,7 @@ class ClaudeMessageHandler:
|
|||
logger.error(
|
||||
f"HANDLER: Task failed with exception: {type(e).__name__}: {e}"
|
||||
)
|
||||
error_msg = str(e)[:200]
|
||||
error_msg = get_user_facing_error_message(e)[:200]
|
||||
transcript.apply({"type": "error", "message": error_msg})
|
||||
await update_ui(self.format_status("💥", "Task Failed"), force=True)
|
||||
if tree:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ from typing import Any, cast
|
|||
|
||||
from loguru import logger
|
||||
|
||||
from providers.common import get_user_facing_error_message
|
||||
|
||||
from ..models import IncomingMessage
|
||||
from ..rendering.discord_markdown import format_status_discord
|
||||
from .base import MessagingPlatform
|
||||
|
|
@ -235,7 +237,7 @@ class DiscordPlatform(MessagingPlatform):
|
|||
)
|
||||
|
||||
logger.info(
|
||||
"DISCORD_VOICE: chat_id=%s message_id=%s transcribed=%r",
|
||||
"DISCORD_VOICE: chat_id={} message_id={} transcribed={!r}",
|
||||
channel_id,
|
||||
message_id,
|
||||
(transcribed[:80] + "..." if len(transcribed) > 80 else transcribed),
|
||||
|
|
@ -244,10 +246,10 @@ class DiscordPlatform(MessagingPlatform):
|
|||
await self._message_handler(incoming)
|
||||
return True
|
||||
except ValueError as e:
|
||||
await message.reply(str(e)[:200])
|
||||
await message.reply(get_user_facing_error_message(e)[:200])
|
||||
return True
|
||||
except ImportError as e:
|
||||
await message.reply(str(e)[:200])
|
||||
await message.reply(get_user_facing_error_message(e)[:200])
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Voice transcription failed: {e}")
|
||||
|
|
@ -289,7 +291,7 @@ class DiscordPlatform(MessagingPlatform):
|
|||
if len(message.content or "") > 80:
|
||||
text_preview += "..."
|
||||
logger.info(
|
||||
"DISCORD_MSG: chat_id=%s message_id=%s reply_to=%s text_preview=%r",
|
||||
"DISCORD_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}",
|
||||
channel_id,
|
||||
message_id,
|
||||
reply_to,
|
||||
|
|
@ -317,7 +319,9 @@ class DiscordPlatform(MessagingPlatform):
|
|||
with contextlib.suppress(Exception):
|
||||
await self.send_message(
|
||||
channel_id,
|
||||
format_status_discord("Error:", str(e)[:200]),
|
||||
format_status_discord(
|
||||
"Error:", get_user_facing_error_message(e)[:200]
|
||||
),
|
||||
reply_to=message_id,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ from typing import TYPE_CHECKING, Any
|
|||
|
||||
from loguru import logger
|
||||
|
||||
from providers.common import get_user_facing_error_message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from telegram import Update
|
||||
from telegram.ext import ContextTypes
|
||||
|
|
@ -508,7 +510,7 @@ class TelegramPlatform(MessagingPlatform):
|
|||
if len(update.message.text or "") > 80:
|
||||
text_preview += "..."
|
||||
logger.info(
|
||||
"TELEGRAM_MSG: chat_id=%s message_id=%s reply_to=%s text_preview=%r",
|
||||
"TELEGRAM_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}",
|
||||
chat_id,
|
||||
message_id,
|
||||
reply_to,
|
||||
|
|
@ -536,7 +538,7 @@ class TelegramPlatform(MessagingPlatform):
|
|||
with contextlib.suppress(Exception):
|
||||
await self.send_message(
|
||||
chat_id,
|
||||
f"❌ *{escape_md_v2('Error:')}* {escape_md_v2(str(e)[:200])}",
|
||||
f"❌ *{escape_md_v2('Error:')}* {escape_md_v2(get_user_facing_error_message(e)[:200])}",
|
||||
reply_to=incoming.message_id,
|
||||
message_thread_id=thread_id,
|
||||
parse_mode="MarkdownV2",
|
||||
|
|
@ -638,7 +640,7 @@ class TelegramPlatform(MessagingPlatform):
|
|||
)
|
||||
|
||||
logger.info(
|
||||
"TELEGRAM_VOICE: chat_id=%s message_id=%s transcribed=%r",
|
||||
"TELEGRAM_VOICE: chat_id={} message_id={} transcribed={!r}",
|
||||
chat_id,
|
||||
message_id,
|
||||
(transcribed[:80] + "..." if len(transcribed) > 80 else transcribed),
|
||||
|
|
@ -646,9 +648,9 @@ class TelegramPlatform(MessagingPlatform):
|
|||
|
||||
await self._message_handler(incoming)
|
||||
except ValueError as e:
|
||||
await update.message.reply_text(str(e)[:200])
|
||||
await update.message.reply_text(get_user_facing_error_message(e)[:200])
|
||||
except ImportError as e:
|
||||
await update.message.reply_text(str(e)[:200])
|
||||
await update.message.reply_text(get_user_facing_error_message(e)[:200])
|
||||
except Exception as e:
|
||||
logger.error(f"Voice transcription failed: {e}")
|
||||
await update.message.reply_text(
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ from collections.abc import Awaitable, Callable
|
|||
|
||||
from loguru import logger
|
||||
|
||||
from providers.common import get_user_facing_error_message
|
||||
|
||||
from .data import MessageNode, MessageState, MessageTree
|
||||
|
||||
|
||||
|
|
@ -83,7 +85,9 @@ class TreeQueueProcessor:
|
|||
except Exception as e:
|
||||
logger.error(f"Error processing node {node.node_id}: {e}")
|
||||
await tree.update_state(
|
||||
node.node_id, MessageState.ERROR, error_message=str(e)
|
||||
node.node_id,
|
||||
MessageState.ERROR,
|
||||
error_message=get_user_facing_error_message(e),
|
||||
)
|
||||
finally:
|
||||
tree.clear_current_node()
|
||||
|
|
|
|||
|
|
@ -39,12 +39,12 @@ class TreeRepository:
|
|||
"""Add a new tree to the repository."""
|
||||
self._trees[root_id] = tree
|
||||
self._node_to_tree[root_id] = root_id
|
||||
logger.debug("TREE_REPO: add_tree root_id=%s", root_id)
|
||||
logger.debug("TREE_REPO: add_tree root_id={}", root_id)
|
||||
|
||||
def register_node(self, node_id: str, root_id: str) -> None:
|
||||
"""Register a node ID to a tree."""
|
||||
self._node_to_tree[node_id] = root_id
|
||||
logger.debug("TREE_REPO: register_node node_id=%s root_id=%s", node_id, root_id)
|
||||
logger.debug("TREE_REPO: register_node node_id={} root_id={}", node_id, root_id)
|
||||
|
||||
def has_node(self, node_id: str) -> bool:
|
||||
"""Check if a node is registered in any tree."""
|
||||
|
|
@ -146,7 +146,7 @@ class TreeRepository:
|
|||
return None
|
||||
for node in tree.all_nodes():
|
||||
self._node_to_tree.pop(node.node_id, None)
|
||||
logger.debug("TREE_REPO: remove_tree root_id=%s", root_id)
|
||||
logger.debug("TREE_REPO: remove_tree root_id={}", root_id)
|
||||
return tree
|
||||
|
||||
def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Shared provider utilities used by NIM, OpenRouter, and LM Studio."""
|
||||
|
||||
from .error_mapping import map_error
|
||||
from .error_mapping import append_request_id, get_user_facing_error_message, map_error
|
||||
from .heuristic_tool_parser import HeuristicToolParser
|
||||
from .message_converter import (
|
||||
AnthropicToOpenAIConverter,
|
||||
|
|
@ -20,9 +20,11 @@ __all__ = [
|
|||
"HeuristicToolParser",
|
||||
"SSEBuilder",
|
||||
"ThinkTagParser",
|
||||
"append_request_id",
|
||||
"build_base_request_body",
|
||||
"get_block_attr",
|
||||
"get_block_type",
|
||||
"get_user_facing_error_message",
|
||||
"map_error",
|
||||
"map_stop_reason",
|
||||
"set_if_not_none",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Error mapping for OpenAI-compatible providers (NIM, OpenRouter, LM Studio)."""
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
|
||||
from providers.exceptions import (
|
||||
|
|
@ -7,29 +8,78 @@ from providers.exceptions import (
|
|||
AuthenticationError,
|
||||
InvalidRequestError,
|
||||
OverloadedError,
|
||||
ProviderError,
|
||||
RateLimitError,
|
||||
)
|
||||
from providers.rate_limit import GlobalRateLimiter
|
||||
|
||||
|
||||
def get_user_facing_error_message(
|
||||
e: Exception,
|
||||
*,
|
||||
read_timeout_s: float | None = None,
|
||||
) -> str:
|
||||
"""Return a readable, non-empty error message for users."""
|
||||
message = str(e).strip()
|
||||
if message:
|
||||
return message
|
||||
|
||||
if isinstance(e, httpx.ReadTimeout):
|
||||
if read_timeout_s is not None:
|
||||
return f"Provider request timed out after {read_timeout_s:g}s."
|
||||
return "Provider request timed out."
|
||||
if isinstance(e, httpx.ConnectTimeout):
|
||||
return "Could not connect to provider."
|
||||
if isinstance(e, TimeoutError):
|
||||
if read_timeout_s is not None:
|
||||
return f"Provider request timed out after {read_timeout_s:g}s."
|
||||
return "Request timed out."
|
||||
|
||||
if isinstance(e, (RateLimitError, openai.RateLimitError)):
|
||||
return "Provider rate limit reached. Please retry shortly."
|
||||
if isinstance(e, (AuthenticationError, openai.AuthenticationError)):
|
||||
return "Provider authentication failed. Check API key."
|
||||
if isinstance(e, (InvalidRequestError, openai.BadRequestError)):
|
||||
return "Invalid request sent to provider."
|
||||
if isinstance(e, OverloadedError):
|
||||
return "Provider is currently overloaded. Please retry."
|
||||
if isinstance(e, APIError):
|
||||
if e.status_code in (502, 503, 504):
|
||||
return "Provider is temporarily unavailable. Please retry."
|
||||
return "Provider API request failed."
|
||||
if isinstance(e, ProviderError):
|
||||
return "Provider request failed."
|
||||
|
||||
return "Provider request failed unexpectedly."
|
||||
|
||||
|
||||
def append_request_id(message: str, request_id: str | None) -> str:
|
||||
"""Append request_id suffix when available."""
|
||||
base = message.strip() or "Provider request failed unexpectedly."
|
||||
if request_id:
|
||||
return f"{base} (request_id={request_id})"
|
||||
return base
|
||||
|
||||
|
||||
def map_error(e: Exception) -> Exception:
|
||||
"""Map OpenAI exception to specific ProviderError."""
|
||||
message = get_user_facing_error_message(e)
|
||||
if isinstance(e, openai.AuthenticationError):
|
||||
return AuthenticationError(str(e), raw_error=str(e))
|
||||
return AuthenticationError(message, raw_error=str(e))
|
||||
if isinstance(e, openai.RateLimitError):
|
||||
# Trigger global rate limit block
|
||||
GlobalRateLimiter.get_instance().set_blocked(60) # Default 60s cooldown
|
||||
return RateLimitError(str(e), raw_error=str(e))
|
||||
return RateLimitError(message, raw_error=str(e))
|
||||
if isinstance(e, openai.BadRequestError):
|
||||
return InvalidRequestError(str(e), raw_error=str(e))
|
||||
return InvalidRequestError(message, raw_error=str(e))
|
||||
if isinstance(e, openai.InternalServerError):
|
||||
message = str(e)
|
||||
if "overloaded" in message.lower() or "capacity" in message.lower():
|
||||
return OverloadedError(message, raw_error=str(e))
|
||||
raw_message = str(e)
|
||||
if "overloaded" in raw_message.lower() or "capacity" in raw_message.lower():
|
||||
return OverloadedError(message, raw_error=raw_message)
|
||||
return APIError(message, status_code=500, raw_error=str(e))
|
||||
if isinstance(e, openai.APIError):
|
||||
return APIError(
|
||||
str(e), status_code=getattr(e, "status_code", 500), raw_error=str(e)
|
||||
message, status_code=getattr(e, "status_code", 500), raw_error=str(e)
|
||||
)
|
||||
|
||||
return e
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ class ContentBlockManager:
|
|||
except Exception as e:
|
||||
prefix = state.task_arg_buffer[:120]
|
||||
logger.warning(
|
||||
"Task args invalid JSON (id=%s len=%d prefix=%r): %s",
|
||||
"Task args invalid JSON (id={} len={} prefix={!r}): {}",
|
||||
state.tool_id or "unknown",
|
||||
len(state.task_arg_buffer),
|
||||
prefix,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ LMSTUDIO_DEFAULT_MAX_TOKENS = 81920
|
|||
def build_request_body(request_data: Any) -> dict:
|
||||
"""Build OpenAI-format request body from Anthropic request for LM Studio."""
|
||||
logger.debug(
|
||||
"LMSTUDIO_REQUEST: conversion start model=%s msgs=%d",
|
||||
"LMSTUDIO_REQUEST: conversion start model={} msgs={}",
|
||||
getattr(request_data, "model", "?"),
|
||||
len(getattr(request_data, "messages", [])),
|
||||
)
|
||||
|
|
@ -21,7 +21,7 @@ def build_request_body(request_data: Any) -> dict:
|
|||
)
|
||||
|
||||
logger.debug(
|
||||
"LMSTUDIO_REQUEST: conversion done model=%s msgs=%d tools=%d",
|
||||
"LMSTUDIO_REQUEST: conversion done model={} msgs={} tools={}",
|
||||
body.get("model"),
|
||||
len(body.get("messages", [])),
|
||||
len(body.get("tools", [])),
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ def _set_extra(
|
|||
def build_request_body(request_data: Any, nim: NimSettings) -> dict:
|
||||
"""Build OpenAI-format request body from Anthropic request."""
|
||||
logger.debug(
|
||||
"NIM_REQUEST: conversion start model=%s msgs=%d",
|
||||
"NIM_REQUEST: conversion start model={} msgs={}",
|
||||
getattr(request_data, "model", "?"),
|
||||
len(getattr(request_data, "messages", [])),
|
||||
)
|
||||
|
|
@ -96,7 +96,7 @@ def build_request_body(request_data: Any, nim: NimSettings) -> dict:
|
|||
body["extra_body"] = extra_body
|
||||
|
||||
logger.debug(
|
||||
"NIM_REQUEST: conversion done model=%s msgs=%d tools=%d",
|
||||
"NIM_REQUEST: conversion done model={} msgs={} tools={}",
|
||||
body.get("model"),
|
||||
len(body.get("messages", [])),
|
||||
len(body.get("tools", [])),
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ OPENROUTER_DEFAULT_MAX_TOKENS = 81920
|
|||
def build_request_body(request_data: Any) -> dict:
|
||||
"""Build OpenAI-format request body from Anthropic request for OpenRouter."""
|
||||
logger.debug(
|
||||
"OPENROUTER_REQUEST: conversion start model=%s msgs=%d",
|
||||
"OPENROUTER_REQUEST: conversion start model={} msgs={}",
|
||||
getattr(request_data, "model", "?"),
|
||||
len(getattr(request_data, "messages", [])),
|
||||
)
|
||||
|
|
@ -39,7 +39,7 @@ def build_request_body(request_data: Any) -> dict:
|
|||
body["extra_body"] = extra_body
|
||||
|
||||
logger.debug(
|
||||
"OPENROUTER_REQUEST: conversion done model=%s msgs=%d tools=%d",
|
||||
"OPENROUTER_REQUEST: conversion done model={} msgs={} tools={}",
|
||||
body.get("model"),
|
||||
len(body.get("messages", [])),
|
||||
len(body.get("tools", [])),
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ from providers.common import (
|
|||
HeuristicToolParser,
|
||||
SSEBuilder,
|
||||
ThinkTagParser,
|
||||
append_request_id,
|
||||
get_user_facing_error_message,
|
||||
map_error,
|
||||
map_stop_reason,
|
||||
)
|
||||
|
|
@ -139,7 +141,7 @@ class OpenAICompatibleProvider(BaseProvider):
|
|||
body = self._build_request_body(request)
|
||||
req_tag = f" request_id={request_id}" if request_id else ""
|
||||
logger.info(
|
||||
"%s_STREAM:%s model=%s msgs=%d tools=%d",
|
||||
"{}_STREAM:{} model={} msgs={} tools={}",
|
||||
tag,
|
||||
req_tag,
|
||||
body.get("model"),
|
||||
|
|
@ -176,7 +178,7 @@ class OpenAICompatibleProvider(BaseProvider):
|
|||
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
logger.debug("%s finish_reason: %s", tag, finish_reason)
|
||||
logger.debug("{} finish_reason: {}", tag, finish_reason)
|
||||
|
||||
# Handle reasoning_content (OpenAI extended format)
|
||||
reasoning = getattr(delta, "reasoning_content", None)
|
||||
|
|
@ -245,12 +247,17 @@ class OpenAICompatibleProvider(BaseProvider):
|
|||
yield event
|
||||
|
||||
except Exception as e:
|
||||
logger.error("%s_ERROR:%s %s: %s", tag, req_tag, type(e).__name__, e)
|
||||
logger.error("{}_ERROR:{} {}: {}", tag, req_tag, type(e).__name__, e)
|
||||
mapped_e = map_error(e)
|
||||
error_occurred = True
|
||||
error_message = str(mapped_e)
|
||||
error_message = append_request_id(
|
||||
get_user_facing_error_message(
|
||||
mapped_e, read_timeout_s=self._config.http_read_timeout
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
logger.info(
|
||||
"%s_STREAM: Emitting SSE error event for %s%s",
|
||||
"{}_STREAM: Emitting SSE error event for {}{}",
|
||||
tag,
|
||||
type(e).__name__,
|
||||
req_tag,
|
||||
|
|
@ -318,7 +325,7 @@ class OpenAICompatibleProvider(BaseProvider):
|
|||
provider_input = usage_info.prompt_tokens
|
||||
if isinstance(provider_input, int):
|
||||
logger.debug(
|
||||
"TOKEN_ESTIMATE: our=%d provider=%d diff=%+d",
|
||||
"TOKEN_ESTIMATE: our={} provider={} diff={:+d}",
|
||||
input_tokens,
|
||||
provider_input,
|
||||
provider_input - input_tokens,
|
||||
|
|
|
|||
|
|
@ -164,6 +164,31 @@ def test_generic_exception_with_status_code():
|
|||
mock_provider.stream_response = _mock_stream_response
|
||||
|
||||
|
||||
def test_generic_exception_empty_message_returns_non_empty_detail():
|
||||
"""Exceptions with empty __str__ still return a readable HTTP detail."""
|
||||
|
||||
class SilentError(RuntimeError):
|
||||
def __str__(self):
|
||||
return ""
|
||||
|
||||
def _raise_silent(*args, **kwargs):
|
||||
raise SilentError()
|
||||
|
||||
mock_provider.stream_response = _raise_silent
|
||||
response = client.post(
|
||||
"/v1/messages",
|
||||
json={
|
||||
"model": "test",
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"max_tokens": 10,
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 500
|
||||
assert response.json()["detail"] != ""
|
||||
mock_provider.stream_response = _mock_stream_response
|
||||
|
||||
|
||||
def test_count_tokens_endpoint():
|
||||
"""count_tokens endpoint returns token count."""
|
||||
response = client.post(
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import openai
|
||||
import pytest
|
||||
from httpx import Request, Response
|
||||
from httpx import ReadTimeout, Request, Response
|
||||
|
||||
from providers.common import map_error
|
||||
from providers.common import append_request_id, get_user_facing_error_message, map_error
|
||||
from providers.exceptions import (
|
||||
APIError,
|
||||
AuthenticationError,
|
||||
|
|
@ -120,3 +120,16 @@ class TestMapError:
|
|||
with patch("providers.common.error_mapping.GlobalRateLimiter"):
|
||||
result = map_error(exc)
|
||||
assert isinstance(result, expected_cls)
|
||||
|
||||
|
||||
def test_user_facing_message_read_timeout_empty_string():
|
||||
"""ReadTimeout wrapping TimeoutError should still produce readable text."""
|
||||
timeout_exc = ReadTimeout("")
|
||||
message = get_user_facing_error_message(timeout_exc, read_timeout_s=60)
|
||||
assert message == "Provider request timed out after 60s."
|
||||
|
||||
|
||||
def test_append_request_id_suffix():
|
||||
"""Request id suffix should be appended deterministically."""
|
||||
message = append_request_id("Provider request failed.", "req_abc123")
|
||||
assert message == "Provider request failed. (request_id=req_abc123)"
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from config.nim import NimSettings
|
||||
|
|
@ -116,6 +117,39 @@ class TestStreamingExceptionHandling:
|
|||
assert "API failed" in event_text
|
||||
assert "message_stop" in event_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_timeout_with_empty_message_emits_fallback(self):
|
||||
"""ReadTimeout(TimeoutError()) should emit a visible, non-empty timeout message."""
|
||||
provider = _make_provider()
|
||||
request = _make_request()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
provider._client.chat.completions,
|
||||
"create",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=httpx.ReadTimeout(""),
|
||||
),
|
||||
patch.object(
|
||||
provider._global_rate_limiter,
|
||||
"wait_if_blocked",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
events = [
|
||||
e
|
||||
async for e in provider.stream_response(
|
||||
request,
|
||||
request_id="req_timeout123",
|
||||
)
|
||||
]
|
||||
|
||||
event_text = "".join(events)
|
||||
assert "timed out after" in event_text
|
||||
assert "request_id=req_timeout123" in event_text
|
||||
assert "message_stop" in event_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_after_partial_content(self):
|
||||
"""Error after partial content: blocks closed, error emitted."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue