From 539854fe7b94b42d2c0c758f78064c92e5efab9e Mon Sep 17 00:00:00 2001 From: Alishahryar1 Date: Sun, 15 Feb 2026 21:58:03 -0800 Subject: [PATCH] Refactor done using GLM-5 --- .github/workflows/tests.yml | 3 + api/app.py | 4 +- api/dependencies.py | 4 +- api/models/anthropic.py | 4 +- api/optimization_handlers.py | 4 +- api/request_utils.py | 9 +- api/routes.py | 8 +- cli/manager.py | 4 +- cli/process_registry.py | 3 +- cli/session.py | 6 +- messaging/event_parser.py | 4 +- messaging/factory.py | 5 +- messaging/handler.py | 8 +- messaging/limiter.py | 4 +- messaging/session.py | 4 +- messaging/telegram.py | 4 +- messaging/transcript.py | 5 +- messaging/tree_data.py | 4 +- messaging/tree_processor.py | 4 +- messaging/tree_queue.py | 4 +- messaging/tree_repository.py | 5 +- providers/lmstudio/client.py | 70 +----- providers/lmstudio/request.py | 3 +- providers/logging_utils.py | 8 +- providers/nvidia_nim/client.py | 86 +------ providers/nvidia_nim/request.py | 4 +- providers/nvidia_nim/response.py | 88 ------- providers/nvidia_nim/utils/__init__.py | 4 - .../nvidia_nim/utils/heuristic_tool_parser.py | 3 +- providers/nvidia_nim/utils/sse_builder.py | 76 +++++- providers/nvidia_nim/utils/think_parser.py | 44 +--- providers/open_router/client.py | 70 +----- providers/open_router/request.py | 3 +- providers/rate_limit.py | 4 +- tests/conftest.py | 22 ++ tests/test_api.py | 2 +- tests/test_limiter.py | 3 - tests/test_lmstudio.py | 5 +- tests/test_logging_config.py | 1 - tests/test_messaging.py | 1 - tests/test_messaging_factory.py | 1 - tests/test_nvidia_nim.py | 1 - tests/test_provider_rate_limit.py | 4 - tests/test_reliability.py | 2 - tests/test_response_conversion.py | 216 ------------------ tests/test_response_models.py | 2 - tests/test_restart_reply_restore.py | 2 +- tests/test_server_module.py | 2 +- tests/test_session_store_edge_cases.py | 1 - tests/test_sse_builder.py | 1 - tests/test_streaming_errors.py | 4 +- tests/test_subagent_interception.py | 21 +- tests/test_tree_concurrency.py | 1 - 53 files changed, 172 insertions(+), 683 deletions(-) delete mode 100644 providers/nvidia_nim/response.py delete mode 100644 tests/test_response_conversion.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 12f658e..97df164 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -38,6 +38,9 @@ jobs: enable-cache: true cache-python: true + - name: Lint + run: uv run ruff check + - name: Type check run: uv run ty check diff --git a/api/app.py b/api/app.py index 4981335..16fa3b8 100644 --- a/api/app.py +++ b/api/app.py @@ -6,7 +6,6 @@ import os # Opt-in to future behavior for python-telegram-bot os.environ["PTB_TIMEDELTA"] = "1" -import logging from contextlib import asynccontextmanager from fastapi import FastAPI, Request from fastapi.responses import JSONResponse @@ -18,11 +17,12 @@ from config.logging_config import configure_logging _settings = get_settings() configure_logging(_settings.log_file) +from loguru import logger + from .routes import router from .dependencies import cleanup_provider from providers.exceptions import ProviderError -logger = logging.getLogger(__name__) _SHUTDOWN_TIMEOUT_S = 5.0 diff --git a/api/dependencies.py b/api/dependencies.py index b6e7290..2007b1b 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -1,12 +1,12 @@ """Dependency injection for FastAPI.""" -import logging from typing import Optional +from loguru import logger + from config.settings import Settings, get_settings as _get_settings, NVIDIA_NIM_BASE_URL from providers.base import BaseProvider, ProviderConfig -logger = logging.getLogger(__name__) # Global provider instance (singleton) _provider: Optional[BaseProvider] = None diff --git a/api/models/anthropic.py b/api/models/anthropic.py index b2677f4..dbc0bb9 100644 --- a/api/models/anthropic.py +++ b/api/models/anthropic.py @@ -1,6 +1,5 @@ """Pydantic models for Anthropic-compatible requests.""" -import logging from enum import Enum from typing import List, Dict, Any, Optional, Union, Literal @@ -8,8 +7,7 @@ from pydantic import BaseModel, field_validator, model_validator from config.settings import get_settings from providers.model_utils import normalize_model_name - -logger = logging.getLogger(__name__) +from loguru import logger # ============================================================================= diff --git a/api/optimization_handlers.py b/api/optimization_handlers.py index 86fcd1e..0bc1b9e 100644 --- a/api/optimization_handlers.py +++ b/api/optimization_handlers.py @@ -4,7 +4,6 @@ Each handler returns a MessagesResponse if the request matches and the optimization is enabled, otherwise None. """ -import logging import uuid from typing import Optional @@ -19,8 +18,7 @@ from .detection import ( ) from .command_utils import extract_command_prefix, extract_filepaths_from_command from config.settings import Settings - -logger = logging.getLogger(__name__) +from loguru import logger def try_prefix_detection( diff --git a/api/request_utils.py b/api/request_utils.py index 83df194..8eee54d 100644 --- a/api/request_utils.py +++ b/api/request_utils.py @@ -4,21 +4,20 @@ Contains token counting for API requests. """ import json -import logging -from typing import Any, List, Optional, Union, cast +from typing import Any, List, Optional, Union import tiktoken +from loguru import logger -logger = logging.getLogger(__name__) ENCODER = tiktoken.get_encoding("cl100k_base") __all__ = ["get_token_count"] -def _get_block_attr(block: object, key: str, default: Any = "") -> Any: +def _get_block_attr(block: Any, key: str, default: Any = "") -> Any: """Get attribute from block (object or dict).""" if isinstance(block, dict): - return cast(dict[str, Any], block).get(key, default) + return block.get(key, default) return getattr(block, key, default) diff --git a/api/routes.py b/api/routes.py index 60743de..cff45d5 100644 --- a/api/routes.py +++ b/api/routes.py @@ -1,15 +1,14 @@ """FastAPI route handlers.""" import json -import logging import uuid from fastapi import APIRouter, Request, Depends, HTTPException -from loguru import logger as loguru_logger +from loguru import logger from fastapi.responses import StreamingResponse from .models.anthropic import MessagesRequest, TokenCountRequest -from .models.responses import MessagesResponse, TokenCountResponse +from .models.responses import TokenCountResponse from .dependencies import get_provider, get_settings from .request_utils import get_token_count from .optimization_handlers import try_optimizations @@ -18,7 +17,6 @@ from providers.base import BaseProvider from providers.exceptions import ProviderError from providers.logging_utils import build_request_summary, log_request_compact -logger = logging.getLogger(__name__) router = APIRouter() @@ -75,7 +73,7 @@ async def create_message( async def count_tokens(request_data: TokenCountRequest): """Count tokens for a request.""" request_id = f"req_{uuid.uuid4().hex[:12]}" - with loguru_logger.contextualize(request_id=request_id): + with logger.contextualize(request_id=request_id): try: tokens = get_token_count( request_data.messages, request_data.system, request_data.tools diff --git a/cli/manager.py b/cli/manager.py index b60b4c8..2d24a8c 100644 --- a/cli/manager.py +++ b/cli/manager.py @@ -8,12 +8,10 @@ simultaneously in separate CLI processes. import asyncio import uuid -import logging from typing import Dict, Optional, Tuple, List from .session import CLISession - -logger = logging.getLogger(__name__) +from loguru import logger class CLISessionManager: diff --git a/cli/process_registry.py b/cli/process_registry.py index 5a7a1bd..cc3510a 100644 --- a/cli/process_registry.py +++ b/cli/process_registry.py @@ -8,13 +8,12 @@ spawn so we don't accidentally kill unrelated system processes. from __future__ import annotations import atexit -import logging import os import subprocess import threading from typing import Set +from loguru import logger -logger = logging.getLogger(__name__) _lock = threading.Lock() _pids: Set[int] = set() diff --git a/cli/session.py b/cli/session.py index 6ba8e4e..334f2c7 100644 --- a/cli/session.py +++ b/cli/session.py @@ -3,12 +3,10 @@ import asyncio import os import json -import logging -from typing import AsyncGenerator, Optional, Dict, List, Any +from typing import AsyncGenerator, Optional, List, Any from .process_registry import register_pid, unregister_pid - -logger = logging.getLogger(__name__) +from loguru import logger class CLISession: diff --git a/messaging/event_parser.py b/messaging/event_parser.py index 7d9ffcb..c82f96e 100644 --- a/messaging/event_parser.py +++ b/messaging/event_parser.py @@ -4,10 +4,8 @@ This parser emits an ordered stream of low-level events suitable for building a Claude Code-like transcript in messaging UIs. """ -import logging from typing import Dict, List, Any - -logger = logging.getLogger(__name__) +from loguru import logger def parse_cli_event(event: Any) -> List[Dict]: diff --git a/messaging/factory.py b/messaging/factory.py index dcec356..18ccc6f 100644 --- a/messaging/factory.py +++ b/messaging/factory.py @@ -6,12 +6,11 @@ To add a new platform (e.g. Discord, Slack): 2. Add a case to create_messaging_platform() below """ -import logging from typing import Optional -from .base import MessagingPlatform +from loguru import logger -logger = logging.getLogger(__name__) +from .base import MessagingPlatform def create_messaging_platform( diff --git a/messaging/handler.py b/messaging/handler.py index 978a6ba..c21cca6 100644 --- a/messaging/handler.py +++ b/messaging/handler.py @@ -8,7 +8,6 @@ Uses tree-based queuing for message ordering. import time import asyncio -import logging import os from typing import List, Optional, Tuple @@ -26,9 +25,8 @@ from .telegram_markdown import ( format_status, render_markdown_to_mdv2, ) -from loguru import logger as loguru_logger +from loguru import logger -logger = logging.getLogger(__name__) # Status message prefixes used to filter our own messages (ignore echo) STATUS_MESSAGE_PREFIXES = ("⏳", "💭", "🔧", "✅", "❌", "🚀", "🤖", "📋", "📊", "🔄") @@ -122,7 +120,7 @@ class ClaudeMessageHandler: text_preview, ) - with loguru_logger.contextualize( + with logger.contextualize( chat_id=incoming.chat_id, node_id=incoming.message_id ): await self._handle_message_impl(incoming) @@ -388,7 +386,7 @@ class ClaudeMessageHandler: status_msg_id = node.status_message_id chat_id = incoming.chat_id - with loguru_logger.contextualize(node_id=node_id, chat_id=chat_id): + with logger.contextualize(node_id=node_id, chat_id=chat_id): await self._process_node_impl(node_id, node, chat_id, status_msg_id) async def _process_node_impl( diff --git a/messaging/limiter.py b/messaging/limiter.py index bb51297..4057ceb 100644 --- a/messaging/limiter.py +++ b/messaging/limiter.py @@ -6,13 +6,11 @@ using a strict sliding window algorithm and a task queue. """ import asyncio -import logging import os from collections import deque import time from typing import Awaitable, Callable, Any, Optional, Dict - -logger = logging.getLogger(__name__) +from loguru import logger class SlidingWindowLimiter: diff --git a/messaging/session.py b/messaging/session.py index d368f4a..cc19eb8 100644 --- a/messaging/session.py +++ b/messaging/session.py @@ -7,13 +7,11 @@ and message trees for conversation continuation. import json import os -import logging from datetime import datetime, timezone from typing import Optional, Dict, List, Any from dataclasses import dataclass, asdict import threading - -logger = logging.getLogger(__name__) +from loguru import logger @dataclass diff --git a/messaging/telegram.py b/messaging/telegram.py index fcce494..24f3ac4 100644 --- a/messaging/telegram.py +++ b/messaging/telegram.py @@ -5,7 +5,6 @@ Implements MessagingPlatform for Telegram using python-telegram-bot. """ import asyncio -import logging import os # Opt-in to future behavior for python-telegram-bot (retry_after as timedelta) @@ -14,11 +13,12 @@ os.environ["PTB_TIMEDELTA"] = "1" from typing import Callable, Awaitable, Optional, Any +from loguru import logger + from .base import MessagingPlatform from .models import IncomingMessage from .telegram_markdown import escape_md_v2 -logger = logging.getLogger(__name__) # Optional import - python-telegram-bot may not be installed try: diff --git a/messaging/transcript.py b/messaging/transcript.py index 26e7fc5..55da4a4 100644 --- a/messaging/transcript.py +++ b/messaging/transcript.py @@ -9,14 +9,11 @@ the transcript grows over time and older content must be truncated. from __future__ import annotations import json -import logging import os from collections import deque from dataclasses import dataclass, field from typing import Any, Callable, Dict, Iterable, List, Optional - - -logger = logging.getLogger(__name__) +from loguru import logger def _safe_json_dumps(obj: Any) -> str: diff --git a/messaging/tree_data.py b/messaging/tree_data.py index e38f44c..ef6a2bf 100644 --- a/messaging/tree_data.py +++ b/messaging/tree_data.py @@ -4,7 +4,6 @@ Contains MessageState, MessageNode, and MessageTree classes. """ import asyncio -import logging from collections import deque from contextlib import asynccontextmanager from enum import Enum @@ -13,8 +12,7 @@ from typing import Dict, Optional, List, Any, cast from dataclasses import dataclass, field from .models import IncomingMessage - -logger = logging.getLogger(__name__) +from loguru import logger class MessageState(Enum): diff --git a/messaging/tree_processor.py b/messaging/tree_processor.py index 503176e..9ecf09e 100644 --- a/messaging/tree_processor.py +++ b/messaging/tree_processor.py @@ -4,12 +4,10 @@ Handles the async processing lifecycle of tree nodes. """ import asyncio -import logging from typing import Callable, Awaitable, Optional from .tree_data import MessageTree, MessageNode, MessageState - -logger = logging.getLogger(__name__) +from loguru import logger class TreeQueueProcessor: diff --git a/messaging/tree_queue.py b/messaging/tree_queue.py index 354e8c0..d58d9b2 100644 --- a/messaging/tree_queue.py +++ b/messaging/tree_queue.py @@ -5,7 +5,6 @@ Uses TreeRepository for data, TreeQueueProcessor for async logic. """ import asyncio -import logging from datetime import datetime, timezone from typing import Callable, Awaitable, List, Optional @@ -13,6 +12,7 @@ from .models import IncomingMessage from .tree_data import MessageState, MessageNode, MessageTree from .tree_repository import TreeRepository from .tree_processor import TreeQueueProcessor +from loguru import logger # Backward compatibility: re-export moved classes __all__ = [ @@ -22,8 +22,6 @@ __all__ = [ "MessageTree", ] -logger = logging.getLogger(__name__) - class TreeQueueManager: """ diff --git a/messaging/tree_repository.py b/messaging/tree_repository.py index 254ba70..3f3fa47 100644 --- a/messaging/tree_repository.py +++ b/messaging/tree_repository.py @@ -3,12 +3,11 @@ Provides data access layer for managing trees and node mappings. """ -import logging from typing import Dict, Optional, List -from .tree_data import MessageTree, MessageNode, MessageState +from loguru import logger -logger = logging.getLogger(__name__) +from .tree_data import MessageTree, MessageNode, MessageState class TreeRepository: diff --git a/providers/lmstudio/client.py b/providers/lmstudio/client.py index dfb9f7f..921ff82 100644 --- a/providers/lmstudio/client.py +++ b/providers/lmstudio/client.py @@ -1,11 +1,10 @@ """LM Studio provider implementation.""" import json -import logging import uuid from typing import Any, AsyncIterator -from loguru import logger as loguru_logger +from loguru import logger from openai import AsyncOpenAI from providers.base import BaseProvider, ProviderConfig @@ -21,7 +20,6 @@ from providers.nvidia_nim.utils import ( from .request import build_request_body -logger = logging.getLogger(__name__) LMSTUDIO_DEFAULT_BASE_URL = "http://localhost:1234/v1" @@ -56,7 +54,7 @@ class LMStudioProvider(BaseProvider): request_id: str | None = None, ) -> AsyncIterator[str]: """Stream response in Anthropic SSE format.""" - with loguru_logger.contextualize(request_id=request_id): + with logger.contextualize(request_id=request_id): async for event in self._stream_response_impl( request, input_tokens, request_id ): @@ -266,20 +264,7 @@ class LMStudioProvider(BaseProvider): fn_delta = tc.get("function", {}) incoming_name = fn_delta.get("name") if incoming_name is not None: - prev = sse.blocks.tool_names.get(tc_index, "") - if not prev: - sse.blocks.tool_names[tc_index] = incoming_name - elif prev == incoming_name: - pass - elif isinstance(prev, str) and isinstance(incoming_name, str): - if incoming_name.startswith(prev): - sse.blocks.tool_names[tc_index] = incoming_name - elif prev.startswith(incoming_name): - pass - else: - sse.blocks.tool_names[tc_index] = prev + incoming_name - else: - sse.blocks.tool_names[tc_index] = str(prev) + str(incoming_name) + sse.blocks.register_tool_name(tc_index, incoming_name) if tc_index not in sse.blocks.tool_indices: name = sse.blocks.tool_names.get(tc_index, "") @@ -305,55 +290,14 @@ class LMStudioProvider(BaseProvider): current_name = sse.blocks.tool_names.get(tc_index, "") if current_name == "Task": - if not sse.blocks.task_args_emitted.get(tc_index, False): - buf = sse.blocks.task_arg_buffer.get(tc_index, "") + args - sse.blocks.task_arg_buffer[tc_index] = buf - try: - args_json = json.loads(buf) - except Exception: - return - if args_json.get("run_in_background") is not False: - logger.info( - "LMSTUDIO_INTERCEPT: Forcing run_in_background=False for Task %s", - tc.get("id") - or sse.blocks.tool_ids.get(tc_index, "unknown"), - ) - args_json["run_in_background"] = False - sse.blocks.task_args_emitted[tc_index] = True - sse.blocks.task_arg_buffer.pop(tc_index, None) - yield sse.emit_tool_delta(tc_index, json.dumps(args_json)) + parsed = sse.blocks.buffer_task_args(tc_index, args) + if parsed is not None: + yield sse.emit_tool_delta(tc_index, json.dumps(parsed)) return yield sse.emit_tool_delta(tc_index, args) def _flush_task_arg_buffers(self, sse: Any): """Emit buffered Task args as a single JSON delta (best-effort).""" - for tool_index, buf in list(sse.blocks.task_arg_buffer.items()): - if sse.blocks.task_args_emitted.get(tool_index, False): - sse.blocks.task_arg_buffer.pop(tool_index, None) - continue - - tool_id = sse.blocks.tool_ids.get(tool_index, "unknown") - out = "{}" - try: - args_json = json.loads(buf) - if args_json.get("run_in_background") is not False: - logger.info( - "LMSTUDIO_INTERCEPT: Forcing run_in_background=False for Task %s", - tool_id, - ) - args_json["run_in_background"] = False - out = json.dumps(args_json) - except Exception as e: - prefix = buf[:120] - logger.warning( - "LMSTUDIO_INTERCEPT: Task args invalid JSON (id=%s len=%d prefix=%r): %s", - tool_id, - len(buf), - prefix, - e, - ) - - sse.blocks.task_args_emitted[tool_index] = True - sse.blocks.task_arg_buffer.pop(tool_index, None) + for tool_index, out in sse.blocks.flush_task_arg_buffers(): yield sse.emit_tool_delta(tool_index, out) diff --git a/providers/lmstudio/request.py b/providers/lmstudio/request.py index 32f44a6..aac8cec 100644 --- a/providers/lmstudio/request.py +++ b/providers/lmstudio/request.py @@ -1,11 +1,10 @@ """Request builder for LM Studio provider.""" -import logging from typing import Any, Dict from providers.nvidia_nim.utils.message_converter import AnthropicToOpenAIConverter +from loguru import logger -logger = logging.getLogger(__name__) LMSTUDIO_DEFAULT_MAX_TOKENS = 81920 diff --git a/providers/logging_utils.py b/providers/logging_utils.py index 6401cb3..aebd28a 100644 --- a/providers/logging_utils.py +++ b/providers/logging_utils.py @@ -6,12 +6,10 @@ while maintaining full traceability through request IDs and content hashes. import hashlib import json -import logging from typing import Any, Dict, List from utils.text import extract_text_from_content - -logger = logging.getLogger(__name__) +from loguru import logger def generate_request_fingerprint(messages: List[Any]) -> str: @@ -99,7 +97,7 @@ def build_request_summary(request_data: Any) -> Dict[str, Any]: def log_full_payload( - logger_instance: logging.Logger, request_id: str, payload: Dict[str, Any] + logger_instance: Any, request_id: str, payload: Dict[str, Any] ) -> None: """Log full payload to the standard logger.""" logger_instance.debug( @@ -108,7 +106,7 @@ def log_full_payload( def log_request_compact( - logger_instance: logging.Logger, + logger_instance: Any, request_id: str, request_data: Any, prefix: str = "API_REQUEST", diff --git a/providers/nvidia_nim/client.py b/providers/nvidia_nim/client.py index 8ff75f6..a47c4e6 100644 --- a/providers/nvidia_nim/client.py +++ b/providers/nvidia_nim/client.py @@ -1,11 +1,10 @@ """NVIDIA NIM provider implementation.""" import json -import logging import uuid from typing import Any, AsyncIterator -from loguru import logger as loguru_logger +from loguru import logger from openai import AsyncOpenAI from providers.base import BaseProvider, ProviderConfig @@ -20,8 +19,6 @@ from .utils import ( ContentType, ) -logger = logging.getLogger(__name__) - class NvidiaNimProvider(BaseProvider): """NVIDIA NIM provider using official OpenAI client.""" @@ -56,7 +53,7 @@ class NvidiaNimProvider(BaseProvider): request_id: str | None = None, ) -> AsyncIterator[str]: """Stream response in Anthropic SSE format.""" - with loguru_logger.contextualize(request_id=request_id): + with logger.contextualize(request_id=request_id): async for event in self._stream_response_impl( request, input_tokens, request_id ): @@ -260,13 +257,7 @@ class NvidiaNimProvider(BaseProvider): yield sse.done() def _process_tool_call(self, tc: dict, sse: Any): - """Process a single tool call delta and yield SSE events. - - Args: - tc: Tool call delta info dict - sse: SSEBuilder instance - """ - + """Process a single tool call delta and yield SSE events.""" tc_index = tc.get("index", 0) if tc_index < 0: tc_index = len(sse.blocks.tool_indices) @@ -274,22 +265,7 @@ class NvidiaNimProvider(BaseProvider): fn_delta = tc.get("function", {}) incoming_name = fn_delta.get("name") if incoming_name is not None: - # Some providers stream tool names as fragments; others resend the full name. - # Avoid "TaskTask" while still supporting fragment streams. - prev = sse.blocks.tool_names.get(tc_index, "") - if not prev: - sse.blocks.tool_names[tc_index] = incoming_name - elif prev == incoming_name: - pass - elif isinstance(prev, str) and isinstance(incoming_name, str): - if incoming_name.startswith(prev): - sse.blocks.tool_names[tc_index] = incoming_name - elif prev.startswith(incoming_name): - pass - else: - sse.blocks.tool_names[tc_index] = prev + incoming_name - else: - sse.blocks.tool_names[tc_index] = str(prev) + str(incoming_name) + sse.blocks.register_tool_name(tc_index, incoming_name) if tc_index not in sse.blocks.tool_indices: name = sse.blocks.tool_names.get(tc_index, "") @@ -310,65 +286,19 @@ class NvidiaNimProvider(BaseProvider): if not sse.blocks.tool_started.get(tc_index): tool_id = tc.get("id") or f"tool_{uuid.uuid4()}" name = sse.blocks.tool_names.get(tc_index, "tool_call") or "tool_call" - yield sse.start_tool_block(tc_index, tool_id, name) sse.blocks.tool_started[tc_index] = True current_name = sse.blocks.tool_names.get(tc_index, "") - # INTERCEPTION: Task args can stream in many partial chunks. Buffer until we - # have valid JSON, then emit a single delta with run_in_background forced off. if current_name == "Task": - if not sse.blocks.task_args_emitted.get(tc_index, False): - buf = sse.blocks.task_arg_buffer.get(tc_index, "") + args - sse.blocks.task_arg_buffer[tc_index] = buf - try: - args_json = json.loads(buf) - except Exception: - return - if args_json.get("run_in_background") is not False: - logger.info( - "NIM_INTERCEPT: Forcing run_in_background=False for Task %s", - ( - tc.get("id") - or sse.blocks.tool_ids.get(tc_index, "unknown") - ), - ) - args_json["run_in_background"] = False - sse.blocks.task_args_emitted[tc_index] = True - sse.blocks.task_arg_buffer.pop(tc_index, None) - yield sse.emit_tool_delta(tc_index, json.dumps(args_json)) + parsed = sse.blocks.buffer_task_args(tc_index, args) + if parsed is not None: + yield sse.emit_tool_delta(tc_index, json.dumps(parsed)) return yield sse.emit_tool_delta(tc_index, args) def _flush_task_arg_buffers(self, sse: Any): """Emit buffered Task args as a single JSON delta (best-effort).""" - for tool_index, buf in list(sse.blocks.task_arg_buffer.items()): - if sse.blocks.task_args_emitted.get(tool_index, False): - sse.blocks.task_arg_buffer.pop(tool_index, None) - continue - - tool_id = sse.blocks.tool_ids.get(tool_index, "unknown") - out = "{}" - try: - args_json = json.loads(buf) - if args_json.get("run_in_background") is not False: - logger.info( - "NIM_INTERCEPT: Forcing run_in_background=False for Task %s", - tool_id, - ) - args_json["run_in_background"] = False - out = json.dumps(args_json) - except Exception as e: - prefix = buf[:120] - logger.warning( - "NIM_INTERCEPT: Task args invalid JSON (id=%s len=%d prefix=%r): %s", - tool_id, - len(buf), - prefix, - e, - ) - - sse.blocks.task_args_emitted[tool_index] = True - sse.blocks.task_arg_buffer.pop(tool_index, None) + for tool_index, out in sse.blocks.flush_task_arg_buffers(): yield sse.emit_tool_delta(tool_index, out) diff --git a/providers/nvidia_nim/request.py b/providers/nvidia_nim/request.py index 7f376f8..3c9901a 100644 --- a/providers/nvidia_nim/request.py +++ b/providers/nvidia_nim/request.py @@ -1,12 +1,10 @@ """Request builder for NVIDIA NIM provider.""" -import logging from typing import Any, Dict from config.nim import NimSettings from .utils.message_converter import AnthropicToOpenAIConverter - -logger = logging.getLogger(__name__) +from loguru import logger def _set_if_not_none(body: Dict[str, Any], key: str, value: Any) -> None: diff --git a/providers/nvidia_nim/response.py b/providers/nvidia_nim/response.py deleted file mode 100644 index 3ddf653..0000000 --- a/providers/nvidia_nim/response.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Response conversion for NVIDIA NIM provider.""" - -import json -import uuid -from typing import Any - -from .utils import map_stop_reason, extract_think_content_interleaved - - -def convert_response(response_json: dict, original_request: Any) -> dict: - """Convert OpenAI response to Anthropic format.""" - choice = response_json["choices"][0] - message = choice["message"] - content = [] - - # Extract reasoning from various sources - reasoning = message.get("reasoning_content") - if not reasoning: - reasoning_details = message.get("reasoning_details") - if reasoning_details and isinstance(reasoning_details, list): - reasoning = "\n".join( - item.get("text", "") - for item in reasoning_details - if isinstance(item, dict) - ) - - if reasoning: - content.append({"type": "thinking", "thinking": reasoning}) - - # Extract text content (with think tag handling, preserving interleaving) - if message.get("content"): - raw_content = message["content"] - if isinstance(raw_content, str): - if not reasoning: - for block_type, block_content in extract_think_content_interleaved( - raw_content - ): - if block_type == "thinking": - content.append({"type": "thinking", "thinking": block_content}) - else: - content.append({"type": "text", "text": block_content}) - else: - if raw_content.strip(): - content.append({"type": "text", "text": raw_content.strip()}) - elif isinstance(raw_content, list): - for item in raw_content: - if isinstance(item, dict) and item.get("type") == "text": - content.append(item) - - # Extract tool calls - if message.get("tool_calls"): - for tc in message["tool_calls"]: - try: - args = json.loads(tc["function"]["arguments"]) - except Exception: - args = tc["function"].get("arguments", {}) - content.append( - { - "type": "tool_use", - "id": tc["id"], - "name": tc["function"]["name"], - "input": args, - } - ) - - if not content: - # NIM models (especially Mistral-based) often require non-empty content. - # Adding a single space satisfies this requirement while avoiding - # the "(no content)" display issue in Claude Code. - content.append({"type": "text", "text": " "}) - - usage = response_json.get("usage", {}) - - return { - "id": response_json.get("id", f"msg_{uuid.uuid4()}"), - "type": "message", - "role": "assistant", - "model": original_request.model, - "content": content, - "stop_reason": map_stop_reason(choice.get("finish_reason")), - "stop_sequence": None, - "usage": { - "input_tokens": usage.get("prompt_tokens", 0), - "output_tokens": usage.get("completion_tokens", 0), - "cache_creation_input_tokens": 0, - "cache_read_input_tokens": 0, - }, - } diff --git a/providers/nvidia_nim/utils/__init__.py b/providers/nvidia_nim/utils/__init__.py index 5e909b6..634b6e5 100644 --- a/providers/nvidia_nim/utils/__init__.py +++ b/providers/nvidia_nim/utils/__init__.py @@ -5,8 +5,6 @@ from .think_parser import ( ThinkTagParser, ContentType, ContentChunk, - extract_think_content, - extract_think_content_interleaved, ) from .heuristic_tool_parser import HeuristicToolParser from .message_converter import ( @@ -23,8 +21,6 @@ __all__ = [ "HeuristicToolParser", "ContentType", "ContentChunk", - "extract_think_content", - "extract_think_content_interleaved", "AnthropicToOpenAIConverter", "get_block_attr", "get_block_type", diff --git a/providers/nvidia_nim/utils/heuristic_tool_parser.py b/providers/nvidia_nim/utils/heuristic_tool_parser.py index 0c9b39c..bdc78f5 100644 --- a/providers/nvidia_nim/utils/heuristic_tool_parser.py +++ b/providers/nvidia_nim/utils/heuristic_tool_parser.py @@ -1,10 +1,9 @@ import re -import logging import uuid from enum import Enum from typing import List, Dict, Any, Tuple +from loguru import logger -logger = logging.getLogger(__name__) # Some OpenAI-compatible backends/models occasionally leak internal sentinel tokens # into `delta.content` (e.g. "<|tool_call_end|>"). These should never be shown to diff --git a/providers/nvidia_nim/utils/sse_builder.py b/providers/nvidia_nim/utils/sse_builder.py index eaf864d..80bebbb 100644 --- a/providers/nvidia_nim/utils/sse_builder.py +++ b/providers/nvidia_nim/utils/sse_builder.py @@ -1,9 +1,10 @@ """SSE event builder for Anthropic-format streaming responses.""" import json -import logging from dataclasses import dataclass, field -from typing import Optional, Dict, Any, Iterator +from typing import Optional, Dict, Any, Iterator, List, Tuple + +from loguru import logger try: import tiktoken @@ -12,7 +13,6 @@ try: except Exception: ENCODER = None -logger = logging.getLogger(__name__) # Map OpenAI finish_reason to Anthropic stop_reason STOP_REASON_MAP = { @@ -54,6 +54,76 @@ class ContentBlockManager: self.next_index += 1 return idx + def register_tool_name(self, index: int, name: str) -> None: + """Register or merge a streaming tool name fragment. + + Handles providers that stream names as fragments and those that + resend the full name on every chunk. + """ + prev = self.tool_names.get(index, "") + if not prev: + self.tool_names[index] = name + elif prev == name: + pass + elif name.startswith(prev): + self.tool_names[index] = name + elif prev.startswith(name): + pass + else: + self.tool_names[index] = prev + name + + def buffer_task_args(self, index: int, args: str) -> Optional[dict]: + """Buffer Task tool args and return parsed JSON when complete. + + Returns the parsed (and patched) args dict once the buffer forms + valid JSON, or None if still accumulating. + """ + if self.task_args_emitted.get(index, False): + return None + + buf = self.task_arg_buffer.get(index, "") + args + self.task_arg_buffer[index] = buf + try: + args_json = json.loads(buf) + except Exception: + return None + + if args_json.get("run_in_background") is not False: + args_json["run_in_background"] = False + + self.task_args_emitted[index] = True + self.task_arg_buffer.pop(index, None) + return args_json + + def flush_task_arg_buffers(self) -> List[Tuple[int, str]]: + """Flush any remaining Task arg buffers. Returns (tool_index, json_str) pairs.""" + results: List[Tuple[int, str]] = [] + for tool_index, buf in list(self.task_arg_buffer.items()): + if self.task_args_emitted.get(tool_index, False): + self.task_arg_buffer.pop(tool_index, None) + continue + + out = "{}" + try: + args_json = json.loads(buf) + if args_json.get("run_in_background") is not False: + args_json["run_in_background"] = False + out = json.dumps(args_json) + except Exception as e: + prefix = buf[:120] + logger.warning( + "Task args invalid JSON (id=%s len=%d prefix=%r): %s", + self.tool_ids.get(tool_index, "unknown"), + len(buf), + prefix, + e, + ) + + self.task_args_emitted[tool_index] = True + self.task_arg_buffer.pop(tool_index, None) + results.append((tool_index, out)) + return results + class SSEBuilder: """Builder for Anthropic SSE streaming events.""" diff --git a/providers/nvidia_nim/utils/think_parser.py b/providers/nvidia_nim/utils/think_parser.py index 7332787..923ad7d 100644 --- a/providers/nvidia_nim/utils/think_parser.py +++ b/providers/nvidia_nim/utils/think_parser.py @@ -1,8 +1,7 @@ """Think tag parser for extracting reasoning content from responses.""" -import re from dataclasses import dataclass -from typing import List, Optional, Tuple, Iterator +from typing import Optional, Iterator from enum import Enum @@ -166,44 +165,3 @@ class ThinkTagParser: """Reset parser state.""" self._buffer = "" self._in_think_tag = False - - -def extract_think_content(text: str) -> Tuple[Optional[str], str]: - """ - Extract thinking content from text (non-streaming). - - Returns: (thinking_content, remaining_text) - Merges all think blocks and strips them from text. Use extract_think_content_interleaved - when interleaved order must be preserved. - """ - think_pattern = re.compile(r"(.*?)", re.DOTALL) - matches = think_pattern.findall(text) - - if matches: - thinking = "\n".join(matches) - remaining = think_pattern.sub("", text).strip() - return thinking, remaining - - return None, text - - -def extract_think_content_interleaved(text: str) -> List[Tuple[str, str]]: - """ - Parse content and return blocks in order, preserving interleaving of - ... and text. - - Returns: [(type, content), ...] where type is "thinking" or "text". - """ - blocks: List[Tuple[str, str]] = [] - pattern = re.compile(r"(.*?)", re.DOTALL) - last_end = 0 - for m in pattern.finditer(text): - before = text[last_end : m.start()].strip() - if before: - blocks.append(("text", before)) - blocks.append(("thinking", m.group(1))) - last_end = m.end() - after = text[last_end:].strip() - if after: - blocks.append(("text", after)) - return blocks diff --git a/providers/open_router/client.py b/providers/open_router/client.py index 57bb8d5..36d8aec 100644 --- a/providers/open_router/client.py +++ b/providers/open_router/client.py @@ -1,11 +1,10 @@ """OpenRouter provider implementation.""" import json -import logging import uuid from typing import Any, AsyncIterator -from loguru import logger as loguru_logger +from loguru import logger from openai import AsyncOpenAI from providers.base import BaseProvider, ProviderConfig @@ -21,7 +20,6 @@ from providers.nvidia_nim.utils import ( from .request import build_request_body -logger = logging.getLogger(__name__) OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1" @@ -56,7 +54,7 @@ class OpenRouterProvider(BaseProvider): request_id: str | None = None, ) -> AsyncIterator[str]: """Stream response in Anthropic SSE format.""" - with loguru_logger.contextualize(request_id=request_id): + with logger.contextualize(request_id=request_id): async for event in self._stream_response_impl( request, input_tokens, request_id ): @@ -277,20 +275,7 @@ class OpenRouterProvider(BaseProvider): fn_delta = tc.get("function", {}) incoming_name = fn_delta.get("name") if incoming_name is not None: - prev = sse.blocks.tool_names.get(tc_index, "") - if not prev: - sse.blocks.tool_names[tc_index] = incoming_name - elif prev == incoming_name: - pass - elif isinstance(prev, str) and isinstance(incoming_name, str): - if incoming_name.startswith(prev): - sse.blocks.tool_names[tc_index] = incoming_name - elif prev.startswith(incoming_name): - pass - else: - sse.blocks.tool_names[tc_index] = prev + incoming_name - else: - sse.blocks.tool_names[tc_index] = str(prev) + str(incoming_name) + sse.blocks.register_tool_name(tc_index, incoming_name) if tc_index not in sse.blocks.tool_indices: name = sse.blocks.tool_names.get(tc_index, "") @@ -316,55 +301,14 @@ class OpenRouterProvider(BaseProvider): current_name = sse.blocks.tool_names.get(tc_index, "") if current_name == "Task": - if not sse.blocks.task_args_emitted.get(tc_index, False): - buf = sse.blocks.task_arg_buffer.get(tc_index, "") + args - sse.blocks.task_arg_buffer[tc_index] = buf - try: - args_json = json.loads(buf) - except Exception: - return - if args_json.get("run_in_background") is not False: - logger.info( - "OPENROUTER_INTERCEPT: Forcing run_in_background=False for Task %s", - tc.get("id") - or sse.blocks.tool_ids.get(tc_index, "unknown"), - ) - args_json["run_in_background"] = False - sse.blocks.task_args_emitted[tc_index] = True - sse.blocks.task_arg_buffer.pop(tc_index, None) - yield sse.emit_tool_delta(tc_index, json.dumps(args_json)) + parsed = sse.blocks.buffer_task_args(tc_index, args) + if parsed is not None: + yield sse.emit_tool_delta(tc_index, json.dumps(parsed)) return yield sse.emit_tool_delta(tc_index, args) def _flush_task_arg_buffers(self, sse: Any): """Emit buffered Task args as a single JSON delta (best-effort).""" - for tool_index, buf in list(sse.blocks.task_arg_buffer.items()): - if sse.blocks.task_args_emitted.get(tool_index, False): - sse.blocks.task_arg_buffer.pop(tool_index, None) - continue - - tool_id = sse.blocks.tool_ids.get(tool_index, "unknown") - out = "{}" - try: - args_json = json.loads(buf) - if args_json.get("run_in_background") is not False: - logger.info( - "OPENROUTER_INTERCEPT: Forcing run_in_background=False for Task %s", - tool_id, - ) - args_json["run_in_background"] = False - out = json.dumps(args_json) - except Exception as e: - prefix = buf[:120] - logger.warning( - "OPENROUTER_INTERCEPT: Task args invalid JSON (id=%s len=%d prefix=%r): %s", - tool_id, - len(buf), - prefix, - e, - ) - - sse.blocks.task_args_emitted[tool_index] = True - sse.blocks.task_arg_buffer.pop(tool_index, None) + for tool_index, out in sse.blocks.flush_task_arg_buffers(): yield sse.emit_tool_delta(tool_index, out) diff --git a/providers/open_router/request.py b/providers/open_router/request.py index 244749b..984b316 100644 --- a/providers/open_router/request.py +++ b/providers/open_router/request.py @@ -1,11 +1,10 @@ """Request builder for OpenRouter provider.""" -import logging from typing import Any, Dict from providers.nvidia_nim.utils.message_converter import AnthropicToOpenAIConverter +from loguru import logger -logger = logging.getLogger(__name__) OPENROUTER_DEFAULT_MAX_TOKENS = 81920 diff --git a/providers/rate_limit.py b/providers/rate_limit.py index b343ee8..bcef02b 100644 --- a/providers/rate_limit.py +++ b/providers/rate_limit.py @@ -4,15 +4,13 @@ import asyncio from collections import deque import random import time -import logging from typing import Any, Callable, Optional, TypeVar import openai +from loguru import logger T = TypeVar("T") -logger = logging.getLogger(__name__) - class GlobalRateLimiter: """ diff --git a/tests/conftest.py b/tests/conftest.py index e4b2272..6f67211 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import logging import pytest import asyncio import os @@ -144,3 +145,24 @@ def incoming_message_factory(): return IncomingMessage(**filtered) return _create + + +@pytest.fixture(autouse=True) +def _propagate_loguru_to_caplog(): + """Route loguru logs to stdlib logging so pytest caplog captures them.""" + from loguru import logger as loguru_logger + + class _PropagateHandler: + def write(self, message): + record = message.record + level = record["level"].no + stdlib_level = min(level, logging.CRITICAL) + py_logger = logging.getLogger(record["name"]) + py_logger.log(stdlib_level, record["message"]) + + handler_id = loguru_logger.add(_PropagateHandler(), format="{message}") + yield + try: + loguru_logger.remove(handler_id) + except ValueError: + pass # Handler already removed (e.g. by test_logging_config tests) diff --git a/tests/test_api.py b/tests/test_api.py index 52c5f21..8c84537 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,7 +1,7 @@ from fastapi.testclient import TestClient from api.app import app from api.dependencies import get_provider -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock from providers.nvidia_nim import NvidiaNimProvider # Mock provider diff --git a/tests/test_limiter.py b/tests/test_limiter.py index 35e6e17..cacab32 100644 --- a/tests/test_limiter.py +++ b/tests/test_limiter.py @@ -3,7 +3,6 @@ import pytest_asyncio import asyncio import time import os -import logging # Set environment variables relative to test execution os.environ["MESSAGING_RATE_LIMIT"] = "1" @@ -11,8 +10,6 @@ os.environ["MESSAGING_RATE_WINDOW"] = "0.5" from messaging.limiter import MessagingRateLimiter -logger = logging.getLogger(__name__) - class TestMessagingRateLimiter: """Tests for MessagingRateLimiter.""" diff --git a/tests/test_lmstudio.py b/tests/test_lmstudio.py index dcb1c22..60d3605 100644 --- a/tests/test_lmstudio.py +++ b/tests/test_lmstudio.py @@ -539,10 +539,7 @@ class TestLMStudioProcessToolCall: flushed = list(lmstudio_provider._flush_task_arg_buffers(sse)) assert len(flushed) > 0 assert "{}" in "".join(flushed) - assert any( - "LMSTUDIO_INTERCEPT: Task args invalid JSON" in r.message - for r in caplog.records - ) + assert any("Task args invalid JSON" in r.message for r in caplog.records) def test_negative_tool_index_fallback(self, lmstudio_provider): """tc_index < 0 uses len(tool_indices) as fallback.""" diff --git a/tests/test_logging_config.py b/tests/test_logging_config.py index 9051f80..18dcc11 100644 --- a/tests/test_logging_config.py +++ b/tests/test_logging_config.py @@ -4,7 +4,6 @@ import json import logging from pathlib import Path -import pytest from config.logging_config import configure_logging diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 1310c2d..f3c119c 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -2,7 +2,6 @@ import pytest import json -import os from datetime import datetime, timedelta, timezone from unittest.mock import patch diff --git a/tests/test_messaging_factory.py b/tests/test_messaging_factory.py index 8b630a5..1657445 100644 --- a/tests/test_messaging_factory.py +++ b/tests/test_messaging_factory.py @@ -1,6 +1,5 @@ """Tests for messaging platform factory.""" -import pytest from unittest.mock import patch, MagicMock from messaging.factory import create_messaging_platform diff --git a/tests/test_nvidia_nim.py b/tests/test_nvidia_nim.py index 4dc0c4a..14d967c 100644 --- a/tests/test_nvidia_nim.py +++ b/tests/test_nvidia_nim.py @@ -2,7 +2,6 @@ import pytest import json from unittest.mock import MagicMock, AsyncMock, patch from providers.nvidia_nim import NvidiaNimProvider -from providers.exceptions import APIError # Mock data classes diff --git a/tests/test_provider_rate_limit.py b/tests/test_provider_rate_limit.py index cd271cf..9a328d5 100644 --- a/tests/test_provider_rate_limit.py +++ b/tests/test_provider_rate_limit.py @@ -2,13 +2,9 @@ import pytest import pytest_asyncio import asyncio import time -import os -import logging from providers.rate_limit import GlobalRateLimiter -logger = logging.getLogger(__name__) - class TestProviderRateLimiter: """Tests for providers.rate_limit.GlobalRateLimiter.""" diff --git a/tests/test_reliability.py b/tests/test_reliability.py index d947b00..97847f6 100644 --- a/tests/test_reliability.py +++ b/tests/test_reliability.py @@ -2,8 +2,6 @@ import pytest from unittest.mock import AsyncMock, MagicMock, patch from messaging.telegram import TelegramPlatform from telegram.error import NetworkError, RetryAfter, TelegramError -from messaging.handler import ClaudeMessageHandler -from messaging.telegram_markdown import format_status @pytest.fixture diff --git a/tests/test_response_conversion.py b/tests/test_response_conversion.py deleted file mode 100644 index 79693ce..0000000 --- a/tests/test_response_conversion.py +++ /dev/null @@ -1,216 +0,0 @@ -"""Tests for providers/nvidia_nim/response.py response conversion.""" - -import pytest -from unittest.mock import MagicMock - -from providers.nvidia_nim.response import convert_response - - -def _make_response( - content="Hello", - finish_reason="stop", - tool_calls=None, - reasoning_content=None, - reasoning_details=None, - usage=None, -): - """Helper to build a minimal OpenAI-format response dict.""" - message = {"content": content, "role": "assistant"} - if tool_calls is not None: - message["tool_calls"] = tool_calls - if reasoning_content is not None: - message["reasoning_content"] = reasoning_content - if reasoning_details is not None: - message["reasoning_details"] = reasoning_details - - return { - "id": "chatcmpl-test", - "choices": [{"message": message, "finish_reason": finish_reason}], - "usage": usage or {"prompt_tokens": 10, "completion_tokens": 5}, - } - - -def _make_request(model="test-model"): - req = MagicMock() - req.model = model - return req - - -class TestConvertResponse: - """Tests for convert_response function.""" - - def test_simple_text_response(self): - """Simple text content is preserved.""" - resp = _make_response(content="Hello world") - result = convert_response(resp, _make_request()) - assert result["content"] == [{"type": "text", "text": "Hello world"}] - assert result["stop_reason"] == "end_turn" - - def test_empty_content_gets_space(self): - """Empty content produces a single space text block.""" - resp = _make_response(content="") - result = convert_response(resp, _make_request()) - assert result["content"] == [{"type": "text", "text": " "}] - - def test_none_content_gets_space(self): - """None content produces a single space text block.""" - resp = _make_response(content=None) - result = convert_response(resp, _make_request()) - assert result["content"] == [{"type": "text", "text": " "}] - - def test_reasoning_content_field(self): - """reasoning_content field is extracted as thinking block.""" - resp = _make_response(content="Answer", reasoning_content="I need to think...") - result = convert_response(resp, _make_request()) - types = [b["type"] for b in result["content"]] - assert "thinking" in types - assert "text" in types - thinking = [b for b in result["content"] if b["type"] == "thinking"] - assert thinking[0]["thinking"] == "I need to think..." - - def test_reasoning_details_list(self): - """reasoning_details list is joined into thinking block.""" - resp = _make_response( - content="Answer", - reasoning_details=[ - {"text": "Step 1"}, - {"text": "Step 2"}, - ], - ) - result = convert_response(resp, _make_request()) - thinking = [b for b in result["content"] if b["type"] == "thinking"] - assert len(thinking) == 1 - assert "Step 1" in thinking[0]["thinking"] - assert "Step 2" in thinking[0]["thinking"] - - def test_content_with_think_tags(self): - """Think tags in content string are extracted when no reasoning field.""" - resp = _make_response(content="reasoning hereThe answer is 42") - result = convert_response(resp, _make_request()) - types = [b["type"] for b in result["content"]] - assert "thinking" in types - text_blocks = [b for b in result["content"] if b["type"] == "text"] - assert any("42" in b["text"] for b in text_blocks) - - def test_think_tags_skipped_when_reasoning_exists(self): - """When reasoning_content exists, think tags in content are NOT re-extracted.""" - resp = _make_response( - content="duplicateAnswer", - reasoning_content="Real reasoning", - ) - result = convert_response(resp, _make_request()) - thinking = [b for b in result["content"] if b["type"] == "thinking"] - # Only the reasoning_content should be in thinking, not duplicate extraction - assert len(thinking) == 1 - assert thinking[0]["thinking"] == "Real reasoning" - - def test_content_as_list(self): - """Content as list of dicts is preserved.""" - resp = _make_response( - content=[ - {"type": "text", "text": "Hello"}, - {"type": "text", "text": "World"}, - ] - ) - result = convert_response(resp, _make_request()) - text_blocks = [b for b in result["content"] if b["type"] == "text"] - assert len(text_blocks) == 2 - - def test_tool_call_valid_json(self): - """Tool calls with valid JSON arguments are parsed.""" - resp = _make_response( - content="", - tool_calls=[ - { - "id": "call_1", - "function": { - "name": "search", - "arguments": '{"query": "test"}', - }, - } - ], - ) - result = convert_response(resp, _make_request()) - tool_blocks = [b for b in result["content"] if b["type"] == "tool_use"] - assert len(tool_blocks) == 1 - assert tool_blocks[0]["input"] == {"query": "test"} - assert tool_blocks[0]["name"] == "search" - - def test_tool_call_invalid_json_fallback(self): - """Tool call with non-JSON arguments falls back to raw value.""" - resp = _make_response( - content="", - tool_calls=[ - { - "id": "call_2", - "function": { - "name": "test", - "arguments": "not valid json {", - }, - } - ], - ) - result = convert_response(resp, _make_request()) - tool_blocks = [b for b in result["content"] if b["type"] == "tool_use"] - assert len(tool_blocks) == 1 - assert tool_blocks[0]["input"] == "not valid json {" - - def test_usage_mapping(self): - """Usage tokens are mapped from OpenAI to Anthropic format.""" - resp = _make_response(usage={"prompt_tokens": 100, "completion_tokens": 50}) - result = convert_response(resp, _make_request()) - assert result["usage"]["input_tokens"] == 100 - assert result["usage"]["output_tokens"] == 50 - assert result["usage"]["cache_creation_input_tokens"] == 0 - assert result["usage"]["cache_read_input_tokens"] == 0 - - def test_missing_usage(self): - """Missing usage defaults to zeros.""" - resp = _make_response() - resp.pop("usage") - result = convert_response(resp, _make_request()) - assert result["usage"]["input_tokens"] == 0 - assert result["usage"]["output_tokens"] == 0 - - @pytest.mark.parametrize( - "finish_reason,expected_stop", - [ - ("stop", "end_turn"), - ("length", "max_tokens"), - ("tool_calls", "tool_use"), - ("content_filter", "end_turn"), - (None, "end_turn"), - ], - ids=["stop", "length", "tool_calls", "content_filter", "none"], - ) - def test_stop_reason_mapping(self, finish_reason, expected_stop): - """Finish reasons map correctly to Anthropic stop_reasons.""" - resp = _make_response(finish_reason=finish_reason) - result = convert_response(resp, _make_request()) - assert result["stop_reason"] == expected_stop - - def test_model_from_request(self): - """Response model comes from original request, not provider response.""" - resp = _make_response() - result = convert_response(resp, _make_request(model="claude-3")) - assert result["model"] == "claude-3" - - def test_interleaved_think_tags_in_content_preserved(self): - """Interleaved ... and text in content should preserve order. - - Bug: extract_think_content uses findall and joins all matches, then strips - tags from content. So "axb" becomes thinking="a\\nb", remaining="x". - Output is [thinking, text] instead of [thinking, text, thinking]. - """ - resp = _make_response(content="firstmiddlesecond") - result = convert_response(resp, _make_request()) - types = [b["type"] for b in result["content"]] - # Expected: thinking, text, thinking (interleaved order) - assert types == ["thinking", "text", "thinking"], ( - f"Interleaved order lost. Got types: {types}" - ) - thinking_blocks = [b for b in result["content"] if b["type"] == "thinking"] - assert thinking_blocks[0]["thinking"] == "first" - assert thinking_blocks[1]["thinking"] == "second" - text_blocks = [b for b in result["content"] if b["type"] == "text"] - assert text_blocks[0]["text"] == "middle" diff --git a/tests/test_response_models.py b/tests/test_response_models.py index 8f649ef..7cbe1d8 100644 --- a/tests/test_response_models.py +++ b/tests/test_response_models.py @@ -1,7 +1,5 @@ """Tests for api/models/responses.py Pydantic response models.""" -import pytest - from api.models.responses import MessagesResponse, Usage, TokenCountResponse from api.models.anthropic import ( ContentBlockText, diff --git a/tests/test_restart_reply_restore.py b/tests/test_restart_reply_restore.py index d46d143..2e20f8a 100644 --- a/tests/test_restart_reply_restore.py +++ b/tests/test_restart_reply_restore.py @@ -1,5 +1,5 @@ import pytest -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch from messaging.handler import ClaudeMessageHandler from messaging.session import SessionStore diff --git a/tests/test_server_module.py b/tests/test_server_module.py index 386ab32..dec7971 100644 --- a/tests/test_server_module.py +++ b/tests/test_server_module.py @@ -8,7 +8,7 @@ def test_server_module_exports_app_and_create_app(): def test_server_main_invokes_uvicorn_run(monkeypatch): import runpy from types import SimpleNamespace - from unittest.mock import MagicMock, patch + from unittest.mock import patch import config.settings as settings_mod import uvicorn as uvicorn_mod diff --git a/tests/test_session_store_edge_cases.py b/tests/test_session_store_edge_cases.py index decb00a..ab01332 100644 --- a/tests/test_session_store_edge_cases.py +++ b/tests/test_session_store_edge_cases.py @@ -1,7 +1,6 @@ """Edge case tests for messaging/session.py SessionStore.""" import json -import os import pytest from unittest.mock import patch diff --git a/tests/test_sse_builder.py b/tests/test_sse_builder.py index f54a77b..b4bbfe9 100644 --- a/tests/test_sse_builder.py +++ b/tests/test_sse_builder.py @@ -8,7 +8,6 @@ from providers.nvidia_nim.utils.sse_builder import ( SSEBuilder, ContentBlockManager, map_stop_reason, - STOP_REASON_MAP, ) diff --git a/tests/test_streaming_errors.py b/tests/test_streaming_errors.py index a48312a..50bb017 100644 --- a/tests/test_streaming_errors.py +++ b/tests/test_streaming_errors.py @@ -359,9 +359,7 @@ class TestProcessToolCall: flushed = list(provider._flush_task_arg_buffers(sse)) assert len(flushed) > 0 assert "{}" in "".join(flushed) - assert any( - "NIM_INTERCEPT: Task args invalid JSON" in r.message for r in caplog.records - ) + assert any("Task args invalid JSON" in r.message for r in caplog.records) def test_negative_tool_index_fallback(self): """tc_index < 0 uses len(tool_indices) as fallback.""" diff --git a/tests/test_subagent_interception.py b/tests/test_subagent_interception.py index f05eae6..d2fef13 100644 --- a/tests/test_subagent_interception.py +++ b/tests/test_subagent_interception.py @@ -1,8 +1,8 @@ import json -import uuid import pytest -from unittest.mock import MagicMock, AsyncMock +from unittest.mock import MagicMock from providers.nvidia_nim import NvidiaNimProvider +from providers.nvidia_nim.utils.sse_builder import ContentBlockManager from providers.base import ProviderConfig @@ -12,18 +12,12 @@ async def test_task_tool_interception(): config = ProviderConfig(api_key="test") provider = NvidiaNimProvider(config) - # Mock request and sse builder + # Mock request and sse builder with real ContentBlockManager request = MagicMock() request.model = "test-model" sse = MagicMock() - sse.blocks = MagicMock() - sse.blocks.tool_indices = {} - sse.blocks.tool_names = {} - sse.blocks.tool_started = {} - sse.blocks.task_arg_buffer = {} - sse.blocks.task_args_emitted = {} - sse.blocks.tool_ids = {} + sse.blocks = ContentBlockManager() # Tool call data (Task tool) tc = { @@ -41,21 +35,16 @@ async def test_task_tool_interception(): }, } - # Remove pre-filled tool name - _process_tool_call handles it - # sse.blocks.tool_names[0] = "Task" - # Call the method events = [] - # _process_tool_call is a synchronous generator in nvidia_nim.py for event in provider._process_tool_call(tc, sse): events.append(event) - # Find the start_tool_block call or check the modified state + # Find the emit_tool_delta call and check args calls = sse.emit_tool_delta.call_args_list assert len(calls) > 0 args_passed = json.loads(calls[0][0][1]) assert args_passed["run_in_background"] is False - print("Verification successful: run_in_background was forced to False") if __name__ == "__main__": diff --git a/tests/test_tree_concurrency.py b/tests/test_tree_concurrency.py index fb1e76d..bc63ab9 100644 --- a/tests/test_tree_concurrency.py +++ b/tests/test_tree_concurrency.py @@ -1,7 +1,6 @@ """Concurrency and race condition tests for tree data structures and queue manager.""" import pytest -import pytest_asyncio import asyncio from messaging.models import IncomingMessage