Improve deterministic error surfacing across stream and API

This commit is contained in:
Alishahryar1 2026-03-01 01:32:52 -08:00
parent 7f2612d2df
commit 34757511a0
17 changed files with 207 additions and 54 deletions

View file

@ -6,6 +6,7 @@ from loguru import logger
from config.settings import Settings
from config.settings import get_settings as _get_settings
from providers.base import BaseProvider, ProviderConfig
from providers.common import get_user_facing_error_message
from providers.exceptions import AuthenticationError
from providers.lmstudio import LMStudioProvider
from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
@ -70,14 +71,14 @@ def _create_provider(settings: Settings) -> BaseProvider:
provider = LMStudioProvider(config)
else:
logger.error(
"Unknown provider_type: '%s'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
"Unknown provider_type: '{}'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
settings.provider_type,
)
raise ValueError(
f"Unknown provider_type: '{settings.provider_type}'. "
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'"
)
logger.info("Provider initialized: %s", settings.provider_type)
logger.info("Provider initialized: {}", settings.provider_type)
return provider
@ -88,7 +89,9 @@ def get_provider() -> BaseProvider:
try:
_provider = _create_provider(get_settings())
except AuthenticationError as e:
raise HTTPException(status_code=503, detail=str(e)) from e
raise HTTPException(
status_code=503, detail=get_user_facing_error_message(e)
) from e
return _provider

View file

@ -10,6 +10,7 @@ from loguru import logger
from config.settings import Settings
from providers.base import BaseProvider
from providers.common import get_user_facing_error_message
from providers.exceptions import InvalidRequestError, ProviderError
from providers.logging_utils import build_request_summary, log_request_compact
@ -69,7 +70,8 @@ async def create_message(
except Exception as e:
logger.error(f"Error: {e!s}\n{traceback.format_exc()}")
raise HTTPException(
status_code=getattr(e, "status_code", 500), detail=str(e)
status_code=getattr(e, "status_code", 500),
detail=get_user_facing_error_message(e),
) from e
@ -85,16 +87,18 @@ async def count_tokens(request_data: TokenCountRequest):
summary = build_request_summary(request_data)
summary["request_id"] = request_id
summary["input_tokens"] = tokens
logger.info("COUNT_TOKENS: %s", json.dumps(summary))
logger.info("COUNT_TOKENS: {}", json.dumps(summary))
return TokenCountResponse(input_tokens=tokens)
except Exception as e:
logger.error(
"COUNT_TOKENS_ERROR: request_id=%s error=%s\n%s",
"COUNT_TOKENS_ERROR: request_id={} error={}\n{}",
request_id,
str(e),
get_user_facing_error_message(e),
traceback.format_exc(),
)
raise HTTPException(status_code=500, detail=str(e)) from e
raise HTTPException(
status_code=500, detail=get_user_facing_error_message(e)
) from e
@router.get("/")
@ -127,5 +131,5 @@ async def stop_cli(request: Request):
raise HTTPException(status_code=503, detail="Messaging system not initialized")
count = await handler.stop_all_tasks()
logger.info("STOP_CLI: source=handler cancelled_count=%d", count)
logger.info("STOP_CLI: source=handler cancelled_count={}", count)
return {"status": "stopped", "cancelled_count": count}

View file

@ -12,6 +12,8 @@ import time
from loguru import logger
from providers.common import get_user_facing_error_message
from .commands import (
handle_clear_command,
handle_stats_command,
@ -161,7 +163,7 @@ class ClaudeMessageHandler:
if len(incoming.text or "") > 80:
text_preview += "..."
logger.info(
"HANDLER_ENTRY: chat_id=%s message_id=%s reply_to=%s text_preview=%r",
"HANDLER_ENTRY: chat_id={} message_id={} reply_to={} text_preview={!r}",
incoming.chat_id,
incoming.message_id,
incoming.reply_to_message_id,
@ -492,7 +494,7 @@ class ClaudeMessageHandler:
return
if display and display != last_displayed_text:
logger.debug(
"PLATFORM_EDIT: node_id=%s chat_id=%s msg_id=%s force=%s status=%r chars=%d",
"PLATFORM_EDIT: node_id={} chat_id={} msg_id={} force={} status={!r} chars={}",
node_id,
chat_id,
status_msg_id,
@ -501,13 +503,13 @@ class ClaudeMessageHandler:
len(display),
)
if os.getenv("DEBUG_TELEGRAM_EDITS") == "1":
logger.debug("PLATFORM_EDIT_TEXT:\n%s", display)
logger.debug("PLATFORM_EDIT_TEXT:\n{}", display)
else:
head = display[:500]
tail = display[-500:] if len(display) > 500 else ""
logger.debug("PLATFORM_EDIT_PREVIEW_HEAD:\n%s", head)
logger.debug("PLATFORM_EDIT_PREVIEW_HEAD:\n{}", head)
if tail:
logger.debug("PLATFORM_EDIT_PREVIEW_TAIL:\n%s", tail)
logger.debug("PLATFORM_EDIT_PREVIEW_TAIL:\n{}", tail)
last_displayed_text = display
try:
await self.platform.queue_edit_message(
@ -531,14 +533,17 @@ class ClaudeMessageHandler:
else:
captured_session_id = session_or_temp_id
except RuntimeError as e:
transcript.apply({"type": "error", "message": str(e)})
error_message = get_user_facing_error_message(e)
transcript.apply({"type": "error", "message": error_message})
await update_ui(
self.format_status("", "Session limit reached"),
force=True,
)
if tree:
await tree.update_state(
node_id, MessageState.ERROR, error_message=str(e)
node_id,
MessageState.ERROR,
error_message=error_message,
)
return
@ -607,7 +612,7 @@ class ClaudeMessageHandler:
logger.error(
f"HANDLER: Task failed with exception: {type(e).__name__}: {e}"
)
error_msg = str(e)[:200]
error_msg = get_user_facing_error_message(e)[:200]
transcript.apply({"type": "error", "message": error_msg})
await update_ui(self.format_status("💥", "Task Failed"), force=True)
if tree:

View file

@ -14,6 +14,8 @@ from typing import Any, cast
from loguru import logger
from providers.common import get_user_facing_error_message
from ..models import IncomingMessage
from ..rendering.discord_markdown import format_status_discord
from .base import MessagingPlatform
@ -235,7 +237,7 @@ class DiscordPlatform(MessagingPlatform):
)
logger.info(
"DISCORD_VOICE: chat_id=%s message_id=%s transcribed=%r",
"DISCORD_VOICE: chat_id={} message_id={} transcribed={!r}",
channel_id,
message_id,
(transcribed[:80] + "..." if len(transcribed) > 80 else transcribed),
@ -244,10 +246,10 @@ class DiscordPlatform(MessagingPlatform):
await self._message_handler(incoming)
return True
except ValueError as e:
await message.reply(str(e)[:200])
await message.reply(get_user_facing_error_message(e)[:200])
return True
except ImportError as e:
await message.reply(str(e)[:200])
await message.reply(get_user_facing_error_message(e)[:200])
return True
except Exception as e:
logger.error(f"Voice transcription failed: {e}")
@ -289,7 +291,7 @@ class DiscordPlatform(MessagingPlatform):
if len(message.content or "") > 80:
text_preview += "..."
logger.info(
"DISCORD_MSG: chat_id=%s message_id=%s reply_to=%s text_preview=%r",
"DISCORD_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}",
channel_id,
message_id,
reply_to,
@ -317,7 +319,9 @@ class DiscordPlatform(MessagingPlatform):
with contextlib.suppress(Exception):
await self.send_message(
channel_id,
format_status_discord("Error:", str(e)[:200]),
format_status_discord(
"Error:", get_user_facing_error_message(e)[:200]
),
reply_to=message_id,
)

View file

@ -19,6 +19,8 @@ from typing import TYPE_CHECKING, Any
from loguru import logger
from providers.common import get_user_facing_error_message
if TYPE_CHECKING:
from telegram import Update
from telegram.ext import ContextTypes
@ -508,7 +510,7 @@ class TelegramPlatform(MessagingPlatform):
if len(update.message.text or "") > 80:
text_preview += "..."
logger.info(
"TELEGRAM_MSG: chat_id=%s message_id=%s reply_to=%s text_preview=%r",
"TELEGRAM_MSG: chat_id={} message_id={} reply_to={} text_preview={!r}",
chat_id,
message_id,
reply_to,
@ -536,7 +538,7 @@ class TelegramPlatform(MessagingPlatform):
with contextlib.suppress(Exception):
await self.send_message(
chat_id,
f"❌ *{escape_md_v2('Error:')}* {escape_md_v2(str(e)[:200])}",
f"❌ *{escape_md_v2('Error:')}* {escape_md_v2(get_user_facing_error_message(e)[:200])}",
reply_to=incoming.message_id,
message_thread_id=thread_id,
parse_mode="MarkdownV2",
@ -638,7 +640,7 @@ class TelegramPlatform(MessagingPlatform):
)
logger.info(
"TELEGRAM_VOICE: chat_id=%s message_id=%s transcribed=%r",
"TELEGRAM_VOICE: chat_id={} message_id={} transcribed={!r}",
chat_id,
message_id,
(transcribed[:80] + "..." if len(transcribed) > 80 else transcribed),
@ -646,9 +648,9 @@ class TelegramPlatform(MessagingPlatform):
await self._message_handler(incoming)
except ValueError as e:
await update.message.reply_text(str(e)[:200])
await update.message.reply_text(get_user_facing_error_message(e)[:200])
except ImportError as e:
await update.message.reply_text(str(e)[:200])
await update.message.reply_text(get_user_facing_error_message(e)[:200])
except Exception as e:
logger.error(f"Voice transcription failed: {e}")
await update.message.reply_text(

View file

@ -8,6 +8,8 @@ from collections.abc import Awaitable, Callable
from loguru import logger
from providers.common import get_user_facing_error_message
from .data import MessageNode, MessageState, MessageTree
@ -83,7 +85,9 @@ class TreeQueueProcessor:
except Exception as e:
logger.error(f"Error processing node {node.node_id}: {e}")
await tree.update_state(
node.node_id, MessageState.ERROR, error_message=str(e)
node.node_id,
MessageState.ERROR,
error_message=get_user_facing_error_message(e),
)
finally:
tree.clear_current_node()

View file

@ -39,12 +39,12 @@ class TreeRepository:
"""Add a new tree to the repository."""
self._trees[root_id] = tree
self._node_to_tree[root_id] = root_id
logger.debug("TREE_REPO: add_tree root_id=%s", root_id)
logger.debug("TREE_REPO: add_tree root_id={}", root_id)
def register_node(self, node_id: str, root_id: str) -> None:
"""Register a node ID to a tree."""
self._node_to_tree[node_id] = root_id
logger.debug("TREE_REPO: register_node node_id=%s root_id=%s", node_id, root_id)
logger.debug("TREE_REPO: register_node node_id={} root_id={}", node_id, root_id)
def has_node(self, node_id: str) -> bool:
"""Check if a node is registered in any tree."""
@ -146,7 +146,7 @@ class TreeRepository:
return None
for node in tree.all_nodes():
self._node_to_tree.pop(node.node_id, None)
logger.debug("TREE_REPO: remove_tree root_id=%s", root_id)
logger.debug("TREE_REPO: remove_tree root_id={}", root_id)
return tree
def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]:

View file

@ -1,6 +1,6 @@
"""Shared provider utilities used by NIM, OpenRouter, and LM Studio."""
from .error_mapping import map_error
from .error_mapping import append_request_id, get_user_facing_error_message, map_error
from .heuristic_tool_parser import HeuristicToolParser
from .message_converter import (
AnthropicToOpenAIConverter,
@ -20,9 +20,11 @@ __all__ = [
"HeuristicToolParser",
"SSEBuilder",
"ThinkTagParser",
"append_request_id",
"build_base_request_body",
"get_block_attr",
"get_block_type",
"get_user_facing_error_message",
"map_error",
"map_stop_reason",
"set_if_not_none",

View file

@ -1,5 +1,6 @@
"""Error mapping for OpenAI-compatible providers (NIM, OpenRouter, LM Studio)."""
import httpx
import openai
from providers.exceptions import (
@ -7,29 +8,78 @@ from providers.exceptions import (
AuthenticationError,
InvalidRequestError,
OverloadedError,
ProviderError,
RateLimitError,
)
from providers.rate_limit import GlobalRateLimiter
def get_user_facing_error_message(
e: Exception,
*,
read_timeout_s: float | None = None,
) -> str:
"""Return a readable, non-empty error message for users."""
message = str(e).strip()
if message:
return message
if isinstance(e, httpx.ReadTimeout):
if read_timeout_s is not None:
return f"Provider request timed out after {read_timeout_s:g}s."
return "Provider request timed out."
if isinstance(e, httpx.ConnectTimeout):
return "Could not connect to provider."
if isinstance(e, TimeoutError):
if read_timeout_s is not None:
return f"Provider request timed out after {read_timeout_s:g}s."
return "Request timed out."
if isinstance(e, (RateLimitError, openai.RateLimitError)):
return "Provider rate limit reached. Please retry shortly."
if isinstance(e, (AuthenticationError, openai.AuthenticationError)):
return "Provider authentication failed. Check API key."
if isinstance(e, (InvalidRequestError, openai.BadRequestError)):
return "Invalid request sent to provider."
if isinstance(e, OverloadedError):
return "Provider is currently overloaded. Please retry."
if isinstance(e, APIError):
if e.status_code in (502, 503, 504):
return "Provider is temporarily unavailable. Please retry."
return "Provider API request failed."
if isinstance(e, ProviderError):
return "Provider request failed."
return "Provider request failed unexpectedly."
def append_request_id(message: str, request_id: str | None) -> str:
"""Append request_id suffix when available."""
base = message.strip() or "Provider request failed unexpectedly."
if request_id:
return f"{base} (request_id={request_id})"
return base
def map_error(e: Exception) -> Exception:
"""Map OpenAI exception to specific ProviderError."""
message = get_user_facing_error_message(e)
if isinstance(e, openai.AuthenticationError):
return AuthenticationError(str(e), raw_error=str(e))
return AuthenticationError(message, raw_error=str(e))
if isinstance(e, openai.RateLimitError):
# Trigger global rate limit block
GlobalRateLimiter.get_instance().set_blocked(60) # Default 60s cooldown
return RateLimitError(str(e), raw_error=str(e))
return RateLimitError(message, raw_error=str(e))
if isinstance(e, openai.BadRequestError):
return InvalidRequestError(str(e), raw_error=str(e))
return InvalidRequestError(message, raw_error=str(e))
if isinstance(e, openai.InternalServerError):
message = str(e)
if "overloaded" in message.lower() or "capacity" in message.lower():
return OverloadedError(message, raw_error=str(e))
raw_message = str(e)
if "overloaded" in raw_message.lower() or "capacity" in raw_message.lower():
return OverloadedError(message, raw_error=raw_message)
return APIError(message, status_code=500, raw_error=str(e))
if isinstance(e, openai.APIError):
return APIError(
str(e), status_code=getattr(e, "status_code", 500), raw_error=str(e)
message, status_code=getattr(e, "status_code", 500), raw_error=str(e)
)
return e

View file

@ -118,7 +118,7 @@ class ContentBlockManager:
except Exception as e:
prefix = state.task_arg_buffer[:120]
logger.warning(
"Task args invalid JSON (id=%s len=%d prefix=%r): %s",
"Task args invalid JSON (id={} len={} prefix={!r}): {}",
state.tool_id or "unknown",
len(state.task_arg_buffer),
prefix,

View file

@ -12,7 +12,7 @@ LMSTUDIO_DEFAULT_MAX_TOKENS = 81920
def build_request_body(request_data: Any) -> dict:
"""Build OpenAI-format request body from Anthropic request for LM Studio."""
logger.debug(
"LMSTUDIO_REQUEST: conversion start model=%s msgs=%d",
"LMSTUDIO_REQUEST: conversion start model={} msgs={}",
getattr(request_data, "model", "?"),
len(getattr(request_data, "messages", [])),
)
@ -21,7 +21,7 @@ def build_request_body(request_data: Any) -> dict:
)
logger.debug(
"LMSTUDIO_REQUEST: conversion done model=%s msgs=%d tools=%d",
"LMSTUDIO_REQUEST: conversion done model={} msgs={} tools={}",
body.get("model"),
len(body.get("messages", [])),
len(body.get("tools", [])),

View file

@ -24,7 +24,7 @@ def _set_extra(
def build_request_body(request_data: Any, nim: NimSettings) -> dict:
"""Build OpenAI-format request body from Anthropic request."""
logger.debug(
"NIM_REQUEST: conversion start model=%s msgs=%d",
"NIM_REQUEST: conversion start model={} msgs={}",
getattr(request_data, "model", "?"),
len(getattr(request_data, "messages", [])),
)
@ -96,7 +96,7 @@ def build_request_body(request_data: Any, nim: NimSettings) -> dict:
body["extra_body"] = extra_body
logger.debug(
"NIM_REQUEST: conversion done model=%s msgs=%d tools=%d",
"NIM_REQUEST: conversion done model={} msgs={} tools={}",
body.get("model"),
len(body.get("messages", [])),
len(body.get("tools", [])),

View file

@ -12,7 +12,7 @@ OPENROUTER_DEFAULT_MAX_TOKENS = 81920
def build_request_body(request_data: Any) -> dict:
"""Build OpenAI-format request body from Anthropic request for OpenRouter."""
logger.debug(
"OPENROUTER_REQUEST: conversion start model=%s msgs=%d",
"OPENROUTER_REQUEST: conversion start model={} msgs={}",
getattr(request_data, "model", "?"),
len(getattr(request_data, "messages", [])),
)
@ -39,7 +39,7 @@ def build_request_body(request_data: Any) -> dict:
body["extra_body"] = extra_body
logger.debug(
"OPENROUTER_REQUEST: conversion done model=%s msgs=%d tools=%d",
"OPENROUTER_REQUEST: conversion done model={} msgs={} tools={}",
body.get("model"),
len(body.get("messages", [])),
len(body.get("tools", [])),

View file

@ -16,6 +16,8 @@ from providers.common import (
HeuristicToolParser,
SSEBuilder,
ThinkTagParser,
append_request_id,
get_user_facing_error_message,
map_error,
map_stop_reason,
)
@ -139,7 +141,7 @@ class OpenAICompatibleProvider(BaseProvider):
body = self._build_request_body(request)
req_tag = f" request_id={request_id}" if request_id else ""
logger.info(
"%s_STREAM:%s model=%s msgs=%d tools=%d",
"{}_STREAM:{} model={} msgs={} tools={}",
tag,
req_tag,
body.get("model"),
@ -176,7 +178,7 @@ class OpenAICompatibleProvider(BaseProvider):
if choice.finish_reason:
finish_reason = choice.finish_reason
logger.debug("%s finish_reason: %s", tag, finish_reason)
logger.debug("{} finish_reason: {}", tag, finish_reason)
# Handle reasoning_content (OpenAI extended format)
reasoning = getattr(delta, "reasoning_content", None)
@ -245,12 +247,17 @@ class OpenAICompatibleProvider(BaseProvider):
yield event
except Exception as e:
logger.error("%s_ERROR:%s %s: %s", tag, req_tag, type(e).__name__, e)
logger.error("{}_ERROR:{} {}: {}", tag, req_tag, type(e).__name__, e)
mapped_e = map_error(e)
error_occurred = True
error_message = str(mapped_e)
error_message = append_request_id(
get_user_facing_error_message(
mapped_e, read_timeout_s=self._config.http_read_timeout
),
request_id,
)
logger.info(
"%s_STREAM: Emitting SSE error event for %s%s",
"{}_STREAM: Emitting SSE error event for {}{}",
tag,
type(e).__name__,
req_tag,
@ -318,7 +325,7 @@ class OpenAICompatibleProvider(BaseProvider):
provider_input = usage_info.prompt_tokens
if isinstance(provider_input, int):
logger.debug(
"TOKEN_ESTIMATE: our=%d provider=%d diff=%+d",
"TOKEN_ESTIMATE: our={} provider={} diff={:+d}",
input_tokens,
provider_input,
provider_input - input_tokens,

View file

@ -164,6 +164,31 @@ def test_generic_exception_with_status_code():
mock_provider.stream_response = _mock_stream_response
def test_generic_exception_empty_message_returns_non_empty_detail():
"""Exceptions with empty __str__ still return a readable HTTP detail."""
class SilentError(RuntimeError):
def __str__(self):
return ""
def _raise_silent(*args, **kwargs):
raise SilentError()
mock_provider.stream_response = _raise_silent
response = client.post(
"/v1/messages",
json={
"model": "test",
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 10,
"stream": True,
},
)
assert response.status_code == 500
assert response.json()["detail"] != ""
mock_provider.stream_response = _mock_stream_response
def test_count_tokens_endpoint():
"""count_tokens endpoint returns token count."""
response = client.post(

View file

@ -4,9 +4,9 @@ from unittest.mock import MagicMock, patch
import openai
import pytest
from httpx import Request, Response
from httpx import ReadTimeout, Request, Response
from providers.common import map_error
from providers.common import append_request_id, get_user_facing_error_message, map_error
from providers.exceptions import (
APIError,
AuthenticationError,
@ -120,3 +120,16 @@ class TestMapError:
with patch("providers.common.error_mapping.GlobalRateLimiter"):
result = map_error(exc)
assert isinstance(result, expected_cls)
def test_user_facing_message_read_timeout_empty_string():
"""ReadTimeout wrapping TimeoutError should still produce readable text."""
timeout_exc = ReadTimeout("")
message = get_user_facing_error_message(timeout_exc, read_timeout_s=60)
assert message == "Provider request timed out after 60s."
def test_append_request_id_suffix():
"""Request id suffix should be appended deterministically."""
message = append_request_id("Provider request failed.", "req_abc123")
assert message == "Provider request failed. (request_id=req_abc123)"

View file

@ -3,6 +3,7 @@
import json
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from config.nim import NimSettings
@ -116,6 +117,39 @@ class TestStreamingExceptionHandling:
assert "API failed" in event_text
assert "message_stop" in event_text
@pytest.mark.asyncio
async def test_read_timeout_with_empty_message_emits_fallback(self):
"""ReadTimeout(TimeoutError()) should emit a visible, non-empty timeout message."""
provider = _make_provider()
request = _make_request()
with (
patch.object(
provider._client.chat.completions,
"create",
new_callable=AsyncMock,
side_effect=httpx.ReadTimeout(""),
),
patch.object(
provider._global_rate_limiter,
"wait_if_blocked",
new_callable=AsyncMock,
return_value=False,
),
):
events = [
e
async for e in provider.stream_response(
request,
request_id="req_timeout123",
)
]
event_text = "".join(events)
assert "timed out after" in event_text
assert "request_id=req_timeout123" in event_text
assert "message_stop" in event_text
@pytest.mark.asyncio
async def test_error_after_partial_content(self):
"""Error after partial content: blocks closed, error emitted."""