feat(SDK) Add Python SDK implementation for #3010 (#3494)

* Codex worktree snapshot: startup-cleanup

Co-authored-by: Codex

* Add Python SDK real smoke test

Adds a repository-only real E2E smoke script for the Python SDK, plus npm and developer documentation entry points.

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>

* fix(sdk-python): address review findings — bugs, type safety, and test coverage

- Fix prepare_spawn_info: JS files now use "node" instead of sys.executable
- Fix protocol.py: correct total=False misuse on 7 TypedDicts (required fields were optional)
- Fix query.py: add _closed guard in _ensure_started, suppress exceptions in close()
- Fix sync_query.py: prevent close() deadlock, add context manager, add timeouts
- Fix transport.py: handle malformed JSON lines, add _closed guard in start()
- Fix validation.py: use uuid.RFC_4122 instead of magic UUID
- Fix __init__.py: export TextBlock, widen query_sync signature
- Remove dead code: ensure_not_aborted, write_json_line, _thread_error
- Add 12 new tests (29 → 41): context managers, JSON skip, closed guards, spawn info, timeouts

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>

* fix(sdk-python): address wenshao review — session_id, bool validation, debug stderr

- Fix continue_session=True generating a wrong random session_id
- Add _as_optional_bool helper for strict type validation on bool fields
- Default debug stderr to sys.stderr when no custom callback is provided

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>

* fix(sdk-python): address remaining wenshao review feedback

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>

* test(cli): harden settings dialog restart prompt test

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>

* fix(sdk-python): review fixes — UUID compat, stderr fallback, sync cleanup

- Remove UUID version restriction to support v6/v7/v8 (RFC 9562)
- Always write to sys.stderr when stderr callback raises (was silent when debug=False)
- Prevent duplicate _STOP sentinel in SyncQuery.close() via _stop_sent flag
- Add ruff format --check to CI workflow
- Fix smoke_real.py version guard: fail early before imports instead of NameError
- Apply ruff format to existing files

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>

* fix(sdk-python): remaining review fixes — exit_code attr, guard strictness, sync timeout

- Add exit_code attribute to ProcessExitError for programmatic access
- Strengthen is_control_response/is_control_cancel guards to require
  payload fields, preventing misrouting of malformed messages
- Expose control_request_timeout property on Query so SyncQuery uses
  the configured timeout instead of a hardcoded 30s default
- Use dataclasses.replace() instead of direct mutation on frozen-style
  QueryOptions in query() factory
- Add ResourceWarning in SyncQuery.__del__ when not properly closed

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>

* fix(sdk-python): add exit_code default and guard __del__ against partial GC

- Give ProcessExitError.exit_code a default value (-1) so user code can
  construct the exception with just a message string
- Wrap SyncQuery.__del__ in try/except AttributeError to prevent crashes
  when the object is partially garbage-collected

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>

* fix(sdk-python): review fixes — resource leak, type safety, CI matrix, docs

- Fix SyncQuery.__del__ to call close() on GC instead of only warning
- Replace hasattr duck-type check with isinstance(prompt, AsyncIterable)
- Type-validate permission_mode/auth_type in QueryOptions.from_mapping
- Use TypeGuard return types on all is_sdk_*/is_control_* predicates
- Add 5s margin to sync wrapper timeouts to prevent error type masking
- Expand CI matrix to test Python 3.10, 3.11, 3.12
- Change ProcessExitError.exit_code default from -1 to None
- Add stderr to docs QueryOptions listing
- Update README sync example to use context manager pattern

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>

* fix(sdk-python): preserve iterator exhaustion state and suppress detached task warning

- Add _exhausted flag to Query.__anext__ and SyncQuery.__next__ so
  repeated iteration after end-of-stream raises Stop(Async)Iteration
  instead of blocking forever.
- Remove re-raise in _initialize() to prevent asyncio
  "Task exception was never retrieved" warning on detached tasks;
  the error is already surfaced via _finish_with_error().

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>

* fix(sdk-python): reject mcp_servers at validation time and add iterator/init tests

- Reject mcp_servers in validate_query_options() with a clear error
  instead of advertising MCP support to the CLI and then failing at
  runtime when mcp_message arrives.
- Remove dead mcp_servers branch from _initialize().
- Add tests for async/sync iterator exhaustion, detached init task
  warning suppression, and mcp_servers validation.

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>

* fix(sdk-python): fix ruff lint errors in new tests

- Use ControlRequestTimeoutError instead of bare Exception (B017)
- Fix import sorting for stdlib vs third-party (I001)
- Break long line to stay within 88-char limit (E501)

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>

* style(sdk-python): apply ruff format to new tests

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>

---------

Co-authored-by: jinye.djy <jinye.djy@alibaba-inc.com>
Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
This commit is contained in:
jinye 2026-04-25 07:02:58 +08:00 committed by GitHub
parent 202be6ec7d
commit e384338145
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 4676 additions and 14 deletions

View file

@ -0,0 +1,108 @@
"""qwen_code_sdk package exports."""
from __future__ import annotations
from collections.abc import AsyncIterable, Iterable, Mapping
from typing import Any
from .errors import (
AbortError,
ControlRequestTimeoutError,
ProcessExitError,
QwenSDKError,
ValidationError,
)
from .protocol import (
APIAssistantMessage,
APIUserMessage,
ContentBlock,
SDKAssistantMessage,
SDKMessage,
SDKPartialAssistantMessage,
SDKResultMessage,
SDKSystemMessage,
SDKUserMessage,
TextBlock,
ThinkingBlock,
ToolResultBlock,
ToolUseBlock,
Usage,
is_control_cancel,
is_control_request,
is_control_response,
is_sdk_assistant_message,
is_sdk_partial_assistant_message,
is_sdk_result_message,
is_sdk_system_message,
is_sdk_user_message,
)
from .query import Query, query
from .sync_query import SyncQuery
from .types import (
AuthType,
CanUseTool,
CanUseToolContext,
PermissionAllowResult,
PermissionDenyResult,
PermissionMode,
PermissionResult,
PermissionSuggestion,
QueryOptions,
QueryOptionsDict,
TimeoutOptions,
TimeoutOptionsDict,
)
def query_sync(
prompt: str | Iterable[SDKUserMessage] | AsyncIterable[SDKUserMessage],
options: QueryOptions | QueryOptionsDict | Mapping[str, Any] | None = None,
) -> SyncQuery:
return SyncQuery(prompt=prompt, options=options)
__all__ = [
"APIAssistantMessage",
"APIUserMessage",
"AbortError",
"AuthType",
"CanUseTool",
"CanUseToolContext",
"ContentBlock",
"ControlRequestTimeoutError",
"PermissionAllowResult",
"PermissionDenyResult",
"PermissionMode",
"PermissionResult",
"PermissionSuggestion",
"ProcessExitError",
"Query",
"QueryOptions",
"QueryOptionsDict",
"QwenSDKError",
"SDKAssistantMessage",
"SDKMessage",
"SDKPartialAssistantMessage",
"SDKResultMessage",
"SDKSystemMessage",
"SDKUserMessage",
"SyncQuery",
"TextBlock",
"ThinkingBlock",
"TimeoutOptions",
"TimeoutOptionsDict",
"ToolResultBlock",
"ToolUseBlock",
"Usage",
"ValidationError",
"is_control_cancel",
"is_control_request",
"is_control_response",
"is_sdk_assistant_message",
"is_sdk_partial_assistant_message",
"is_sdk_result_message",
"is_sdk_system_message",
"is_sdk_user_message",
"query",
"query_sync",
]

View file

@ -0,0 +1,27 @@
"""Error types for qwen_code_sdk."""
from __future__ import annotations
class QwenSDKError(Exception):
"""Base error for all SDK failures."""
class ValidationError(QwenSDKError):
"""Raised when query options are invalid."""
class AbortError(QwenSDKError):
"""Raised when an operation is aborted by caller or transport."""
class ProcessExitError(QwenSDKError):
"""Raised when qwen CLI exits with non-zero status or signal."""
def __init__(self, message: str, exit_code: int | None = None) -> None:
super().__init__(message)
self.exit_code = exit_code
class ControlRequestTimeoutError(QwenSDKError):
"""Raised when a control request times out waiting for response."""

View file

@ -0,0 +1,14 @@
"""JSON lines utilities."""
from __future__ import annotations
import json
from typing import Any
def serialize_json_line(payload: Any) -> str:
return json.dumps(payload, ensure_ascii=False, separators=(",", ":")) + "\n"
def parse_json_line(line: str) -> Any:
return json.loads(line)

View file

@ -0,0 +1,357 @@
"""Protocol message types and helpers for qwen stream-json."""
from __future__ import annotations
from typing import Any, Literal, TypeAlias, TypeGuard
from typing_extensions import NotRequired, TypedDict
from .types import PermissionMode, PermissionSuggestion
class Annotation(TypedDict):
type: str
value: str
class Usage(TypedDict):
input_tokens: int
output_tokens: int
cache_creation_input_tokens: NotRequired[int]
cache_read_input_tokens: NotRequired[int]
total_tokens: NotRequired[int]
class ExtendedUsage(Usage, total=False):
server_tool_use: dict[str, int]
service_tier: str
cache_creation: dict[str, int]
class CLIPermissionDenial(TypedDict):
tool_name: str
tool_use_id: str
tool_input: Any
class TextBlock(TypedDict):
type: Literal["text"]
text: str
annotations: NotRequired[list[Annotation]]
class ThinkingBlock(TypedDict):
type: Literal["thinking"]
thinking: str
signature: NotRequired[str]
annotations: NotRequired[list[Annotation]]
class ToolUseBlock(TypedDict):
type: Literal["tool_use"]
id: str
name: str
input: Any
annotations: NotRequired[list[Annotation]]
class ToolResultBlock(TypedDict):
type: Literal["tool_result"]
tool_use_id: str
content: NotRequired[str | list[ContentBlock]]
is_error: NotRequired[bool]
annotations: NotRequired[list[Annotation]]
ContentBlock: TypeAlias = TextBlock | ThinkingBlock | ToolUseBlock | ToolResultBlock
class APIUserMessage(TypedDict):
role: Literal["user"]
content: str | list[ContentBlock]
class APIAssistantMessage(TypedDict):
role: Literal["assistant"]
content: list[ContentBlock]
id: NotRequired[str]
type: NotRequired[Literal["message"]]
model: NotRequired[str]
stop_reason: NotRequired[str | None]
usage: NotRequired[Usage]
class SDKUserMessage(TypedDict):
type: Literal["user"]
session_id: str
message: APIUserMessage
parent_tool_use_id: str | None
uuid: NotRequired[str]
options: NotRequired[dict[str, Any]]
class SDKAssistantMessage(TypedDict):
type: Literal["assistant"]
uuid: str
session_id: str
message: APIAssistantMessage
parent_tool_use_id: str | None
class MCPServerState(TypedDict):
name: str
status: str
class SDKSystemMessage(TypedDict):
type: Literal["system"]
subtype: str
uuid: str
session_id: str
data: NotRequired[Any]
cwd: NotRequired[str]
tools: NotRequired[list[str]]
mcp_servers: NotRequired[list[MCPServerState]]
model: NotRequired[str]
permission_mode: NotRequired[str]
slash_commands: NotRequired[list[str]]
qwen_code_version: NotRequired[str]
output_style: NotRequired[str]
agents: NotRequired[list[str]]
skills: NotRequired[list[str]]
capabilities: NotRequired[dict[str, Any]]
class SDKResultMessageSuccess(TypedDict):
type: Literal["result"]
subtype: Literal["success"]
uuid: str
session_id: str
is_error: Literal[False]
duration_ms: int
duration_api_ms: int
num_turns: int
result: str
usage: ExtendedUsage
permission_denials: list[CLIPermissionDenial]
class ResultErrorObject(TypedDict):
message: str
type: NotRequired[str]
class SDKResultMessageError(TypedDict):
type: Literal["result"]
subtype: Literal["error_max_turns", "error_during_execution"]
uuid: str
session_id: str
is_error: Literal[True]
duration_ms: int
duration_api_ms: int
num_turns: int
usage: ExtendedUsage
permission_denials: list[CLIPermissionDenial]
error: NotRequired[ResultErrorObject]
SDKResultMessage: TypeAlias = SDKResultMessageSuccess | SDKResultMessageError
class MessageStartStreamEvent(TypedDict):
type: Literal["message_start"]
message: dict[str, Any]
class ContentBlockStartEvent(TypedDict):
type: Literal["content_block_start"]
index: int
content_block: ContentBlock
class ContentBlockDeltaEvent(TypedDict):
type: Literal["content_block_delta"]
index: int
delta: dict[str, Any]
class ContentBlockStopEvent(TypedDict):
type: Literal["content_block_stop"]
index: int
class MessageStopStreamEvent(TypedDict):
type: Literal["message_stop"]
StreamEvent: TypeAlias = (
MessageStartStreamEvent
| ContentBlockStartEvent
| ContentBlockDeltaEvent
| ContentBlockStopEvent
| MessageStopStreamEvent
)
class SDKPartialAssistantMessage(TypedDict):
type: Literal["stream_event"]
uuid: str
session_id: str
event: StreamEvent
parent_tool_use_id: str | None
class CLIControlInterruptRequest(TypedDict):
subtype: Literal["interrupt"]
class CLIControlPermissionRequest(TypedDict):
subtype: Literal["can_use_tool"]
tool_name: str
tool_use_id: str
input: Any
permission_suggestions: list[PermissionSuggestion] | None
blocked_path: str | None
class CLIControlInitializeRequest(TypedDict):
subtype: Literal["initialize"]
hooks: NotRequired[Any]
mcpServers: NotRequired[dict[str, dict[str, Any]]]
class CLIControlSetPermissionModeRequest(TypedDict):
subtype: Literal["set_permission_mode"]
mode: PermissionMode
class CLIControlSetModelRequest(TypedDict):
subtype: Literal["set_model"]
model: str
class CLIControlMcpStatusRequest(TypedDict):
subtype: Literal["mcp_server_status"]
class CLIControlSupportedCommandsRequest(TypedDict):
subtype: Literal["supported_commands"]
ControlRequestPayload: TypeAlias = (
CLIControlInterruptRequest
| CLIControlPermissionRequest
| CLIControlInitializeRequest
| CLIControlSetPermissionModeRequest
| CLIControlSetModelRequest
| CLIControlMcpStatusRequest
| CLIControlSupportedCommandsRequest
| dict[str, Any]
)
class CLIControlRequest(TypedDict):
type: Literal["control_request"]
request_id: str
request: ControlRequestPayload
class ControlResponseSuccess(TypedDict):
subtype: Literal["success"]
request_id: str
response: Any
class ControlResponseError(TypedDict):
subtype: Literal["error"]
request_id: str
error: str | dict[str, Any]
class CLIControlResponse(TypedDict):
type: Literal["control_response"]
response: ControlResponseSuccess | ControlResponseError
class ControlCancelRequest(TypedDict):
type: Literal["control_cancel_request"]
request_id: NotRequired[str]
SDKMessage: TypeAlias = (
SDKUserMessage
| SDKAssistantMessage
| SDKSystemMessage
| SDKResultMessage
| SDKPartialAssistantMessage
)
ControlMessage: TypeAlias = (
CLIControlRequest | CLIControlResponse | ControlCancelRequest
)
def is_sdk_user_message(msg: Any) -> TypeGuard[SDKUserMessage]:
return isinstance(msg, dict) and msg.get("type") == "user" and "message" in msg
def is_sdk_assistant_message(msg: Any) -> TypeGuard[SDKAssistantMessage]:
return (
isinstance(msg, dict)
and msg.get("type") == "assistant"
and "session_id" in msg
and "message" in msg
)
def is_sdk_system_message(msg: Any) -> TypeGuard[SDKSystemMessage]:
return (
isinstance(msg, dict)
and msg.get("type") == "system"
and "subtype" in msg
and "session_id" in msg
)
def is_sdk_result_message(msg: Any) -> TypeGuard[SDKResultMessage]:
return (
isinstance(msg, dict)
and msg.get("type") == "result"
and "subtype" in msg
and "session_id" in msg
)
def is_sdk_partial_assistant_message(msg: Any) -> TypeGuard[SDKPartialAssistantMessage]:
return (
isinstance(msg, dict)
and msg.get("type") == "stream_event"
and "session_id" in msg
and "event" in msg
)
def is_control_request(msg: Any) -> TypeGuard[CLIControlRequest]:
return (
isinstance(msg, dict)
and msg.get("type") == "control_request"
and "request_id" in msg
and "request" in msg
)
def is_control_response(msg: Any) -> TypeGuard[CLIControlResponse]:
return (
isinstance(msg, dict)
and msg.get("type") == "control_response"
and "response" in msg
)
def is_control_cancel(msg: Any) -> TypeGuard[ControlCancelRequest]:
return (
isinstance(msg, dict)
and msg.get("type") == "control_cancel_request"
and "request_id" in msg
)

View file

@ -0,0 +1,607 @@
"""Async Query implementation for qwen_code_sdk."""
from __future__ import annotations
import asyncio
import contextlib
from collections.abc import AsyncIterable, Mapping, MutableMapping
from dataclasses import dataclass, replace
from types import TracebackType
from typing import Any, cast
from uuid import uuid4
from .errors import AbortError, ControlRequestTimeoutError
from .json_lines import serialize_json_line
from .protocol import (
CLIControlRequest,
CLIControlResponse,
SDKMessage,
SDKUserMessage,
is_control_cancel,
is_control_request,
is_control_response,
is_sdk_assistant_message,
is_sdk_partial_assistant_message,
is_sdk_result_message,
is_sdk_system_message,
is_sdk_user_message,
)
from .transport import ProcessTransport
from .types import (
CanUseToolContext,
PermissionDenyResult,
QueryOptions,
QueryOptionsDict,
)
from .validation import validate_query_options
_DONE = object()
@dataclass
class _PendingControlRequest:
future: asyncio.Future[dict[str, Any] | None]
cancel_event: asyncio.Event
timeout_handle: asyncio.TimerHandle
@dataclass
class _IncomingControlRequest:
task: asyncio.Task[None]
cancel_event: asyncio.Event
class Query:
def __init__(
self,
transport: ProcessTransport,
options: QueryOptions,
prompt: str | AsyncIterable[SDKUserMessage],
session_id: str,
) -> None:
self._transport = transport
self._options = options
self._prompt = prompt
self._single_turn = isinstance(prompt, str)
self._session_id = session_id
self._session_id_locked = bool(options.resume or options.session_id)
self._message_queue: asyncio.Queue[SDKMessage | Exception | object] = (
asyncio.Queue()
)
self._closed = False
self._started = False
self._start_lock = asyncio.Lock()
self._cancel_event = asyncio.Event()
self._router_task: asyncio.Task[None] | None = None
self._input_task: asyncio.Task[None] | None = None
self._initialize_task: asyncio.Task[None] | None = None
self._first_result_event = asyncio.Event()
self._terminal_event_sent = False
self._exhausted = False
self._pending_control_requests: dict[str, _PendingControlRequest] = {}
self._incoming_control_requests: dict[str, _IncomingControlRequest] = {}
async def _ensure_started(self) -> None:
if self._closed:
raise RuntimeError("Query is closed")
if self._started:
return
async with self._start_lock:
if self._closed:
raise RuntimeError("Query is closed")
if self._started:
return
await self._transport.start()
self._router_task = asyncio.create_task(self._message_router())
self._initialize_task = asyncio.create_task(self._initialize())
if self._single_turn:
self._input_task = asyncio.create_task(self._send_single_turn_prompt())
else:
self._input_task = asyncio.create_task(
self.stream_input(self._prompt) # type: ignore[arg-type]
)
self._started = True
async def _initialize(self) -> None:
try:
payload: dict[str, Any] = {"hooks": None}
await self._send_control_request("initialize", payload)
except Exception as exc:
await self._finish_with_error(exc)
async def _send_single_turn_prompt(self) -> None:
try:
assert isinstance(self._prompt, str)
await self._wait_initialized()
message: SDKUserMessage = {
"type": "user",
"session_id": self._session_id,
"message": {
"role": "user",
"content": self._prompt,
},
"parent_tool_use_id": None,
}
await self._write_payload(message)
except Exception as exc:
await self._finish_with_error(exc)
raise
async def _wait_initialized(self) -> None:
if self._initialize_task is None:
return
await self._initialize_task
async def _message_router(self) -> None:
try:
async for message in self._transport.read_messages():
await self._route_message(message)
if self._closed:
break
if self._closed:
return
if self._transport.exit_error is not None:
await self._finish_with_error(self._transport.exit_error)
return
await self._finish()
except Exception as exc: # pragma: no cover - critical propagation path
await self._finish_with_error(exc)
async def _route_message(self, message: Any) -> None:
self._maybe_update_session_id(message)
if is_control_request(message):
self._start_incoming_control_request(message)
return
if is_control_response(message):
self._handle_control_response(message)
return
if is_control_cancel(message):
self._handle_control_cancel_request(message)
return
if is_sdk_result_message(message):
self._first_result_event.set()
if self._single_turn:
self._transport.end_input()
await self._message_queue.put(message)
return
if (
is_sdk_system_message(message)
or is_sdk_assistant_message(message)
or is_sdk_user_message(message)
or is_sdk_partial_assistant_message(message)
):
await self._message_queue.put(message)
return
def _maybe_update_session_id(self, message: Any) -> None:
if self._session_id_locked or not isinstance(message, Mapping):
return
session_id = message.get("session_id")
if isinstance(session_id, str) and session_id:
self._session_id = session_id
self._session_id_locked = True
def _start_incoming_control_request(self, request: CLIControlRequest) -> None:
request_id = request["request_id"]
cancel_event = asyncio.Event()
async def runner() -> None:
try:
await self._handle_control_request(request, cancel_event)
except asyncio.CancelledError:
pass
except Exception as exc: # pragma: no cover - fatal background path
await self._finish_with_error(exc)
finally:
self._incoming_control_requests.pop(request_id, None)
task = asyncio.create_task(runner())
self._incoming_control_requests[request_id] = _IncomingControlRequest(
task=task,
cancel_event=cancel_event,
)
async def _handle_control_request(
self,
request: CLIControlRequest,
cancel_event: asyncio.Event,
) -> None:
request_id = request["request_id"]
payload = request["request"]
subtype = payload.get("subtype")
try:
if subtype == "can_use_tool":
response = await self._handle_permission_request(
cast(MutableMapping[str, Any], payload),
cancel_event,
)
elif subtype == "mcp_message":
raise RuntimeError("mcp_message is unsupported in python sdk v1")
else:
raise RuntimeError(f"Unknown control request subtype: {subtype}")
if cancel_event.is_set():
return
await self._send_control_response(
request_id, success=True, response=response
)
except Exception as exc:
if cancel_event.is_set():
return
await self._send_control_response(
request_id,
success=False,
response=str(exc),
)
async def _handle_permission_request(
self,
payload: MutableMapping[str, Any],
cancel_event: asyncio.Event,
) -> dict[str, Any]:
tool_name = str(payload.get("tool_name", ""))
tool_input = payload.get("input")
if not isinstance(tool_input, dict):
tool_input = {}
if self._options.can_use_tool is None:
return {"behavior": "deny", "message": "Denied"}
context: CanUseToolContext = {
"cancel_event": cancel_event,
"suggestions": payload.get("permission_suggestions"),
"blocked_path": payload.get("blocked_path"),
}
try:
result = await asyncio.wait_for(
self._options.can_use_tool(tool_name, tool_input, context),
timeout=self._options.timeout.can_use_tool,
)
except asyncio.TimeoutError:
return {
"behavior": "deny",
"message": "Permission request timed out",
}
except asyncio.CancelledError:
if cancel_event.is_set():
raise
return {
"behavior": "deny",
"message": "Permission check failed: callback cancelled",
}
except Exception as exc:
return {
"behavior": "deny",
"message": f"Permission check failed: {exc}",
}
behavior = result.get("behavior")
if behavior == "allow":
return {
"behavior": "allow",
"updatedInput": result.get("updatedInput", tool_input),
}
deny_result = cast(PermissionDenyResult, result)
return {
"behavior": "deny",
"message": deny_result.get("message", "Denied"),
**(
{"interrupt": deny_result["interrupt"]}
if "interrupt" in deny_result
else {}
),
}
def _handle_control_response(self, response: CLIControlResponse) -> None:
payload = response["response"]
request_id = payload["request_id"]
pending = self._pending_control_requests.pop(request_id, None)
if pending is None:
return
pending.timeout_handle.cancel()
if payload["subtype"] == "success":
if not pending.future.done():
pending.future.set_result(payload.get("response"))
else:
error = payload.get("error", "Unknown control error")
if isinstance(error, dict):
error_message = str(error.get("message", "Unknown control error"))
else:
error_message = str(error)
if not pending.future.done():
pending.future.set_exception(RuntimeError(error_message))
def _handle_control_cancel_request(self, message: Mapping[str, Any]) -> None:
request_id = message.get("request_id")
if not isinstance(request_id, str):
return
pending = self._pending_control_requests.pop(request_id, None)
if pending is not None:
pending.timeout_handle.cancel()
pending.cancel_event.set()
if not pending.future.done():
pending.future.set_exception(AbortError("Control request cancelled"))
incoming = self._incoming_control_requests.get(request_id)
if incoming is None:
return
incoming.cancel_event.set()
incoming.task.cancel()
async def _send_control_request(
self,
subtype: str,
data: dict[str, Any] | None = None,
) -> dict[str, Any] | None:
if self._closed:
raise RuntimeError("Query is closed")
if subtype != "initialize":
await self._wait_initialized()
request_id = str(uuid4())
loop = asyncio.get_running_loop()
future: asyncio.Future[dict[str, Any] | None] = loop.create_future()
cancel_event = asyncio.Event()
def on_timeout() -> None:
pending = self._pending_control_requests.pop(request_id, None)
if pending is None:
return
pending.cancel_event.set()
if not pending.future.done():
pending.future.set_exception(
ControlRequestTimeoutError(f"Control request timeout: {subtype}")
)
timeout_handle = loop.call_later(
self._options.timeout.control_request,
on_timeout,
)
self._pending_control_requests[request_id] = _PendingControlRequest(
future=future,
cancel_event=cancel_event,
timeout_handle=timeout_handle,
)
request_payload: dict[str, Any] = {"subtype": subtype}
if data:
request_payload.update(data)
payload: CLIControlRequest = {
"type": "control_request",
"request_id": request_id,
"request": request_payload,
}
await self._write_payload(payload)
return await future
async def _send_control_response(
self,
request_id: str,
*,
success: bool,
response: Any,
) -> None:
payload: CLIControlResponse
if success:
payload = {
"type": "control_response",
"response": {
"subtype": "success",
"request_id": request_id,
"response": response,
},
}
else:
payload = {
"type": "control_response",
"response": {
"subtype": "error",
"request_id": request_id,
"error": str(response),
},
}
await self._write_payload(payload)
async def _write_payload(self, payload: Any) -> None:
self._transport.write(serialize_json_line(payload))
await self._transport.drain()
async def stream_input(self, messages: AsyncIterable[SDKUserMessage]) -> None:
try:
if self._closed:
raise RuntimeError("Query is closed")
await self._wait_initialized()
async for message in messages:
if self._cancel_event.is_set() or self._closed:
break
await self._write_payload(message)
if not self._single_turn:
try:
await asyncio.wait_for(
self._first_result_event.wait(),
timeout=self._options.timeout.stream_close,
)
except asyncio.TimeoutError:
pass
self._transport.end_input()
except Exception as exc:
await self._finish_with_error(exc)
raise
async def interrupt(self) -> None:
await self._ensure_started()
await self._send_control_request("interrupt")
async def set_permission_mode(self, mode: str) -> None:
await self._ensure_started()
await self._send_control_request("set_permission_mode", {"mode": mode})
async def set_model(self, model: str) -> None:
await self._ensure_started()
await self._send_control_request("set_model", {"model": model})
async def supported_commands(self) -> dict[str, Any] | None:
await self._ensure_started()
return await self._send_control_request("supported_commands")
async def mcp_server_status(self) -> dict[str, Any] | None:
await self._ensure_started()
return await self._send_control_request("mcp_server_status")
@property
def control_request_timeout(self) -> float:
return self._options.timeout.control_request
def get_session_id(self) -> str:
return self._session_id
def is_closed(self) -> bool:
return self._closed
def _fail_pending_control_requests(self, error: Exception) -> None:
for request_id, pending in list(self._pending_control_requests.items()):
pending.timeout_handle.cancel()
pending.cancel_event.set()
if not pending.future.done():
pending.future.set_exception(error)
self._pending_control_requests.pop(request_id, None)
async def _cancel_incoming_control_requests(self) -> None:
current_task = asyncio.current_task()
tasks: list[asyncio.Task[None]] = []
for incoming in list(self._incoming_control_requests.values()):
incoming.cancel_event.set()
if incoming.task is current_task:
continue
incoming.task.cancel()
tasks.append(incoming.task)
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
async def close(self) -> None:
if self._closed:
return
self._closed = True
self._cancel_event.set()
error = RuntimeError("Query is closed")
self._fail_pending_control_requests(error)
await self._cancel_incoming_control_requests()
await self._transport.close()
if self._input_task is not None:
self._input_task.cancel()
with contextlib.suppress(asyncio.CancelledError, Exception):
await self._input_task
if self._router_task is not None:
with contextlib.suppress(Exception):
await self._router_task
await self._finish()
async def _finish(self) -> None:
if self._terminal_event_sent:
return
self._terminal_event_sent = True
await self._message_queue.put(_DONE)
async def _finish_with_error(self, exc: Exception) -> None:
if self._terminal_event_sent:
return
self._closed = True
self._terminal_event_sent = True
self._cancel_event.set()
self._fail_pending_control_requests(exc)
await self._cancel_incoming_control_requests()
await self._transport.close()
await self._message_queue.put(exc)
await self._message_queue.put(_DONE)
def __aiter__(self) -> Query:
return self
async def __anext__(self) -> SDKMessage:
if self._exhausted:
raise StopAsyncIteration
await self._ensure_started()
item = await self._message_queue.get()
if item is _DONE:
self._exhausted = True
raise StopAsyncIteration
if isinstance(item, Exception):
raise item
return cast(SDKMessage, item)
async def __aenter__(self) -> Query:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()
def query(
prompt: str | AsyncIterable[SDKUserMessage],
options: QueryOptions | QueryOptionsDict | Mapping[str, Any] | None = None,
) -> Query:
if isinstance(options, QueryOptions):
parsed_options = replace(options)
else:
parsed_options = QueryOptions.from_mapping(options)
validate_query_options(parsed_options)
session_id = parsed_options.resume or parsed_options.session_id
if session_id is None and not parsed_options.continue_session:
session_id = str(uuid4())
if parsed_options.resume is None and not parsed_options.continue_session:
parsed_options = replace(parsed_options, session_id=session_id)
transport = ProcessTransport(parsed_options)
return Query(transport, parsed_options, prompt, session_id or "")

View file

@ -0,0 +1,217 @@
"""Synchronous wrapper around the async Query API."""
from __future__ import annotations
import asyncio
import threading
import warnings
from collections.abc import AsyncIterable, AsyncIterator, Iterable, Mapping
from queue import Queue
from typing import Any, cast
from .protocol import SDKMessage, SDKUserMessage
from .query import Query, query
from .types import QueryOptions, QueryOptionsDict
_STOP = object()
_SYNC_TIMEOUT_MARGIN = 5.0
class SyncQuery:
def __init__(
self,
prompt: str | Iterable[SDKUserMessage] | AsyncIterable[SDKUserMessage],
options: QueryOptions | QueryOptionsDict | Mapping[str, Any] | None = None,
) -> None:
self._queue: Queue[SDKMessage | Exception | object] = Queue()
self._ready = threading.Event()
self._shutdown = threading.Event()
self._stop_sent = threading.Event()
self._exhausted = False
self._query: Query | None = None
self._consumer_task: asyncio.Task[None] | None = None
self._loop = asyncio.new_event_loop()
self._thread = threading.Thread(
target=self._run_loop,
name="qwen-sdk-sync-loop",
daemon=True,
)
self._thread.start()
if isinstance(prompt, str) or isinstance(prompt, AsyncIterable):
source_prompt: str | AsyncIterable[SDKUserMessage] = prompt
else:
source_prompt = _iterable_to_async(prompt)
future = asyncio.run_coroutine_threadsafe(
self._bootstrap(source_prompt, options),
self._loop,
)
try:
future.result()
except Exception:
self._stop_loop()
raise
def _run_loop(self) -> None:
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
async def _bootstrap(
self,
prompt: str | AsyncIterable[SDKUserMessage],
options: QueryOptions | QueryOptionsDict | Mapping[str, Any] | None,
) -> None:
self._query = query(prompt=prompt, options=options)
self._ready.set()
self._consumer_task = asyncio.create_task(self._consume())
async def _consume(self) -> None:
assert self._query is not None
try:
async for message in self._query:
self._queue.put(message)
except Exception as exc:
self._queue.put(exc)
finally:
if not self._stop_sent.is_set():
self._stop_sent.set()
self._queue.put(_STOP)
def _require_query(self) -> Query:
self._ready.wait(timeout=30)
if self._query is None:
raise RuntimeError("SyncQuery failed to initialize")
return self._query
def __iter__(self) -> SyncQuery:
return self
def __next__(self) -> SDKMessage:
if self._exhausted:
raise StopIteration
item = self._queue.get()
if item is _STOP:
self._exhausted = True
raise StopIteration
if isinstance(item, Exception):
raise item
return cast(SDKMessage, item)
def __enter__(self) -> SyncQuery:
return self
def __exit__(self, *_args: object) -> None:
self.close()
def interrupt(self) -> None:
q = self._require_query()
asyncio.run_coroutine_threadsafe(q.interrupt(), self._loop).result(
timeout=q.control_request_timeout + _SYNC_TIMEOUT_MARGIN
)
def set_model(self, model: str) -> None:
q = self._require_query()
asyncio.run_coroutine_threadsafe(q.set_model(model), self._loop).result(
timeout=q.control_request_timeout + _SYNC_TIMEOUT_MARGIN
)
def set_permission_mode(self, mode: str) -> None:
q = self._require_query()
asyncio.run_coroutine_threadsafe(
q.set_permission_mode(mode),
self._loop,
).result(timeout=q.control_request_timeout + _SYNC_TIMEOUT_MARGIN)
def supported_commands(self) -> Any:
q = self._require_query()
return asyncio.run_coroutine_threadsafe(
q.supported_commands(),
self._loop,
).result(timeout=q.control_request_timeout + _SYNC_TIMEOUT_MARGIN)
def mcp_server_status(self) -> Any:
q = self._require_query()
return asyncio.run_coroutine_threadsafe(
q.mcp_server_status(),
self._loop,
).result(timeout=q.control_request_timeout + _SYNC_TIMEOUT_MARGIN)
def get_session_id(self) -> str:
q = self._require_query()
return q.get_session_id()
def is_closed(self) -> bool:
q = self._require_query()
return q.is_closed()
def close(self) -> None:
if self._shutdown.is_set():
return
self._shutdown.set()
q = self._query
if q is not None:
try:
asyncio.run_coroutine_threadsafe(q.close(), self._loop).result(
timeout=30
)
except Exception:
pass
# Wait for _consume() to put _STOP before stopping the loop,
# otherwise consumers blocked on queue.get() will deadlock.
if self._consumer_task is not None:
try:
asyncio.run_coroutine_threadsafe(
self._await_consumer(), self._loop
).result(timeout=5)
except Exception:
pass
if not self._stop_sent.is_set():
self._stop_sent.set()
self._queue.put(_STOP)
self._stop_loop()
async def _await_consumer(self) -> None:
if self._consumer_task is not None:
try:
await asyncio.wait_for(self._consumer_task, timeout=5.0)
except Exception:
pass
def _stop_loop(self) -> None:
if self._loop.is_running():
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join(timeout=5)
if not self._loop.is_closed():
self._loop.close()
def __del__(self) -> None:
try:
if not self._shutdown.is_set():
warnings.warn(
"SyncQuery was not closed. "
"Use 'with SyncQuery(...) as q:' or call q.close() explicitly.",
ResourceWarning,
stacklevel=1,
)
try:
self.close()
except Exception:
pass
except AttributeError:
pass
async def _iterable_to_async(
messages: Iterable[SDKUserMessage],
) -> AsyncIterator[SDKUserMessage]:
for message in messages:
yield message

View file

@ -0,0 +1,243 @@
"""Process transport for qwen CLI stream-json protocol."""
from __future__ import annotations
import asyncio
import json
import os
import subprocess
import sys
from collections.abc import AsyncIterator
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from .errors import ProcessExitError
from .json_lines import parse_json_line
from .types import QueryOptions
@dataclass(frozen=True)
class SpawnInfo:
command: str
args: list[str]
def prepare_spawn_info(path_to_qwen_executable: str | None) -> SpawnInfo:
if path_to_qwen_executable is None:
return SpawnInfo(command="qwen", args=[])
spec = path_to_qwen_executable
if os.path.sep not in spec and (
os.path.altsep is None or os.path.altsep not in spec
):
return SpawnInfo(command=spec, args=[])
path = Path(spec).expanduser().resolve()
suffix = path.suffix.lower()
if suffix == ".py":
return SpawnInfo(command=sys.executable, args=[str(path)])
if suffix in {".js", ".mjs", ".cjs"}:
return SpawnInfo(command="node", args=[str(path)])
return SpawnInfo(command=str(path), args=[])
class ProcessTransport:
def __init__(self, options: QueryOptions):
self._options = options
self._process: asyncio.subprocess.Process | None = None
self._stderr_task: asyncio.Task[None] | None = None
self._closed = False
self._input_closed = False
self._exit_error: Exception | None = None
@property
def exit_error(self) -> Exception | None:
return self._exit_error
@property
def is_closed(self) -> bool:
return self._closed
async def start(self) -> None:
if self._closed:
raise RuntimeError("Transport is closed")
if self._process is not None:
return
spawn_info = prepare_spawn_info(self._options.path_to_qwen_executable)
args = [*spawn_info.args, *build_cli_arguments(self._options)]
stderr_target = (
asyncio.subprocess.PIPE
if self._options.debug or self._options.stderr is not None
else subprocess.DEVNULL
)
self._process = await asyncio.create_subprocess_exec(
spawn_info.command,
*args,
cwd=self._options.cwd,
env={**os.environ, **(self._options.env or {})},
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=stderr_target,
)
if self._options.debug or self._options.stderr is not None:
self._stderr_task = asyncio.create_task(self._forward_stderr())
async def _forward_stderr(self) -> None:
if self._process is None or self._process.stderr is None:
return
while True:
chunk = await self._process.stderr.readline()
if not chunk:
return
text = chunk.decode("utf-8", errors="replace").rstrip("\n")
try:
if self._options.stderr is not None:
self._options.stderr(text)
elif self._options.debug:
print(text, file=sys.stderr)
except Exception:
print(text, file=sys.stderr)
def write(self, data: str) -> None:
if self._closed:
raise RuntimeError("Transport is closed")
if self._process is None or self._process.stdin is None:
raise RuntimeError("Transport is not started")
if self._input_closed:
raise RuntimeError("Transport input is already closed")
self._process.stdin.write(data.encode("utf-8"))
async def drain(self) -> None:
if self._process is None or self._process.stdin is None:
return
await self._process.stdin.drain()
def end_input(self) -> None:
if self._closed or self._input_closed:
return
if self._process is None or self._process.stdin is None:
return
self._process.stdin.close()
self._input_closed = True
async def read_messages(self) -> AsyncIterator[Any]:
if self._process is None or self._process.stdout is None:
raise RuntimeError("Transport is not started")
while True:
line = await self._process.stdout.readline()
if not line:
break
raw = line.decode("utf-8", errors="replace").strip()
if not raw:
continue
try:
yield parse_json_line(raw)
except json.JSONDecodeError:
continue
await self._finalize_exit()
async def wait_for_exit(self) -> None:
if self._process is None:
return
await self._finalize_exit()
async def _finalize_exit(self) -> None:
if self._process is None:
return
return_code = self._process.returncode
if return_code is None:
return_code = await self._process.wait()
if return_code != 0 and self._exit_error is None:
self._exit_error = ProcessExitError(
f"CLI process exited with code {return_code}",
exit_code=return_code,
)
if self._stderr_task is not None:
await self._stderr_task
self._stderr_task = None
async def close(self) -> None:
if self._closed:
return
self._closed = True
if self._process is None:
return
if self._process.stdin is not None and not self._input_closed:
self._process.stdin.close()
self._input_closed = True
if self._process.returncode is None:
self._process.terminate()
try:
await asyncio.wait_for(self._process.wait(), timeout=5.0)
except asyncio.TimeoutError:
self._process.kill()
await self._process.wait()
await self._finalize_exit()
def build_cli_arguments(options: QueryOptions) -> list[str]:
args: list[str] = [
"--input-format",
"stream-json",
"--output-format",
"stream-json",
"--channel=SDK",
]
if options.model:
args.extend(["--model", options.model])
if options.system_prompt:
args.extend(["--system-prompt", options.system_prompt])
if options.append_system_prompt:
args.extend(["--append-system-prompt", options.append_system_prompt])
if options.permission_mode:
args.extend(["--approval-mode", options.permission_mode])
if options.max_session_turns is not None:
args.extend(["--max-session-turns", str(options.max_session_turns)])
if options.core_tools:
args.extend(["--core-tools", ",".join(options.core_tools)])
if options.exclude_tools:
args.extend(["--exclude-tools", ",".join(options.exclude_tools)])
if options.allowed_tools:
args.extend(["--allowed-tools", ",".join(options.allowed_tools)])
if options.auth_type:
args.extend(["--auth-type", options.auth_type])
if options.include_partial_messages:
args.append("--include-partial-messages")
if options.resume:
args.extend(["--resume", options.resume])
elif options.continue_session:
args.append("--continue")
elif options.session_id:
args.extend(["--session-id", options.session_id])
return args

View file

@ -0,0 +1,323 @@
"""Public type definitions for qwen_code_sdk."""
from __future__ import annotations
from collections.abc import Awaitable, Callable, Mapping, MutableMapping
from dataclasses import dataclass
from inspect import Parameter, Signature, iscoroutinefunction, signature
from typing import (
Any,
Literal,
TypeAlias,
TypedDict,
cast,
)
from typing_extensions import NotRequired
PermissionMode: TypeAlias = Literal["default", "plan", "auto-edit", "yolo"]
AuthType: TypeAlias = Literal[
"openai",
"anthropic",
"qwen-oauth",
"gemini",
"vertex-ai",
]
class PermissionSuggestion(TypedDict):
type: Literal["allow", "deny", "modify"]
label: str
description: NotRequired[str]
modifiedInput: NotRequired[Any]
class PermissionAllowResult(TypedDict):
behavior: Literal["allow"]
updatedInput: NotRequired[dict[str, Any]]
class PermissionDenyResult(TypedDict):
behavior: Literal["deny"]
message: NotRequired[str]
interrupt: NotRequired[bool]
PermissionResult: TypeAlias = PermissionAllowResult | PermissionDenyResult
class CanUseToolContext(TypedDict):
cancel_event: Any
suggestions: list[PermissionSuggestion] | None
blocked_path: str | None
CanUseTool: TypeAlias = Callable[
[str, dict[str, Any], CanUseToolContext],
Awaitable[PermissionResult],
]
class TimeoutOptionsDict(TypedDict, total=False):
"""Timeout configuration. All values are in seconds."""
can_use_tool: float
control_request: float
stream_close: float
@dataclass(frozen=True)
class TimeoutOptions:
can_use_tool: float = 60.0
control_request: float = 60.0
stream_close: float = 60.0
@classmethod
def from_mapping(cls, value: Mapping[str, Any] | None) -> TimeoutOptions:
if value is None:
return cls()
def _read(name: str, default: float) -> float:
raw = value.get(name, default)
if isinstance(raw, bool) or not isinstance(raw, (int, float)):
raise TypeError(f"timeout.{name} must be a positive number")
if raw <= 0:
raise ValueError(f"timeout.{name} must be a positive number")
return float(raw)
return cls(
can_use_tool=_read("can_use_tool", 60.0),
control_request=_read("control_request", 60.0),
stream_close=_read("stream_close", 60.0),
)
class QueryOptionsDict(TypedDict, total=False):
cwd: str
model: str
path_to_qwen_executable: str
permission_mode: PermissionMode
can_use_tool: CanUseTool
env: dict[str, str]
system_prompt: str
append_system_prompt: str
debug: bool
max_session_turns: int
core_tools: list[str]
exclude_tools: list[str]
allowed_tools: list[str]
auth_type: AuthType
include_partial_messages: bool
resume: str
continue_session: bool
session_id: str
timeout: TimeoutOptionsDict
mcp_servers: dict[str, dict[str, Any]]
stderr: Callable[[str], None]
@dataclass
class QueryOptions:
cwd: str | None = None
model: str | None = None
path_to_qwen_executable: str | None = None
permission_mode: PermissionMode | None = None
can_use_tool: CanUseTool | None = None
env: dict[str, str] | None = None
system_prompt: str | None = None
append_system_prompt: str | None = None
debug: bool = False
max_session_turns: int | None = None
core_tools: list[str] | None = None
exclude_tools: list[str] | None = None
allowed_tools: list[str] | None = None
auth_type: AuthType | None = None
include_partial_messages: bool = False
resume: str | None = None
continue_session: bool = False
session_id: str | None = None
timeout: TimeoutOptions = TimeoutOptions()
mcp_servers: dict[str, dict[str, Any]] | None = None
stderr: Callable[[str], None] | None = None
@classmethod
def from_mapping(cls, value: Mapping[str, Any] | None) -> QueryOptions:
if value is None:
return cls()
data: MutableMapping[str, Any] = dict(value)
timeout = TimeoutOptions.from_mapping(data.get("timeout"))
return cls(
cwd=_as_optional_str(data, "cwd"),
model=_as_optional_str(data, "model"),
path_to_qwen_executable=_as_optional_str(data, "path_to_qwen_executable"),
permission_mode=cast(
PermissionMode | None,
_as_optional_str(data, "permission_mode"),
),
can_use_tool=cast(
CanUseTool | None,
_as_optional_callable(data, "can_use_tool"),
),
env=_as_optional_str_dict(data, "env"),
system_prompt=_as_optional_str(data, "system_prompt"),
append_system_prompt=_as_optional_str(data, "append_system_prompt"),
debug=_as_optional_bool(data, "debug") or False,
max_session_turns=_as_optional_int(data, "max_session_turns"),
core_tools=_as_optional_str_list(data, "core_tools"),
exclude_tools=_as_optional_str_list(data, "exclude_tools"),
allowed_tools=_as_optional_str_list(data, "allowed_tools"),
auth_type=cast(
AuthType | None,
_as_optional_str(data, "auth_type"),
),
include_partial_messages=_as_optional_bool(data, "include_partial_messages")
or False,
resume=_as_optional_str(data, "resume"),
continue_session=_as_optional_bool(data, "continue_session") or False,
session_id=_as_optional_str(data, "session_id"),
timeout=timeout,
mcp_servers=_as_optional_nested_dict(data, "mcp_servers"),
stderr=cast(
Callable[[str], None] | None,
_as_optional_callable(data, "stderr"),
),
)
def _as_optional_str(data: Mapping[str, Any], key: str) -> str | None:
raw = data.get(key)
if raw is None:
return None
if not isinstance(raw, str):
raise TypeError(f"{key} must be a string")
return raw
def _as_optional_int(data: Mapping[str, Any], key: str) -> int | None:
raw = data.get(key)
if raw is None:
return None
if isinstance(raw, bool) or not isinstance(raw, int):
raise TypeError(f"{key} must be an integer")
return int(raw)
def _as_optional_bool(data: Mapping[str, Any], key: str) -> bool | None:
raw = data.get(key)
if raw is None:
return None
if not isinstance(raw, bool):
raise TypeError(f"{key} must be a boolean")
return raw
def _as_optional_callable(
data: Mapping[str, Any], key: str
) -> Callable[..., Any] | None:
raw = data.get(key)
if raw is None:
return None
if not callable(raw):
raise TypeError(f"{key} must be callable")
if key == "can_use_tool":
_validate_can_use_tool_callable(raw, error_type=TypeError)
elif key == "stderr":
_validate_stderr_callable(raw, error_type=TypeError)
return cast(Callable[..., Any], raw)
def _validate_can_use_tool_callable(value: object, error_type: type[Exception]) -> None:
if not callable(value):
raise error_type("can_use_tool must be callable")
if not iscoroutinefunction(value):
raise error_type("can_use_tool must be an async callable")
try:
sig = signature(value)
except (TypeError, ValueError):
return
if not _supports_argument_count(sig, 3):
raise error_type("can_use_tool must accept exactly 3 positional arguments")
def _validate_stderr_callable(value: object, error_type: type[Exception]) -> None:
if not callable(value):
raise error_type("stderr must be callable")
try:
sig = signature(value)
except (TypeError, ValueError):
return
if not _supports_argument_count(sig, 1):
raise error_type("stderr must accept exactly 1 positional argument")
def _supports_argument_count(sig: Signature, count: int) -> bool:
params = list(sig.parameters.values())
positional_params = [
param
for param in params
if param.kind
in (
Parameter.POSITIONAL_ONLY,
Parameter.POSITIONAL_OR_KEYWORD,
)
]
required_positional = [
param for param in positional_params if param.default is Parameter.empty
]
has_var_positional = any(param.kind is Parameter.VAR_POSITIONAL for param in params)
if len(required_positional) > count:
return False
if has_var_positional:
return True
return len(positional_params) >= count
def _as_optional_str_dict(data: Mapping[str, Any], key: str) -> dict[str, str] | None:
raw = data.get(key)
if raw is None:
return None
if not isinstance(raw, Mapping):
raise TypeError(f"{key} must be a mapping of string to string")
parsed: dict[str, str] = {}
for k, v in raw.items():
if not isinstance(k, str) or not isinstance(v, str):
raise TypeError(f"{key} must be a mapping of string to string")
parsed[k] = v
return parsed
def _as_optional_str_list(data: Mapping[str, Any], key: str) -> list[str] | None:
raw = data.get(key)
if raw is None:
return None
if not isinstance(raw, list):
raise TypeError(f"{key} must be a list of strings")
if any(not isinstance(item, str) for item in raw):
raise TypeError(f"{key} must be a list of strings")
return list(raw)
def _as_optional_nested_dict(
data: Mapping[str, Any], key: str
) -> dict[str, dict[str, Any]] | None:
raw = data.get(key)
if raw is None:
return None
if not isinstance(raw, Mapping):
raise TypeError(f"{key} must be a mapping")
parsed: dict[str, dict[str, Any]] = {}
for k, v in raw.items():
if not isinstance(k, str) or not isinstance(v, Mapping):
raise TypeError(f"{key} must be a mapping of string to mapping")
parsed[k] = dict(v)
return parsed

View file

@ -0,0 +1,94 @@
"""Validation helpers for query options."""
from __future__ import annotations
from collections.abc import Callable
from uuid import RFC_4122, UUID
from .errors import ValidationError
from .types import (
QueryOptions,
_validate_can_use_tool_callable,
_validate_stderr_callable,
)
_VALID_PERMISSION_MODES = {"default", "plan", "auto-edit", "yolo"}
_VALID_AUTH_TYPES = {"openai", "anthropic", "qwen-oauth", "gemini", "vertex-ai"}
def validate_query_options(options: QueryOptions) -> None:
if (
options.permission_mode
and options.permission_mode not in _VALID_PERMISSION_MODES
):
raise ValidationError(
f"Invalid permission_mode: {options.permission_mode!r}. "
"Expected one of: default, plan, auto-edit, yolo."
)
if options.auth_type and options.auth_type not in _VALID_AUTH_TYPES:
raise ValidationError(
f"Invalid auth_type: {options.auth_type!r}. "
"Expected one of: openai, anthropic, qwen-oauth, gemini, vertex-ai."
)
_validate_optional_callable(options.can_use_tool, _validate_can_use_tool_callable)
_validate_optional_callable(options.stderr, _validate_stderr_callable)
if options.resume and options.continue_session:
raise ValidationError(
"Cannot use resume together with continue_session. "
"Use continue_session for latest session "
"or resume for a specific session ID."
)
if options.session_id and (options.resume or options.continue_session):
raise ValidationError(
"Cannot use session_id with resume or continue_session. "
"session_id starts a new session, "
"resume/continue_session restore existing sessions."
)
if options.session_id:
validate_session_id(options.session_id, "session_id")
if options.resume:
validate_session_id(options.resume, "resume")
if options.max_session_turns is not None and options.max_session_turns < -1:
raise ValidationError("max_session_turns must be -1 or a non-negative integer")
if (
options.path_to_qwen_executable is not None
and not options.path_to_qwen_executable.strip()
):
raise ValidationError("path_to_qwen_executable cannot be empty")
if options.mcp_servers:
raise ValidationError(
"mcp_servers is not supported in Python SDK v1. "
"Remove the mcp_servers option or use the TypeScript SDK."
)
def _validate_optional_callable(
value: object,
validator: Callable[[object, type[ValidationError]], None],
) -> None:
if value is None:
return
validator(value, ValidationError)
def validate_session_id(value: str, param_name: str) -> None:
try:
parsed = UUID(value)
except ValueError as exc:
raise ValidationError(
f"Invalid {param_name}: {value!r}. Must be a valid UUID."
) from exc
if parsed.variant != RFC_4122:
raise ValidationError(
f"Invalid {param_name}: {value!r}. UUID variant must be RFC 4122."
)