diff --git a/api/models/anthropic.py b/api/models/anthropic.py index af4601f..f38a1ba 100644 --- a/api/models/anthropic.py +++ b/api/models/anthropic.py @@ -39,7 +39,7 @@ class ContentBlockToolUse(BaseModel): class ContentBlockToolResult(BaseModel): type: Literal["tool_result"] tool_use_id: str - content: str | list[dict[str, Any]] | dict[str, Any] | list[Any] | Any + content: str | list[Any] | dict[str, Any] class ContentBlockThinking(BaseModel): diff --git a/api/request_utils.py b/api/request_utils.py index c571eb5..4553963 100644 --- a/api/request_utils.py +++ b/api/request_utils.py @@ -9,18 +9,13 @@ from typing import Any import tiktoken from loguru import logger +from providers.common import get_block_attr + ENCODER = tiktoken.get_encoding("cl100k_base") __all__ = ["get_token_count"] -def _get_block_attr(block: Any, key: str, default: Any = "") -> Any: - """Get attribute from block (object or dict).""" - if isinstance(block, dict): - return block.get(key, default) - return getattr(block, key, default) - - def get_token_count( messages: list, system: str | list | None = None, @@ -38,7 +33,7 @@ def get_token_count( total_tokens += len(ENCODER.encode(system)) elif isinstance(system, list): for block in system: - text = _get_block_attr(block, "text", "") + text = get_block_attr(block, "text", "") if text: total_tokens += len(ENCODER.encode(str(text))) total_tokens += 4 # System block formatting overhead @@ -48,24 +43,24 @@ def get_token_count( total_tokens += len(ENCODER.encode(msg.content)) elif isinstance(msg.content, list): for block in msg.content: - b_type = _get_block_attr(block, "type") or None + b_type = get_block_attr(block, "type") or None if b_type == "text": - text = _get_block_attr(block, "text", "") + text = get_block_attr(block, "text", "") total_tokens += len(ENCODER.encode(str(text))) elif b_type == "thinking": - thinking = _get_block_attr(block, "thinking", "") + thinking = get_block_attr(block, "thinking", "") total_tokens += len(ENCODER.encode(str(thinking))) elif b_type == "tool_use": - name = _get_block_attr(block, "name", "") - inp = _get_block_attr(block, "input", {}) - block_id = _get_block_attr(block, "id", "") + name = get_block_attr(block, "name", "") + inp = get_block_attr(block, "input", {}) + block_id = get_block_attr(block, "id", "") total_tokens += len(ENCODER.encode(str(name))) total_tokens += len(ENCODER.encode(json.dumps(inp))) total_tokens += len(ENCODER.encode(str(block_id))) total_tokens += 15 elif b_type == "image": - source = _get_block_attr(block, "source") + source = get_block_attr(block, "source") if isinstance(source, dict): data = source.get("data") or source.get("base64") or "" if data: @@ -75,8 +70,8 @@ def get_token_count( else: total_tokens += 765 elif b_type == "tool_result": - content = _get_block_attr(block, "content", "") - tool_use_id = _get_block_attr(block, "tool_use_id", "") + content = get_block_attr(block, "content", "") + tool_use_id = get_block_attr(block, "tool_use_id", "") if isinstance(content, str): total_tokens += len(ENCODER.encode(content)) else: diff --git a/config/logging_config.py b/config/logging_config.py index 568ce5e..e1a31e4 100644 --- a/config/logging_config.py +++ b/config/logging_config.py @@ -8,6 +8,7 @@ included at top level for easy grep/filter. import json import logging +from pathlib import Path from loguru import logger @@ -71,7 +72,7 @@ def configure_logging(log_file: str, *, force: bool = False) -> None: logger.remove() # Truncate log file on fresh start for clean debugging - open(log_file, "w", encoding="utf-8").close() + Path(log_file).write_text("") # Add file sink: JSON lines, DEBUG level, context vars at top level logger.add( diff --git a/messaging/commands.py b/messaging/commands.py new file mode 100644 index 0000000..10c5932 --- /dev/null +++ b/messaging/commands.py @@ -0,0 +1,283 @@ +"""Command handlers for messaging platform commands (/stop, /stats, /clear). + +Extracted from ClaudeMessageHandler to keep handler.py focused on +core message processing logic. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from loguru import logger + +if TYPE_CHECKING: + from messaging.handler import ClaudeMessageHandler + from messaging.models import IncomingMessage + + +async def handle_stop_command( + handler: ClaudeMessageHandler, incoming: IncomingMessage +) -> None: + """Handle /stop command from messaging platform.""" + # Reply-scoped stop: reply "/stop" to stop only that task. + if incoming.is_reply() and incoming.reply_to_message_id: + reply_id = incoming.reply_to_message_id + tree = handler.tree_queue.get_tree_for_node(reply_id) + node_id = handler.tree_queue.resolve_parent_node_id(reply_id) if tree else None + + if not node_id: + msg_id = await handler.platform.queue_send_message( + incoming.chat_id, + handler._format_status( + "⏹", "Stopped.", "Nothing to stop for that message." + ), + fire_and_forget=False, + message_thread_id=incoming.message_thread_id, + ) + handler._record_outgoing_message( + incoming.platform, incoming.chat_id, msg_id, "command" + ) + return + + count = await handler.stop_task(node_id) + noun = "request" if count == 1 else "requests" + msg_id = await handler.platform.queue_send_message( + incoming.chat_id, + handler._format_status("⏹", "Stopped.", f"Cancelled {count} {noun}."), + fire_and_forget=False, + message_thread_id=incoming.message_thread_id, + ) + handler._record_outgoing_message( + incoming.platform, incoming.chat_id, msg_id, "command" + ) + return + + # Global stop: legacy behavior (stop everything) + count = await handler.stop_all_tasks() + msg_id = await handler.platform.queue_send_message( + incoming.chat_id, + handler._format_status( + "⏹", "Stopped.", f"Cancelled {count} pending or active requests." + ), + fire_and_forget=False, + message_thread_id=incoming.message_thread_id, + ) + handler._record_outgoing_message( + incoming.platform, incoming.chat_id, msg_id, "command" + ) + + +async def handle_stats_command( + handler: ClaudeMessageHandler, incoming: IncomingMessage +) -> None: + """Handle /stats command.""" + stats = handler.cli_manager.get_stats() + tree_count = handler.tree_queue.get_tree_count() + ctx = handler._get_render_ctx() + msg_id = await handler.platform.queue_send_message( + incoming.chat_id, + "📊 " + + ctx.bold("Stats") + + "\n" + + ctx.escape_text(f"• Active CLI: {stats['active_sessions']}") + + "\n" + + ctx.escape_text(f"• Message Trees: {tree_count}"), + fire_and_forget=False, + message_thread_id=incoming.message_thread_id, + ) + handler._record_outgoing_message( + incoming.platform, incoming.chat_id, msg_id, "command" + ) + + +async def _delete_message_ids( + handler: ClaudeMessageHandler, chat_id: str, msg_ids: set[str] +) -> None: + """Best-effort delete messages by ID. Sorts numeric IDs descending.""" + if not msg_ids: + return + + def _as_int(s: str) -> int | None: + try: + return int(str(s)) + except Exception: + return None + + numeric: list[tuple[int, str]] = [] + non_numeric: list[str] = [] + for mid in msg_ids: + n = _as_int(mid) + if n is None: + non_numeric.append(mid) + else: + numeric.append((n, mid)) + numeric.sort(reverse=True) + ordered = [mid for _, mid in numeric] + non_numeric + + batch_fn = getattr(handler.platform, "queue_delete_messages", None) + if callable(batch_fn): + try: + CHUNK = 100 + for i in range(0, len(ordered), CHUNK): + chunk = ordered[i : i + CHUNK] + await batch_fn(chat_id, chunk, fire_and_forget=False) + except Exception as e: + logger.debug(f"Batch delete failed: {type(e).__name__}: {e}") + else: + for mid in ordered: + try: + await handler.platform.queue_delete_message( + chat_id, mid, fire_and_forget=False + ) + except Exception as e: + logger.debug(f"Delete failed for msg {mid}: {type(e).__name__}: {e}") + + +async def _handle_clear_branch( + handler: ClaudeMessageHandler, + incoming: IncomingMessage, + branch_root_id: str, +) -> None: + """ + Clear a branch (replied-to node + all descendants). + + Order: cancel tasks, delete messages, remove branch, update session store. + """ + tree = handler.tree_queue.get_tree_for_node(branch_root_id) + if not tree: + return + + # 1) Cancel branch tasks (no stop_all) + cancelled = await handler.tree_queue.cancel_branch(branch_root_id) + handler._update_cancelled_nodes_ui(cancelled) + + # 2) Collect message IDs from branch nodes only + msg_ids: set[str] = set() + branch_ids = tree.get_descendants(branch_root_id) + for nid in branch_ids: + node = tree.get_node(nid) + if node: + if node.incoming.message_id: + msg_ids.add(str(node.incoming.message_id)) + if node.status_message_id: + msg_ids.add(str(node.status_message_id)) + if incoming.message_id: + msg_ids.add(str(incoming.message_id)) + + # 3) Delete messages (best-effort) + await _delete_message_ids(handler, incoming.chat_id, msg_ids) + + # 4) Remove branch from tree + removed, root_id, removed_entire_tree = await handler.tree_queue.remove_branch( + branch_root_id + ) + + # 5) Update session store + try: + handler.session_store.remove_node_mappings([n.node_id for n in removed]) + if removed_entire_tree: + handler.session_store.remove_tree(root_id) + else: + updated_tree = handler.tree_queue.get_tree(root_id) + if updated_tree: + handler.session_store.save_tree(root_id, updated_tree.to_dict()) + except Exception as e: + logger.warning(f"Failed to update session store after branch clear: {e}") + + +async def handle_clear_command( + handler: ClaudeMessageHandler, incoming: IncomingMessage +) -> None: + """ + Handle /clear command. + + Reply-scoped: reply to a message to clear that branch (node + descendants). + Standalone: global clear (stop all, delete all chat messages, reset store). + """ + from messaging.trees import TreeQueueManager + + if incoming.is_reply() and incoming.reply_to_message_id: + reply_id = incoming.reply_to_message_id + tree = handler.tree_queue.get_tree_for_node(reply_id) + branch_root_id = ( + handler.tree_queue.resolve_parent_node_id(reply_id) if tree else None + ) + if not branch_root_id: + cancel_fn = getattr(handler.platform, "cancel_pending_voice", None) + if cancel_fn is not None: + cancelled = await cancel_fn(incoming.chat_id, reply_id) + if cancelled is not None: + voice_msg_id, status_msg_id = cancelled + msg_ids_to_del: set[str] = {voice_msg_id, status_msg_id} + if incoming.message_id is not None: + msg_ids_to_del.add(str(incoming.message_id)) + await _delete_message_ids(handler, incoming.chat_id, msg_ids_to_del) + msg_id = await handler.platform.queue_send_message( + incoming.chat_id, + handler._format_status( + "🗑", "Cleared.", "Voice note cancelled." + ), + fire_and_forget=False, + message_thread_id=incoming.message_thread_id, + ) + handler._record_outgoing_message( + incoming.platform, incoming.chat_id, msg_id, "command" + ) + return + msg_id = await handler.platform.queue_send_message( + incoming.chat_id, + handler._format_status( + "🗑", "Cleared.", "Nothing to clear for that message." + ), + fire_and_forget=False, + message_thread_id=incoming.message_thread_id, + ) + handler._record_outgoing_message( + incoming.platform, incoming.chat_id, msg_id, "command" + ) + return + await _handle_clear_branch(handler, incoming, branch_root_id) + return + + # Global clear + # 1) Stop tasks first (ensures no more work is running). + await handler.stop_all_tasks() + + # 2) Clear chat: best-effort delete messages we can identify. + msg_ids: set[str] = set() + + # Add any recorded message IDs for this chat (commands, command replies, etc). + try: + for mid in handler.session_store.get_message_ids_for_chat( + incoming.platform, incoming.chat_id + ): + if mid is not None: + msg_ids.add(str(mid)) + except Exception as e: + logger.debug(f"Failed to read message log for /clear: {e}") + + try: + msg_ids.update( + handler.tree_queue.get_message_ids_for_chat( + incoming.platform, incoming.chat_id + ) + ) + except Exception as e: + logger.warning(f"Failed to gather messages for /clear: {e}") + + # Also delete the command message itself. + if incoming.message_id is not None: + msg_ids.add(str(incoming.message_id)) + + await _delete_message_ids(handler, incoming.chat_id, msg_ids) + + # 3) Clear persistent state and reset in-memory queue/tree state. + try: + handler.session_store.clear_all() + except Exception as e: + logger.warning(f"Failed to clear session store: {e}") + + handler.tree_queue = TreeQueueManager( + queue_update_callback=handler._update_queue_positions, + node_started_callback=handler._mark_node_processing, + ) diff --git a/messaging/handler.py b/messaging/handler.py index 98d4b4b..19a6e8d 100644 --- a/messaging/handler.py +++ b/messaging/handler.py @@ -12,6 +12,11 @@ import time from loguru import logger +from .commands import ( + handle_clear_command, + handle_stats_command, + handle_stop_command, +) from .event_parser import parse_cli_event from .models import IncomingMessage from .platforms.base import MessagingPlatform, SessionManagerInterface @@ -738,253 +743,12 @@ class ClaudeMessageHandler: async def _handle_stop_command(self, incoming: IncomingMessage) -> None: """Handle /stop command from messaging platform.""" - # Reply-scoped stop: reply "/stop" to stop only that task. - if incoming.is_reply() and incoming.reply_to_message_id: - reply_id = incoming.reply_to_message_id - tree = self.tree_queue.get_tree_for_node(reply_id) - node_id = self.tree_queue.resolve_parent_node_id(reply_id) if tree else None - - if not node_id: - msg_id = await self.platform.queue_send_message( - incoming.chat_id, - self._format_status( - "⏹", "Stopped.", "Nothing to stop for that message." - ), - fire_and_forget=False, - message_thread_id=incoming.message_thread_id, - ) - self._record_outgoing_message( - incoming.platform, incoming.chat_id, msg_id, "command" - ) - return - - count = await self.stop_task(node_id) - noun = "request" if count == 1 else "requests" - msg_id = await self.platform.queue_send_message( - incoming.chat_id, - self._format_status("⏹", "Stopped.", f"Cancelled {count} {noun}."), - fire_and_forget=False, - message_thread_id=incoming.message_thread_id, - ) - self._record_outgoing_message( - incoming.platform, incoming.chat_id, msg_id, "command" - ) - return - - # Global stop: legacy behavior (stop everything) - count = await self.stop_all_tasks() - msg_id = await self.platform.queue_send_message( - incoming.chat_id, - self._format_status( - "⏹", "Stopped.", f"Cancelled {count} pending or active requests." - ), - fire_and_forget=False, - message_thread_id=incoming.message_thread_id, - ) - self._record_outgoing_message( - incoming.platform, incoming.chat_id, msg_id, "command" - ) + await handle_stop_command(self, incoming) async def _handle_stats_command(self, incoming: IncomingMessage) -> None: """Handle /stats command.""" - stats = self.cli_manager.get_stats() - tree_count = self.tree_queue.get_tree_count() - ctx = self._get_render_ctx() - msg_id = await self.platform.queue_send_message( - incoming.chat_id, - "📊 " - + ctx.bold("Stats") - + "\n" - + ctx.escape_text(f"• Active CLI: {stats['active_sessions']}") - + "\n" - + ctx.escape_text(f"• Message Trees: {tree_count}"), - fire_and_forget=False, - message_thread_id=incoming.message_thread_id, - ) - self._record_outgoing_message( - incoming.platform, incoming.chat_id, msg_id, "command" - ) - - async def _handle_clear_branch( - self, incoming: IncomingMessage, branch_root_id: str - ) -> None: - """ - Clear a branch (replied-to node + all descendants). - - Order: cancel tasks, delete messages, remove branch, update session store. - """ - tree = self.tree_queue.get_tree_for_node(branch_root_id) - if not tree: - return - - # 1) Cancel branch tasks (no stop_all) - cancelled = await self.tree_queue.cancel_branch(branch_root_id) - self._update_cancelled_nodes_ui(cancelled) - - # 2) Collect message IDs from branch nodes only - msg_ids: set[str] = set() - branch_ids = tree.get_descendants(branch_root_id) - for nid in branch_ids: - node = tree.get_node(nid) - if node: - if node.incoming.message_id: - msg_ids.add(str(node.incoming.message_id)) - if node.status_message_id: - msg_ids.add(str(node.status_message_id)) - if incoming.message_id: - msg_ids.add(str(incoming.message_id)) - - # 3) Delete messages (best-effort) - await self._delete_message_ids(incoming.chat_id, msg_ids) - - # 4) Remove branch from tree - removed, root_id, removed_entire_tree = await self.tree_queue.remove_branch( - branch_root_id - ) - - # 5) Update session store - try: - self.session_store.remove_node_mappings([n.node_id for n in removed]) - if removed_entire_tree: - self.session_store.remove_tree(root_id) - else: - updated_tree = self.tree_queue.get_tree(root_id) - if updated_tree: - self.session_store.save_tree(root_id, updated_tree.to_dict()) - except Exception as e: - logger.warning(f"Failed to update session store after branch clear: {e}") - - async def _delete_message_ids(self, chat_id: str, msg_ids: set[str]) -> None: - """Best-effort delete messages by ID. Sorts numeric IDs descending.""" - if not msg_ids: - return - - def _as_int(s: str) -> int | None: - try: - return int(str(s)) - except Exception: - return None - - numeric: list[tuple[int, str]] = [] - non_numeric: list[str] = [] - for mid in msg_ids: - n = _as_int(mid) - if n is None: - non_numeric.append(mid) - else: - numeric.append((n, mid)) - numeric.sort(reverse=True) - ordered = [mid for _, mid in numeric] + non_numeric - - batch_fn = getattr(self.platform, "queue_delete_messages", None) - if callable(batch_fn): - try: - CHUNK = 100 - for i in range(0, len(ordered), CHUNK): - chunk = ordered[i : i + CHUNK] - await batch_fn(chat_id, chunk, fire_and_forget=False) - except Exception as e: - logger.debug(f"Batch delete failed: {type(e).__name__}: {e}") - else: - for mid in ordered: - try: - await self.platform.queue_delete_message( - chat_id, mid, fire_and_forget=False - ) - except Exception as e: - logger.debug( - f"Delete failed for msg {mid}: {type(e).__name__}: {e}" - ) + await handle_stats_command(self, incoming) async def _handle_clear_command(self, incoming: IncomingMessage) -> None: - """ - Handle /clear command. - - Reply-scoped: reply to a message to clear that branch (node + descendants). - Standalone: global clear (stop all, delete all chat messages, reset store). - """ - if incoming.is_reply() and incoming.reply_to_message_id: - reply_id = incoming.reply_to_message_id - tree = self.tree_queue.get_tree_for_node(reply_id) - branch_root_id = ( - self.tree_queue.resolve_parent_node_id(reply_id) if tree else None - ) - if not branch_root_id: - cancel_fn = getattr(self.platform, "cancel_pending_voice", None) - if cancel_fn is not None: - cancelled = await cancel_fn(incoming.chat_id, reply_id) - if cancelled is not None: - voice_msg_id, status_msg_id = cancelled - msg_ids_to_del: set[str] = {voice_msg_id, status_msg_id} - if incoming.message_id is not None: - msg_ids_to_del.add(str(incoming.message_id)) - await self._delete_message_ids(incoming.chat_id, msg_ids_to_del) - msg_id = await self.platform.queue_send_message( - incoming.chat_id, - self._format_status( - "🗑", "Cleared.", "Voice note cancelled." - ), - fire_and_forget=False, - message_thread_id=incoming.message_thread_id, - ) - self._record_outgoing_message( - incoming.platform, incoming.chat_id, msg_id, "command" - ) - return - msg_id = await self.platform.queue_send_message( - incoming.chat_id, - self._format_status( - "🗑", "Cleared.", "Nothing to clear for that message." - ), - fire_and_forget=False, - message_thread_id=incoming.message_thread_id, - ) - self._record_outgoing_message( - incoming.platform, incoming.chat_id, msg_id, "command" - ) - return - await self._handle_clear_branch(incoming, branch_root_id) - return - - # Global clear - # 1) Stop tasks first (ensures no more work is running). - await self.stop_all_tasks() - - # 2) Clear chat: best-effort delete messages we can identify. - msg_ids: set[str] = set() - - # Add any recorded message IDs for this chat (commands, command replies, etc). - try: - for mid in self.session_store.get_message_ids_for_chat( - incoming.platform, incoming.chat_id - ): - if mid is not None: - msg_ids.add(str(mid)) - except Exception as e: - logger.debug(f"Failed to read message log for /clear: {e}") - - try: - msg_ids.update( - self.tree_queue.get_message_ids_for_chat( - incoming.platform, incoming.chat_id - ) - ) - except Exception as e: - logger.warning(f"Failed to gather messages for /clear: {e}") - - # Also delete the command message itself. - if incoming.message_id is not None: - msg_ids.add(str(incoming.message_id)) - - await self._delete_message_ids(incoming.chat_id, msg_ids) - - # 3) Clear persistent state and reset in-memory queue/tree state. - try: - self.session_store.clear_all() - except Exception as e: - logger.warning(f"Failed to clear session store: {e}") - - self.tree_queue = TreeQueueManager( - queue_update_callback=self._update_queue_positions, - node_started_callback=self._mark_node_processing, - ) + """Handle /clear command.""" + await handle_clear_command(self, incoming) diff --git a/providers/common/__init__.py b/providers/common/__init__.py index 3c43bfe..14c6df3 100644 --- a/providers/common/__init__.py +++ b/providers/common/__init__.py @@ -4,6 +4,7 @@ from .error_mapping import map_error from .heuristic_tool_parser import HeuristicToolParser from .message_converter import ( AnthropicToOpenAIConverter, + build_base_request_body, get_block_attr, get_block_type, ) @@ -19,6 +20,7 @@ __all__ = [ "HeuristicToolParser", "SSEBuilder", "ThinkTagParser", + "build_base_request_body", "get_block_attr", "get_block_type", "map_error", diff --git a/providers/common/message_converter.py b/providers/common/message_converter.py index 6b249a4..b07fa3b 100644 --- a/providers/common/message_converter.py +++ b/providers/common/message_converter.py @@ -178,3 +178,49 @@ class AnthropicToOpenAIConverter: if text_parts: return {"role": "system", "content": "\n\n".join(text_parts).strip()} return None + + +def build_base_request_body( + request_data: Any, + *, + default_max_tokens: int | None = None, + include_reasoning_for_openrouter: bool = False, +) -> dict[str, Any]: + """Build the common parts of an OpenAI-format request body. + + Handles message conversion, system prompt, max_tokens, temperature, + top_p, stop sequences, tools, and tool_choice. Provider-specific + parameters (extra_body, penalties, NIM settings) are added by callers. + """ + from providers.common.utils import set_if_not_none + + messages = AnthropicToOpenAIConverter.convert_messages( + request_data.messages, + include_reasoning_for_openrouter=include_reasoning_for_openrouter, + ) + + system = getattr(request_data, "system", None) + if system: + system_msg = AnthropicToOpenAIConverter.convert_system_prompt(system) + if system_msg: + messages.insert(0, system_msg) + + body: dict[str, Any] = {"model": request_data.model, "messages": messages} + + max_tokens = getattr(request_data, "max_tokens", None) + set_if_not_none(body, "max_tokens", max_tokens or default_max_tokens) + set_if_not_none(body, "temperature", getattr(request_data, "temperature", None)) + set_if_not_none(body, "top_p", getattr(request_data, "top_p", None)) + + stop_sequences = getattr(request_data, "stop_sequences", None) + if stop_sequences: + body["stop"] = stop_sequences + + tools = getattr(request_data, "tools", None) + if tools: + body["tools"] = AnthropicToOpenAIConverter.convert_tools(tools) + tool_choice = getattr(request_data, "tool_choice", None) + if tool_choice: + body["tool_choice"] = tool_choice + + return body diff --git a/providers/common/sse_builder.py b/providers/common/sse_builder.py index d95efce..201f224 100644 --- a/providers/common/sse_builder.py +++ b/providers/common/sse_builder.py @@ -187,10 +187,6 @@ class SSEBuilder: """Generate message_stop event.""" return self._format_event("message_stop", {"type": "message_stop"}) - def done(self) -> str: - """Generate [DONE] marker.""" - return "[DONE]\n\n" - # Content block events def content_block_start(self, index: int, block_type: str, **kwargs) -> str: """Generate content_block_start event.""" diff --git a/providers/lmstudio/request.py b/providers/lmstudio/request.py index 6b45b1c..8edb4c1 100644 --- a/providers/lmstudio/request.py +++ b/providers/lmstudio/request.py @@ -4,8 +4,7 @@ from typing import Any from loguru import logger -from providers.common.message_converter import AnthropicToOpenAIConverter -from providers.common.utils import set_if_not_none +from providers.common.message_converter import build_base_request_body LMSTUDIO_DEFAULT_MAX_TOKENS = 81920 @@ -17,39 +16,10 @@ def build_request_body(request_data: Any) -> dict: getattr(request_data, "model", "?"), len(getattr(request_data, "messages", [])), ) - messages = AnthropicToOpenAIConverter.convert_messages( - request_data.messages, include_reasoning_for_openrouter=False + body = build_base_request_body( + request_data, default_max_tokens=LMSTUDIO_DEFAULT_MAX_TOKENS ) - # Add system prompt - system = getattr(request_data, "system", None) - if system: - system_msg = AnthropicToOpenAIConverter.convert_system_prompt(system) - if system_msg: - messages.insert(0, system_msg) - - body: dict[str, Any] = { - "model": request_data.model, - "messages": messages, - } - - max_tokens = getattr(request_data, "max_tokens", None) - set_if_not_none(body, "max_tokens", max_tokens or LMSTUDIO_DEFAULT_MAX_TOKENS) - - set_if_not_none(body, "temperature", getattr(request_data, "temperature", None)) - set_if_not_none(body, "top_p", getattr(request_data, "top_p", None)) - - stop_sequences = getattr(request_data, "stop_sequences", None) - if stop_sequences: - body["stop"] = stop_sequences - - tools = getattr(request_data, "tools", None) - if tools: - body["tools"] = AnthropicToOpenAIConverter.convert_tools(tools) - tool_choice = getattr(request_data, "tool_choice", None) - if tool_choice: - body["tool_choice"] = tool_choice - logger.debug( "LMSTUDIO_REQUEST: conversion done model=%s msgs=%d tools=%d", body.get("model"), diff --git a/providers/nvidia_nim/errors.py b/providers/nvidia_nim/errors.py deleted file mode 100644 index 4c70b57..0000000 --- a/providers/nvidia_nim/errors.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Error mapping for NVIDIA NIM provider (re-exports from providers.common).""" - -from providers.common.error_mapping import map_error - -__all__ = ["map_error"] diff --git a/providers/nvidia_nim/request.py b/providers/nvidia_nim/request.py index 7f61d6c..c7d7069 100644 --- a/providers/nvidia_nim/request.py +++ b/providers/nvidia_nim/request.py @@ -5,7 +5,7 @@ from typing import Any from loguru import logger from config.nim import NimSettings -from providers.common.message_converter import AnthropicToOpenAIConverter +from providers.common.message_converter import build_base_request_body from providers.common.utils import set_if_not_none @@ -28,49 +28,26 @@ def build_request_body(request_data: Any, nim: NimSettings) -> dict: getattr(request_data, "model", "?"), len(getattr(request_data, "messages", [])), ) - messages = AnthropicToOpenAIConverter.convert_messages(request_data.messages) + body = build_base_request_body(request_data) - # Add system prompt - system = getattr(request_data, "system", None) - if system: - system_msg = AnthropicToOpenAIConverter.convert_system_prompt(system) - if system_msg: - messages.insert(0, system_msg) - - body: dict[str, Any] = { - "model": request_data.model, - "messages": messages, - } - - # max_tokens with optional cap - max_tokens = getattr(request_data, "max_tokens", None) + # NIM-specific max_tokens: cap against nim.max_tokens + max_tokens = body.get("max_tokens") or getattr(request_data, "max_tokens", None) if max_tokens is None: max_tokens = nim.max_tokens elif nim.max_tokens: max_tokens = min(max_tokens, nim.max_tokens) set_if_not_none(body, "max_tokens", max_tokens) - req_temperature = getattr(request_data, "temperature", None) - temperature = req_temperature if req_temperature is not None else nim.temperature - set_if_not_none(body, "temperature", temperature) + # NIM-specific temperature/top_p: fall back to NIM defaults if request didn't set + if body.get("temperature") is None and nim.temperature is not None: + body["temperature"] = nim.temperature + if body.get("top_p") is None and nim.top_p is not None: + body["top_p"] = nim.top_p - req_top_p = getattr(request_data, "top_p", None) - top_p = req_top_p if req_top_p is not None else nim.top_p - set_if_not_none(body, "top_p", top_p) - - stop_sequences = getattr(request_data, "stop_sequences", None) - if stop_sequences: - body["stop"] = stop_sequences - elif nim.stop: + # NIM-specific stop sequences fallback + if "stop" not in body and nim.stop: body["stop"] = nim.stop - tools = getattr(request_data, "tools", None) - if tools: - body["tools"] = AnthropicToOpenAIConverter.convert_tools(tools) - tool_choice = getattr(request_data, "tool_choice", None) - if tool_choice: - body["tool_choice"] = tool_choice - if nim.presence_penalty != 0.0: body["presence_penalty"] = nim.presence_penalty if nim.frequency_penalty != 0.0: diff --git a/providers/nvidia_nim/utils/__init__.py b/providers/nvidia_nim/utils/__init__.py deleted file mode 100644 index bbca7f4..0000000 --- a/providers/nvidia_nim/utils/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Utility modules for providers (re-exports from providers.common).""" - -from providers.common import ( - AnthropicToOpenAIConverter, - ContentBlockManager, - ContentChunk, - ContentType, - HeuristicToolParser, - SSEBuilder, - ThinkTagParser, - get_block_attr, - get_block_type, - map_stop_reason, -) - -__all__ = [ - "AnthropicToOpenAIConverter", - "ContentBlockManager", - "ContentChunk", - "ContentType", - "HeuristicToolParser", - "SSEBuilder", - "ThinkTagParser", - "get_block_attr", - "get_block_type", - "map_stop_reason", -] diff --git a/providers/open_router/request.py b/providers/open_router/request.py index 6d64b1f..a2ba9eb 100644 --- a/providers/open_router/request.py +++ b/providers/open_router/request.py @@ -4,8 +4,7 @@ from typing import Any from loguru import logger -from providers.common.message_converter import AnthropicToOpenAIConverter -from providers.common.utils import set_if_not_none +from providers.common.message_converter import build_base_request_body OPENROUTER_DEFAULT_MAX_TOKENS = 81920 @@ -17,39 +16,12 @@ def build_request_body(request_data: Any) -> dict: getattr(request_data, "model", "?"), len(getattr(request_data, "messages", [])), ) - messages = AnthropicToOpenAIConverter.convert_messages( - request_data.messages, include_reasoning_for_openrouter=True + body = build_base_request_body( + request_data, + default_max_tokens=OPENROUTER_DEFAULT_MAX_TOKENS, + include_reasoning_for_openrouter=True, ) - # Add system prompt - system = getattr(request_data, "system", None) - if system: - system_msg = AnthropicToOpenAIConverter.convert_system_prompt(system) - if system_msg: - messages.insert(0, system_msg) - - body: dict[str, Any] = { - "model": request_data.model, - "messages": messages, - } - - max_tokens = getattr(request_data, "max_tokens", None) - set_if_not_none(body, "max_tokens", max_tokens or OPENROUTER_DEFAULT_MAX_TOKENS) - - set_if_not_none(body, "temperature", getattr(request_data, "temperature", None)) - set_if_not_none(body, "top_p", getattr(request_data, "top_p", None)) - - stop_sequences = getattr(request_data, "stop_sequences", None) - if stop_sequences: - body["stop"] = stop_sequences - - tools = getattr(request_data, "tools", None) - if tools: - body["tools"] = AnthropicToOpenAIConverter.convert_tools(tools) - tool_choice = getattr(request_data, "tool_choice", None) - if tool_choice: - body["tool_choice"] = tool_choice - # OpenRouter reasoning: extra_body={"reasoning": {"enabled": True}} extra_body: dict[str, Any] = {} request_extra = getattr(request_data, "extra_body", None) diff --git a/tests/providers/test_sse_builder.py b/tests/providers/test_sse_builder.py index 7bdb37f..5c2400f 100644 --- a/tests/providers/test_sse_builder.py +++ b/tests/providers/test_sse_builder.py @@ -67,7 +67,7 @@ class TestContentBlockManager: class TestSSEBuilderMessageLifecycle: - """Tests for message_start, message_delta, message_stop, done.""" + """Tests for message_start, message_delta, message_stop.""" def test_message_start(self): builder = SSEBuilder("msg_123", "test-model", input_tokens=50) @@ -102,10 +102,6 @@ class TestSSEBuilderMessageLifecycle: data = _parse_sse(sse) assert data["type"] == "message_stop" - def test_done(self): - builder = SSEBuilder("msg_1", "model") - assert builder.done() == "[DONE]\n\n" - class TestSSEBuilderContentBlocks: """Tests for content block start/delta/stop events."""