mirror of
https://github.com/QwenLM/qwen-code.git
synced 2026-04-28 11:41:04 +00:00
* Codex worktree snapshot: startup-cleanup Co-authored-by: Codex * Add Python SDK real smoke test Adds a repository-only real E2E smoke script for the Python SDK, plus npm and developer documentation entry points. Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): address review findings — bugs, type safety, and test coverage - Fix prepare_spawn_info: JS files now use "node" instead of sys.executable - Fix protocol.py: correct total=False misuse on 7 TypedDicts (required fields were optional) - Fix query.py: add _closed guard in _ensure_started, suppress exceptions in close() - Fix sync_query.py: prevent close() deadlock, add context manager, add timeouts - Fix transport.py: handle malformed JSON lines, add _closed guard in start() - Fix validation.py: use uuid.RFC_4122 instead of magic UUID - Fix __init__.py: export TextBlock, widen query_sync signature - Remove dead code: ensure_not_aborted, write_json_line, _thread_error - Add 12 new tests (29 → 41): context managers, JSON skip, closed guards, spawn info, timeouts Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): address wenshao review — session_id, bool validation, debug stderr - Fix continue_session=True generating a wrong random session_id - Add _as_optional_bool helper for strict type validation on bool fields - Default debug stderr to sys.stderr when no custom callback is provided Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): address remaining wenshao review feedback Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * test(cli): harden settings dialog restart prompt test Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): review fixes — UUID compat, stderr fallback, sync cleanup - Remove UUID version restriction to support v6/v7/v8 (RFC 9562) - Always write to sys.stderr when stderr callback raises (was silent when debug=False) - Prevent duplicate _STOP sentinel in SyncQuery.close() via _stop_sent flag - Add ruff format --check to CI workflow - Fix smoke_real.py version guard: fail early before imports instead of NameError - Apply ruff format to existing files Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): remaining review fixes — exit_code attr, guard strictness, sync timeout - Add exit_code attribute to ProcessExitError for programmatic access - Strengthen is_control_response/is_control_cancel guards to require payload fields, preventing misrouting of malformed messages - Expose control_request_timeout property on Query so SyncQuery uses the configured timeout instead of a hardcoded 30s default - Use dataclasses.replace() instead of direct mutation on frozen-style QueryOptions in query() factory - Add ResourceWarning in SyncQuery.__del__ when not properly closed Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): add exit_code default and guard __del__ against partial GC - Give ProcessExitError.exit_code a default value (-1) so user code can construct the exception with just a message string - Wrap SyncQuery.__del__ in try/except AttributeError to prevent crashes when the object is partially garbage-collected Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): review fixes — resource leak, type safety, CI matrix, docs - Fix SyncQuery.__del__ to call close() on GC instead of only warning - Replace hasattr duck-type check with isinstance(prompt, AsyncIterable) - Type-validate permission_mode/auth_type in QueryOptions.from_mapping - Use TypeGuard return types on all is_sdk_*/is_control_* predicates - Add 5s margin to sync wrapper timeouts to prevent error type masking - Expand CI matrix to test Python 3.10, 3.11, 3.12 - Change ProcessExitError.exit_code default from -1 to None - Add stderr to docs QueryOptions listing - Update README sync example to use context manager pattern Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): preserve iterator exhaustion state and suppress detached task warning - Add _exhausted flag to Query.__anext__ and SyncQuery.__next__ so repeated iteration after end-of-stream raises Stop(Async)Iteration instead of blocking forever. - Remove re-raise in _initialize() to prevent asyncio "Task exception was never retrieved" warning on detached tasks; the error is already surfaced via _finish_with_error(). Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): reject mcp_servers at validation time and add iterator/init tests - Reject mcp_servers in validate_query_options() with a clear error instead of advertising MCP support to the CLI and then failing at runtime when mcp_message arrives. - Remove dead mcp_servers branch from _initialize(). - Add tests for async/sync iterator exhaustion, detached init task warning suppression, and mcp_servers validation. Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): fix ruff lint errors in new tests - Use ControlRequestTimeoutError instead of bare Exception (B017) - Fix import sorting for stdlib vs third-party (I001) - Break long line to stay within 88-char limit (E501) Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * style(sdk-python): apply ruff format to new tests Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> --------- Co-authored-by: jinye.djy <jinye.djy@alibaba-inc.com> Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
This commit is contained in:
parent
202be6ec7d
commit
e384338145
25 changed files with 4676 additions and 14 deletions
400
packages/sdk-python/tests/integration/conftest.py
Normal file
400
packages/sdk-python/tests/integration/conftest.py
Normal file
|
|
@ -0,0 +1,400 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import stat
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fake_qwen_path(tmp_path: Path) -> str:
|
||||
script_path = tmp_path / "fake_qwen.py"
|
||||
script_path.write_text(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
|
||||
def send(message):
|
||||
sys.stdout.write(json.dumps(message, separators=(",", ":")) + "\\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def parse_user_content(message):
|
||||
payload = message.get("message", {})
|
||||
content = payload.get("content", "")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text_parts.append(str(block.get("text", "")))
|
||||
return " ".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
def build_system_message():
|
||||
return {
|
||||
"type": "system",
|
||||
"subtype": "init",
|
||||
"uuid": session_id,
|
||||
"session_id": session_id,
|
||||
"cwd": ".",
|
||||
"tools": ["Read", "Edit", "Bash"],
|
||||
"mcp_servers": [],
|
||||
"model": state["model"],
|
||||
"permission_mode": state["permission_mode"],
|
||||
"qwen_code_version": "fake-1.0.0",
|
||||
"capabilities": {
|
||||
"canSetModel": True,
|
||||
"canSetPermissionMode": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def build_assistant_message(text):
|
||||
return {
|
||||
"type": "assistant",
|
||||
"uuid": str(uuid.uuid4()),
|
||||
"session_id": session_id,
|
||||
"message": {
|
||||
"id": str(uuid.uuid4()),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": state["model"],
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": text,
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"input_tokens": 1,
|
||||
"output_tokens": 1,
|
||||
},
|
||||
},
|
||||
"parent_tool_use_id": None,
|
||||
}
|
||||
|
||||
|
||||
def build_result_message(result_text):
|
||||
return {
|
||||
"type": "result",
|
||||
"subtype": "success",
|
||||
"uuid": str(uuid.uuid4()),
|
||||
"session_id": session_id,
|
||||
"is_error": False,
|
||||
"duration_ms": 5,
|
||||
"duration_api_ms": 1,
|
||||
"num_turns": 1,
|
||||
"result": result_text,
|
||||
"usage": {
|
||||
"input_tokens": 1,
|
||||
"output_tokens": 1,
|
||||
},
|
||||
"permission_denials": [],
|
||||
}
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument("--model")
|
||||
parser.add_argument("--approval-mode")
|
||||
parser.add_argument("--include-partial-messages", action="store_true")
|
||||
parser.add_argument("--session-id")
|
||||
parser.add_argument("--resume")
|
||||
parser.add_argument(
|
||||
"--continue",
|
||||
dest="continue_session",
|
||||
action="store_true",
|
||||
)
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
session_id = (
|
||||
args.resume
|
||||
or args.session_id
|
||||
or (
|
||||
"aaaaaaaa-bbbb-4ccc-8ddd-eeeeeeeeeeee"
|
||||
if args.continue_session
|
||||
else str(uuid.uuid4())
|
||||
)
|
||||
)
|
||||
state = {
|
||||
"model": args.model or "coder-model",
|
||||
"permission_mode": args.approval_mode or "default",
|
||||
"include_partial": bool(args.include_partial_messages),
|
||||
}
|
||||
|
||||
pending_permission = None
|
||||
pending_unknown_control = None
|
||||
|
||||
for line in sys.stdin:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
message = json.loads(line)
|
||||
msg_type = message.get("type")
|
||||
|
||||
if msg_type == "control_request":
|
||||
request_id = message["request_id"]
|
||||
request = message["request"]
|
||||
subtype = request.get("subtype")
|
||||
|
||||
if subtype == "initialize":
|
||||
send(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": request_id,
|
||||
"response": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
send(build_system_message())
|
||||
elif subtype == "set_model":
|
||||
state["model"] = request["model"]
|
||||
send(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": request_id,
|
||||
"response": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
send(build_system_message())
|
||||
elif subtype == "set_permission_mode":
|
||||
state["permission_mode"] = request["mode"]
|
||||
send(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": request_id,
|
||||
"response": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
send(build_system_message())
|
||||
elif subtype == "interrupt":
|
||||
send(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": request_id,
|
||||
"response": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
elif subtype == "supported_commands":
|
||||
send(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": request_id,
|
||||
"response": {
|
||||
"commands": [
|
||||
"initialize",
|
||||
"interrupt",
|
||||
"set_model",
|
||||
"set_permission_mode",
|
||||
]
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
elif subtype == "mcp_server_status":
|
||||
send(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": request_id,
|
||||
"response": {"servers": []},
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
send(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "error",
|
||||
"request_id": request_id,
|
||||
"error": f"unsupported request: {subtype}",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
elif msg_type == "user":
|
||||
prompt = parse_user_content(message)
|
||||
|
||||
if "exit nonzero" in prompt:
|
||||
sys.exit(9)
|
||||
|
||||
if "request unknown control" in prompt:
|
||||
request_id = str(uuid.uuid4())
|
||||
pending_unknown_control = {
|
||||
"request_id": request_id,
|
||||
"prompt": prompt,
|
||||
}
|
||||
send(
|
||||
{
|
||||
"type": "control_request",
|
||||
"request_id": request_id,
|
||||
"request": {
|
||||
"subtype": "something_new",
|
||||
"payload": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if "use tool" in prompt or "create file" in prompt:
|
||||
tool_use_id = str(uuid.uuid4())
|
||||
send(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": str(uuid.uuid4()),
|
||||
"session_id": session_id,
|
||||
"message": {
|
||||
"id": str(uuid.uuid4()),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": state["model"],
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tool_use_id,
|
||||
"name": "write_file",
|
||||
"input": {
|
||||
"path": "demo.txt",
|
||||
"content": "hello",
|
||||
},
|
||||
}
|
||||
],
|
||||
"usage": {"input_tokens": 1, "output_tokens": 1},
|
||||
},
|
||||
"parent_tool_use_id": None,
|
||||
}
|
||||
)
|
||||
request_id = str(uuid.uuid4())
|
||||
pending_permission = {
|
||||
"request_id": request_id,
|
||||
"tool_use_id": tool_use_id,
|
||||
"prompt": prompt,
|
||||
}
|
||||
send(
|
||||
{
|
||||
"type": "control_request",
|
||||
"request_id": request_id,
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "write_file",
|
||||
"tool_use_id": tool_use_id,
|
||||
"input": {"path": "demo.txt", "content": "hello"},
|
||||
"permission_suggestions": [
|
||||
{"type": "allow", "label": "Allow write"}
|
||||
],
|
||||
"blocked_path": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if state["include_partial"]:
|
||||
send(
|
||||
{
|
||||
"type": "stream_event",
|
||||
"uuid": str(uuid.uuid4()),
|
||||
"session_id": session_id,
|
||||
"event": {
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": "partial"},
|
||||
},
|
||||
"parent_tool_use_id": None,
|
||||
}
|
||||
)
|
||||
|
||||
send(build_assistant_message(f"Echo: {prompt}"))
|
||||
send(build_result_message(f"done: {prompt}"))
|
||||
|
||||
elif msg_type == "control_response":
|
||||
payload = message.get("response", {})
|
||||
request_id = payload.get("request_id")
|
||||
|
||||
if (
|
||||
pending_unknown_control
|
||||
and request_id == pending_unknown_control["request_id"]
|
||||
):
|
||||
if payload.get("subtype") != "error":
|
||||
sys.exit(3)
|
||||
prompt = pending_unknown_control["prompt"]
|
||||
pending_unknown_control = None
|
||||
send(
|
||||
build_assistant_message(
|
||||
f"Unknown control handled for: {prompt}"
|
||||
)
|
||||
)
|
||||
send(build_result_message(f"unknown-control: {prompt}"))
|
||||
continue
|
||||
|
||||
if (
|
||||
pending_permission
|
||||
and request_id == pending_permission["request_id"]
|
||||
):
|
||||
prompt = pending_permission["prompt"]
|
||||
tool_use_id = pending_permission["tool_use_id"]
|
||||
pending_permission = None
|
||||
|
||||
behavior = "deny"
|
||||
if payload.get("subtype") == "success":
|
||||
response_payload = payload.get("response") or {}
|
||||
behavior = response_payload.get("behavior", "deny")
|
||||
|
||||
is_allowed = behavior == "allow"
|
||||
send(
|
||||
{
|
||||
"type": "user",
|
||||
"session_id": session_id,
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use_id,
|
||||
"is_error": not is_allowed,
|
||||
"content": "ok" if is_allowed else "denied",
|
||||
}
|
||||
],
|
||||
},
|
||||
"parent_tool_use_id": tool_use_id,
|
||||
}
|
||||
)
|
||||
send(build_assistant_message(f"tool handled: {prompt}"))
|
||||
send(build_result_message(f"tool-result: {prompt}"))
|
||||
continue
|
||||
"""
|
||||
).strip()
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
script_path.chmod(script_path.stat().st_mode | stat.S_IEXEC)
|
||||
return str(script_path)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_history_expansion() -> None:
|
||||
# No-op fixture used as explicit marker for deterministic test env.
|
||||
os.environ.setdefault("PYTHONUTF8", "1")
|
||||
276
packages/sdk-python/tests/integration/test_async_query.py
Normal file
276
packages/sdk-python/tests/integration/test_async_query.py
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from qwen_code_sdk import (
|
||||
ProcessExitError,
|
||||
SDKUserMessage,
|
||||
is_sdk_assistant_message,
|
||||
is_sdk_partial_assistant_message,
|
||||
is_sdk_result_message,
|
||||
is_sdk_system_message,
|
||||
is_sdk_user_message,
|
||||
query,
|
||||
)
|
||||
|
||||
CONTINUED_SESSION_ID = "aaaaaaaa-bbbb-4ccc-8ddd-eeeeeeeeeeee"
|
||||
VALID_UUID = "123e4567-e89b-12d3-a456-426614174000"
|
||||
RESUME_UUID = "223e4567-e89b-12d3-a456-426614174000"
|
||||
|
||||
|
||||
async def _collect_messages(result: Any) -> list[dict[str, Any]]:
|
||||
messages: list[dict[str, Any]] = []
|
||||
async for message in result:
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
|
||||
async def _wait_for(predicate: Callable[[], bool], timeout: float = 2.0) -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
deadline = loop.time() + timeout
|
||||
while loop.time() < deadline:
|
||||
if predicate():
|
||||
return
|
||||
await asyncio.sleep(0.01)
|
||||
raise AssertionError("timed out waiting for expected SDK state")
|
||||
|
||||
|
||||
def _tool_result_error_flag(message: dict[str, Any]) -> bool:
|
||||
content = message["message"]["content"]
|
||||
assert isinstance(content, list)
|
||||
return bool(content[0]["is_error"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_turn_query(fake_qwen_path: str) -> None:
|
||||
result = query(
|
||||
"hello world",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
},
|
||||
)
|
||||
messages = await _collect_messages(result)
|
||||
|
||||
assistant = next(
|
||||
message for message in messages if is_sdk_assistant_message(message)
|
||||
)
|
||||
final = next(message for message in messages if is_sdk_result_message(message))
|
||||
|
||||
assert assistant["message"]["content"][0]["text"] == "Echo: hello world"
|
||||
assert final["result"] == "done: hello world"
|
||||
await result.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_include_partial_messages(fake_qwen_path: str) -> None:
|
||||
result = query(
|
||||
"stream partial",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
"include_partial_messages": True,
|
||||
},
|
||||
)
|
||||
messages = await _collect_messages(result)
|
||||
|
||||
partial = next(
|
||||
message for message in messages if is_sdk_partial_assistant_message(message)
|
||||
)
|
||||
assert partial["event"]["type"] == "content_block_delta"
|
||||
await result.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_permission_callback_denies_tool_use(fake_qwen_path: str) -> None:
|
||||
result = query(
|
||||
"use tool now",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
},
|
||||
)
|
||||
messages = await _collect_messages(result)
|
||||
|
||||
tool_result = next(
|
||||
message
|
||||
for message in messages
|
||||
if is_sdk_user_message(message)
|
||||
and isinstance(message["message"]["content"], list)
|
||||
)
|
||||
assert _tool_result_error_flag(tool_result) is True
|
||||
await result.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_callback_can_allow_tool_use(fake_qwen_path: str) -> None:
|
||||
async def can_use_tool(
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
assert tool_name == "write_file"
|
||||
assert tool_input["path"] == "demo.txt"
|
||||
assert context["suggestions"][0]["type"] == "allow"
|
||||
return {"behavior": "allow", "updatedInput": tool_input}
|
||||
|
||||
result = query(
|
||||
"create file with use tool",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
"can_use_tool": can_use_tool,
|
||||
},
|
||||
)
|
||||
messages = await _collect_messages(result)
|
||||
|
||||
tool_result = next(
|
||||
message
|
||||
for message in messages
|
||||
if is_sdk_user_message(message)
|
||||
and isinstance(message["message"]["content"], list)
|
||||
)
|
||||
assert _tool_result_error_flag(tool_result) is False
|
||||
await result.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_control_requests_are_rejected(fake_qwen_path: str) -> None:
|
||||
result = query(
|
||||
"request unknown control",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
},
|
||||
)
|
||||
messages = await _collect_messages(result)
|
||||
|
||||
final = next(message for message in messages if is_sdk_result_message(message))
|
||||
assert final["result"] == "unknown-control: request unknown control"
|
||||
await result.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_controls_and_status(fake_qwen_path: str) -> None:
|
||||
release_input = asyncio.Event()
|
||||
|
||||
async def prompts() -> AsyncIterator[SDKUserMessage]:
|
||||
yield {
|
||||
"type": "user",
|
||||
"session_id": VALID_UUID,
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": "first turn",
|
||||
},
|
||||
"parent_tool_use_id": None,
|
||||
}
|
||||
await release_input.wait()
|
||||
|
||||
result = query(
|
||||
prompts(),
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
"session_id": VALID_UUID,
|
||||
},
|
||||
)
|
||||
|
||||
messages: list[dict[str, Any]] = []
|
||||
|
||||
async def consume() -> list[dict[str, Any]]:
|
||||
async for message in result:
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
collector = asyncio.create_task(consume())
|
||||
await _wait_for(lambda: any(is_sdk_result_message(message) for message in messages))
|
||||
|
||||
assert await result.supported_commands() == {
|
||||
"commands": [
|
||||
"initialize",
|
||||
"interrupt",
|
||||
"set_model",
|
||||
"set_permission_mode",
|
||||
]
|
||||
}
|
||||
assert await result.mcp_server_status() == {"servers": []}
|
||||
|
||||
await result.set_model("new-model")
|
||||
await result.set_permission_mode("plan")
|
||||
release_input.set()
|
||||
await collector
|
||||
|
||||
system_messages = [
|
||||
message for message in messages if is_sdk_system_message(message)
|
||||
]
|
||||
assert any(message["model"] == "new-model" for message in system_messages)
|
||||
assert any(message["permission_mode"] == "plan" for message in system_messages)
|
||||
await result.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_id_resume_and_continue(fake_qwen_path: str) -> None:
|
||||
explicit = query(
|
||||
"hello explicit",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
"session_id": VALID_UUID,
|
||||
},
|
||||
)
|
||||
explicit_messages = await _collect_messages(explicit)
|
||||
assert explicit.get_session_id() == VALID_UUID
|
||||
assert all(message["session_id"] == VALID_UUID for message in explicit_messages)
|
||||
await explicit.close()
|
||||
|
||||
resumed = query(
|
||||
"hello resume",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
"resume": RESUME_UUID,
|
||||
},
|
||||
)
|
||||
resumed_messages = await _collect_messages(resumed)
|
||||
assert resumed.get_session_id() == RESUME_UUID
|
||||
assert all(message["session_id"] == RESUME_UUID for message in resumed_messages)
|
||||
await resumed.close()
|
||||
|
||||
continued = query(
|
||||
"hello continue",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
"continue_session": True,
|
||||
},
|
||||
)
|
||||
continued_messages = await _collect_messages(continued)
|
||||
assert continued.get_session_id() == CONTINUED_SESSION_ID
|
||||
assert any(
|
||||
message["session_id"] == CONTINUED_SESSION_ID for message in continued_messages
|
||||
)
|
||||
await continued.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_zero_process_exit_is_propagated(fake_qwen_path: str) -> None:
|
||||
result = query(
|
||||
"please exit nonzero",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(ProcessExitError, match="code 9"):
|
||||
await _collect_messages(result)
|
||||
|
||||
await result.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_context_manager(fake_qwen_path: str) -> None:
|
||||
async with query(
|
||||
"hello context",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
},
|
||||
) as result:
|
||||
messages = await _collect_messages(result)
|
||||
|
||||
assert result.is_closed()
|
||||
final = next(m for m in messages if is_sdk_result_message(m))
|
||||
assert final["result"] == "done: hello context"
|
||||
82
packages/sdk-python/tests/integration/test_sync_query.py
Normal file
82
packages/sdk-python/tests/integration/test_sync_query.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import qwen_code_sdk.sync_query as sync_query_module
|
||||
from qwen_code_sdk import is_sdk_result_message, query_sync
|
||||
from qwen_code_sdk.sync_query import SyncQuery
|
||||
|
||||
|
||||
def test_sync_query_single_turn(fake_qwen_path: str) -> None:
|
||||
result = query_sync(
|
||||
"hello sync",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
},
|
||||
)
|
||||
|
||||
commands = result.supported_commands()
|
||||
messages = list(result)
|
||||
|
||||
assert commands["commands"][0] == "initialize"
|
||||
assert any(
|
||||
is_sdk_result_message(message) and message["result"] == "done: hello sync"
|
||||
for message in messages
|
||||
)
|
||||
|
||||
result.close()
|
||||
result.close()
|
||||
|
||||
|
||||
def test_sync_query_bootstrap_failure_cleans_up_loop_thread(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
def raising_query(*args: object, **kwargs: object) -> object:
|
||||
raise RuntimeError("bootstrap failed")
|
||||
|
||||
monkeypatch.setattr(sync_query_module, "query", raising_query)
|
||||
|
||||
baseline_threads = {
|
||||
thread.ident
|
||||
for thread in threading.enumerate()
|
||||
if thread.name == "qwen-sdk-sync-loop"
|
||||
}
|
||||
|
||||
with pytest.raises(RuntimeError, match="bootstrap failed"):
|
||||
SyncQuery("hello")
|
||||
|
||||
deadline = time.time() + 1.0
|
||||
while time.time() < deadline:
|
||||
active_threads = {
|
||||
thread.ident
|
||||
for thread in threading.enumerate()
|
||||
if thread.name == "qwen-sdk-sync-loop"
|
||||
}
|
||||
if active_threads == baseline_threads:
|
||||
break
|
||||
time.sleep(0.01)
|
||||
|
||||
active_threads = {
|
||||
thread.ident
|
||||
for thread in threading.enumerate()
|
||||
if thread.name == "qwen-sdk-sync-loop"
|
||||
}
|
||||
assert active_threads == baseline_threads
|
||||
|
||||
|
||||
def test_sync_query_context_manager(fake_qwen_path: str) -> None:
|
||||
with query_sync(
|
||||
"hello context",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
},
|
||||
) as result:
|
||||
messages = list(result)
|
||||
assert any(
|
||||
is_sdk_result_message(m) and m["result"] == "done: hello context"
|
||||
for m in messages
|
||||
)
|
||||
|
||||
assert result.is_closed()
|
||||
612
packages/sdk-python/tests/unit/test_query_core.py
Normal file
612
packages/sdk-python/tests/unit/test_query_core.py
Normal file
|
|
@ -0,0 +1,612 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from qwen_code_sdk.errors import AbortError, ControlRequestTimeoutError
|
||||
from qwen_code_sdk.json_lines import parse_json_line
|
||||
from qwen_code_sdk.query import Query
|
||||
from qwen_code_sdk.types import QueryOptions, TimeoutOptions
|
||||
|
||||
VALID_UUID = "123e4567-e89b-12d3-a456-426614174000"
|
||||
_EOF = object()
|
||||
|
||||
|
||||
class FakeTransport:
|
||||
def __init__(self) -> None:
|
||||
self.writes: list[dict[str, Any]] = []
|
||||
self.exit_error: Exception | None = None
|
||||
self.closed = False
|
||||
self.close_calls = 0
|
||||
self.input_closed = False
|
||||
self._queue: asyncio.Queue[dict[str, Any] | object] = asyncio.Queue()
|
||||
|
||||
async def start(self) -> None:
|
||||
return None
|
||||
|
||||
def write(self, data: str) -> None:
|
||||
self.writes.append(parse_json_line(data))
|
||||
|
||||
async def drain(self) -> None:
|
||||
return None
|
||||
|
||||
def end_input(self) -> None:
|
||||
self.input_closed = True
|
||||
|
||||
async def read_messages(self): # type: ignore[no-untyped-def]
|
||||
while True:
|
||||
item = await self._queue.get()
|
||||
if item is _EOF:
|
||||
break
|
||||
yield item
|
||||
|
||||
async def close(self) -> None:
|
||||
self.closed = True
|
||||
self.close_calls += 1
|
||||
self.input_closed = True
|
||||
self._queue.put_nowait(_EOF)
|
||||
|
||||
def push(self, payload: dict[str, Any]) -> None:
|
||||
self._queue.put_nowait(payload)
|
||||
|
||||
|
||||
async def _wait_for(predicate: Callable[[], bool], timeout: float = 1.0) -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
deadline = loop.time() + timeout
|
||||
while loop.time() < deadline:
|
||||
if predicate():
|
||||
return
|
||||
await asyncio.sleep(0.01)
|
||||
raise AssertionError("timed out waiting for test condition")
|
||||
|
||||
|
||||
async def _wait_for_request(
|
||||
transport: FakeTransport,
|
||||
subtype: str,
|
||||
timeout: float = 1.0,
|
||||
) -> dict[str, Any]:
|
||||
await _wait_for(
|
||||
lambda: any(
|
||||
payload.get("type") == "control_request"
|
||||
and payload.get("request", {}).get("subtype") == subtype
|
||||
for payload in transport.writes
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
for payload in transport.writes:
|
||||
if (
|
||||
payload.get("type") == "control_request"
|
||||
and payload.get("request", {}).get("subtype") == subtype
|
||||
):
|
||||
return payload
|
||||
raise AssertionError(f"missing control request: {subtype}")
|
||||
|
||||
|
||||
async def _wait_for_control_response(
|
||||
transport: FakeTransport,
|
||||
request_id: str,
|
||||
timeout: float = 1.0,
|
||||
) -> dict[str, Any]:
|
||||
await _wait_for(
|
||||
lambda: any(
|
||||
payload.get("type") == "control_response"
|
||||
and payload.get("response", {}).get("request_id") == request_id
|
||||
for payload in transport.writes
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
for payload in transport.writes:
|
||||
if (
|
||||
payload.get("type") == "control_response"
|
||||
and payload.get("response", {}).get("request_id") == request_id
|
||||
):
|
||||
return payload
|
||||
raise AssertionError(f"missing control response: {request_id}")
|
||||
|
||||
|
||||
async def _start_query(transport: FakeTransport) -> Query:
|
||||
query = Query(
|
||||
transport=transport, # type: ignore[arg-type]
|
||||
options=QueryOptions(
|
||||
timeout=TimeoutOptions(
|
||||
can_use_tool=0.05,
|
||||
control_request=0.05,
|
||||
stream_close=0.05,
|
||||
)
|
||||
),
|
||||
prompt="hello",
|
||||
session_id=VALID_UUID,
|
||||
)
|
||||
await query._ensure_started()
|
||||
|
||||
init_request = await _wait_for_request(transport, "initialize")
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": init_request["request_id"],
|
||||
"response": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
await _wait_for(
|
||||
lambda: any(payload.get("type") == "user" for payload in transport.writes)
|
||||
)
|
||||
return query
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_control_request_returns_error_response() -> None:
|
||||
transport = FakeTransport()
|
||||
query = await _start_query(transport)
|
||||
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_request",
|
||||
"request_id": "unknown-1",
|
||||
"request": {
|
||||
"subtype": "something_new",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
response = await _wait_for_control_response(transport, "unknown-1")
|
||||
|
||||
assert response["response"]["subtype"] == "error"
|
||||
assert "Unknown control request subtype" in response["response"]["error"]
|
||||
await query.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_control_request_times_out() -> None:
|
||||
transport = FakeTransport()
|
||||
query = await _start_query(transport)
|
||||
|
||||
with pytest.raises(ControlRequestTimeoutError, match="supported_commands"):
|
||||
await query.supported_commands()
|
||||
|
||||
await query.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_control_request_cancel_propagates_abort_error() -> None:
|
||||
transport = FakeTransport()
|
||||
query = await _start_query(transport)
|
||||
|
||||
task = asyncio.create_task(query.supported_commands())
|
||||
request = await _wait_for_request(transport, "supported_commands")
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_cancel_request",
|
||||
"request_id": request["request_id"],
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(AbortError, match="Control request cancelled"):
|
||||
await task
|
||||
|
||||
await query.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_incoming_control_request_cancel_does_not_block_router() -> None:
|
||||
transport = FakeTransport()
|
||||
started = asyncio.Event()
|
||||
cancelled = asyncio.Event()
|
||||
captured_cancel_events: list[asyncio.Event] = []
|
||||
|
||||
async def can_use_tool(
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
assert tool_name == "write_file"
|
||||
assert tool_input["path"] == "demo.txt"
|
||||
cancel_event = cast(asyncio.Event, context["cancel_event"])
|
||||
captured_cancel_events.append(cancel_event)
|
||||
started.set()
|
||||
try:
|
||||
await cancel_event.wait()
|
||||
cancelled.set()
|
||||
return {"behavior": "deny", "message": "Cancelled"}
|
||||
except asyncio.CancelledError:
|
||||
if cancel_event.is_set():
|
||||
cancelled.set()
|
||||
raise
|
||||
|
||||
query = Query(
|
||||
transport=transport, # type: ignore[arg-type]
|
||||
options=QueryOptions(
|
||||
can_use_tool=can_use_tool,
|
||||
timeout=TimeoutOptions(
|
||||
can_use_tool=1.0,
|
||||
control_request=0.2,
|
||||
stream_close=0.05,
|
||||
),
|
||||
),
|
||||
prompt="hello",
|
||||
session_id=VALID_UUID,
|
||||
)
|
||||
await query._ensure_started()
|
||||
|
||||
init_request = await _wait_for_request(transport, "initialize")
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": init_request["request_id"],
|
||||
"response": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
await _wait_for(
|
||||
lambda: any(payload.get("type") == "user" for payload in transport.writes)
|
||||
)
|
||||
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_request",
|
||||
"request_id": "incoming-1",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "write_file",
|
||||
"tool_use_id": "tool-1",
|
||||
"input": {"path": "demo.txt", "content": "hello"},
|
||||
"permission_suggestions": [],
|
||||
"blocked_path": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
await _wait_for(lambda: started.is_set())
|
||||
assert captured_cancel_events[0] is not query._cancel_event
|
||||
|
||||
supported_commands_task = asyncio.create_task(query.supported_commands())
|
||||
supported_request = await _wait_for_request(transport, "supported_commands")
|
||||
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_cancel_request",
|
||||
"request_id": "incoming-1",
|
||||
}
|
||||
)
|
||||
await _wait_for(lambda: cancelled.is_set())
|
||||
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": supported_request["request_id"],
|
||||
"response": {"commands": ["supported_commands"]},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert await supported_commands_task == {"commands": ["supported_commands"]}
|
||||
assert all(
|
||||
not (
|
||||
payload.get("type") == "control_response"
|
||||
and payload.get("response", {}).get("request_id") == "incoming-1"
|
||||
)
|
||||
for payload in transport.writes
|
||||
)
|
||||
await query.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_request_passes_blocked_path_to_callback() -> None:
|
||||
transport = FakeTransport()
|
||||
captured_context: dict[str, Any] | None = None
|
||||
|
||||
async def can_use_tool(
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
nonlocal captured_context
|
||||
assert tool_name == "write_file"
|
||||
assert tool_input["path"] == "demo.txt"
|
||||
captured_context = context
|
||||
return {"behavior": "deny", "message": "blocked"}
|
||||
|
||||
query = Query(
|
||||
transport=transport, # type: ignore[arg-type]
|
||||
options=QueryOptions(
|
||||
can_use_tool=can_use_tool,
|
||||
timeout=TimeoutOptions(
|
||||
can_use_tool=1.0,
|
||||
control_request=0.2,
|
||||
stream_close=0.05,
|
||||
),
|
||||
),
|
||||
prompt="hello",
|
||||
session_id=VALID_UUID,
|
||||
)
|
||||
await query._ensure_started()
|
||||
|
||||
init_request = await _wait_for_request(transport, "initialize")
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": init_request["request_id"],
|
||||
"response": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
await _wait_for(
|
||||
lambda: any(payload.get("type") == "user" for payload in transport.writes)
|
||||
)
|
||||
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_request",
|
||||
"request_id": "incoming-2",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "write_file",
|
||||
"tool_use_id": "tool-2",
|
||||
"input": {"path": "demo.txt", "content": "hello"},
|
||||
"permission_suggestions": [],
|
||||
"blocked_path": "/tmp/demo.txt",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
response = await _wait_for_control_response(transport, "incoming-2")
|
||||
|
||||
assert captured_context is not None
|
||||
assert isinstance(captured_context["cancel_event"], asyncio.Event)
|
||||
assert captured_context["suggestions"] == []
|
||||
assert captured_context["blocked_path"] == "/tmp/demo.txt"
|
||||
assert response["response"]["subtype"] == "success"
|
||||
assert response["response"]["response"] == {
|
||||
"behavior": "deny",
|
||||
"message": "blocked",
|
||||
}
|
||||
await query.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_request_cancelled_callback_returns_deny() -> None:
|
||||
transport = FakeTransport()
|
||||
|
||||
async def can_use_tool(
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
assert tool_name == "write_file"
|
||||
assert tool_input["path"] == "demo.txt"
|
||||
assert isinstance(context["cancel_event"], asyncio.Event)
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
query = Query(
|
||||
transport=transport, # type: ignore[arg-type]
|
||||
options=QueryOptions(
|
||||
can_use_tool=can_use_tool,
|
||||
timeout=TimeoutOptions(
|
||||
can_use_tool=1.0,
|
||||
control_request=0.2,
|
||||
stream_close=0.05,
|
||||
),
|
||||
),
|
||||
prompt="hello",
|
||||
session_id=VALID_UUID,
|
||||
)
|
||||
await query._ensure_started()
|
||||
|
||||
init_request = await _wait_for_request(transport, "initialize")
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": init_request["request_id"],
|
||||
"response": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
await _wait_for(
|
||||
lambda: any(payload.get("type") == "user" for payload in transport.writes)
|
||||
)
|
||||
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_request",
|
||||
"request_id": "incoming-3",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "write_file",
|
||||
"tool_use_id": "tool-3",
|
||||
"input": {"path": "demo.txt", "content": "hello"},
|
||||
"permission_suggestions": [],
|
||||
"blocked_path": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
response = await _wait_for_control_response(transport, "incoming-3")
|
||||
|
||||
assert response["response"]["subtype"] == "success"
|
||||
assert response["response"]["response"] == {
|
||||
"behavior": "deny",
|
||||
"message": "Permission check failed: callback cancelled",
|
||||
}
|
||||
await query.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finish_with_error_closes_transport_and_fails_pending_requests() -> None:
|
||||
transport = FakeTransport()
|
||||
query = await _start_query(transport)
|
||||
|
||||
supported_commands_task = asyncio.create_task(query.supported_commands())
|
||||
await _wait_for_request(transport, "supported_commands")
|
||||
|
||||
await query._finish_with_error(RuntimeError("boom"))
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await supported_commands_task
|
||||
|
||||
assert query.is_closed() is True
|
||||
assert transport.closed is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_started_raises_after_close() -> None:
|
||||
transport = FakeTransport()
|
||||
query = await _start_query(transport)
|
||||
|
||||
await query.close()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Query is closed"):
|
||||
await query.supported_commands()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anext_after_exhaustion_raises_stop_async_iteration() -> None:
|
||||
"""After the async iterator is exhausted, subsequent __anext__ calls must
|
||||
raise StopAsyncIteration immediately instead of blocking."""
|
||||
transport = FakeTransport()
|
||||
query = await _start_query(transport)
|
||||
|
||||
# Deliver one assistant message, then a result to end the turn.
|
||||
transport.push(
|
||||
{
|
||||
"type": "assistant",
|
||||
"session_id": VALID_UUID,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hi"}],
|
||||
},
|
||||
}
|
||||
)
|
||||
transport.push(
|
||||
{
|
||||
"type": "result",
|
||||
"session_id": VALID_UUID,
|
||||
"result": "done",
|
||||
"is_error": False,
|
||||
"duration_ms": 10,
|
||||
"duration_api_ms": 5,
|
||||
"num_turns": 1,
|
||||
}
|
||||
)
|
||||
# Signal end of transport stream so the router finishes naturally.
|
||||
await transport.close()
|
||||
|
||||
# Consume all messages until exhaustion.
|
||||
messages: list[Any] = []
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
while True:
|
||||
messages.append(await query.__anext__())
|
||||
|
||||
assert len(messages) >= 1
|
||||
|
||||
# The iterator is now exhausted — a second call must raise immediately.
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await query.__anext__()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_failure_no_unhandled_task_exception(
|
||||
recwarn: pytest.WarningsChecker,
|
||||
) -> None:
|
||||
"""When _initialize fails, no 'Task exception was never retrieved' warning
|
||||
should appear — _finish_with_error already surfaces the error."""
|
||||
transport = FakeTransport()
|
||||
query = Query(
|
||||
transport=transport, # type: ignore[arg-type]
|
||||
options=QueryOptions(
|
||||
timeout=TimeoutOptions(
|
||||
can_use_tool=0.05,
|
||||
control_request=0.05,
|
||||
stream_close=0.05,
|
||||
)
|
||||
),
|
||||
prompt="hello",
|
||||
session_id=VALID_UUID,
|
||||
)
|
||||
await query._ensure_started()
|
||||
|
||||
# Let the initialize request time out — this triggers _finish_with_error
|
||||
# inside _initialize.
|
||||
init_request = await _wait_for_request(transport, "initialize")
|
||||
assert init_request is not None # init was sent
|
||||
|
||||
# Don't respond to initialize — let the control-request timeout fire.
|
||||
# The error propagates through _message_queue.
|
||||
with pytest.raises(ControlRequestTimeoutError):
|
||||
await query.__anext__()
|
||||
|
||||
await query.close()
|
||||
|
||||
# Give the event loop a moment to report any unhandled task exceptions.
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# No "Task exception was never retrieved" warnings should have appeared.
|
||||
task_warnings = [w for w in recwarn.list if "never retrieved" in str(w.message)]
|
||||
assert task_warnings == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_context_manager_closes_on_exit() -> None:
|
||||
transport = FakeTransport()
|
||||
query = Query(
|
||||
transport=transport, # type: ignore[arg-type]
|
||||
options=QueryOptions(
|
||||
timeout=TimeoutOptions(
|
||||
can_use_tool=0.05,
|
||||
control_request=0.05,
|
||||
stream_close=0.05,
|
||||
)
|
||||
),
|
||||
prompt="hello",
|
||||
session_id=VALID_UUID,
|
||||
)
|
||||
|
||||
async with query as q:
|
||||
assert q is query
|
||||
assert not q.is_closed()
|
||||
|
||||
assert query.is_closed() is True
|
||||
|
||||
|
||||
def test_sync_next_after_exhaustion_raises_stop_iteration() -> None:
|
||||
"""After the sync iterator is exhausted, subsequent __next__ calls must
|
||||
raise StopIteration immediately instead of blocking on queue.get()."""
|
||||
from queue import Queue
|
||||
|
||||
from qwen_code_sdk.sync_query import _STOP, SyncQuery
|
||||
|
||||
# Build a minimal SyncQuery without spawning the real event-loop thread.
|
||||
sq = object.__new__(SyncQuery)
|
||||
sq._queue = Queue()
|
||||
sq._exhausted = False
|
||||
|
||||
# Put one message then the sentinel.
|
||||
msg_payload = {
|
||||
"type": "assistant",
|
||||
"message": {"role": "assistant", "content": []},
|
||||
}
|
||||
sq._queue.put(msg_payload)
|
||||
sq._queue.put(_STOP)
|
||||
|
||||
# First call returns the message.
|
||||
msg = next(sq)
|
||||
assert msg["type"] == "assistant"
|
||||
|
||||
# Second call should exhaust.
|
||||
with pytest.raises(StopIteration):
|
||||
next(sq)
|
||||
|
||||
# Third call must raise immediately, not block.
|
||||
with pytest.raises(StopIteration):
|
||||
next(sq)
|
||||
291
packages/sdk-python/tests/unit/test_transport.py
Normal file
291
packages/sdk-python/tests/unit/test_transport.py
Normal file
|
|
@ -0,0 +1,291 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from qwen_code_sdk.transport import build_cli_arguments, prepare_spawn_info
|
||||
from qwen_code_sdk.types import QueryOptions, TimeoutOptions
|
||||
|
||||
VALID_UUID = "123e4567-e89b-12d3-a456-426614174000"
|
||||
|
||||
|
||||
class DummyProcess:
|
||||
def __init__(self) -> None:
|
||||
self.stdin = None
|
||||
self.stdout = None
|
||||
self.stderr = None
|
||||
self.returncode = 0
|
||||
|
||||
|
||||
def test_build_cli_arguments_maps_supported_options() -> None:
|
||||
args = build_cli_arguments(
|
||||
QueryOptions(
|
||||
model="qwen3-coder",
|
||||
system_prompt="system prompt",
|
||||
append_system_prompt="append prompt",
|
||||
permission_mode="auto-edit",
|
||||
max_session_turns=7,
|
||||
core_tools=["Read", "Edit"],
|
||||
exclude_tools=["Bash(rm *)"],
|
||||
allowed_tools=["Bash(git status)"],
|
||||
auth_type="openai",
|
||||
include_partial_messages=True,
|
||||
session_id=VALID_UUID,
|
||||
)
|
||||
)
|
||||
|
||||
assert args == [
|
||||
"--input-format",
|
||||
"stream-json",
|
||||
"--output-format",
|
||||
"stream-json",
|
||||
"--channel=SDK",
|
||||
"--model",
|
||||
"qwen3-coder",
|
||||
"--system-prompt",
|
||||
"system prompt",
|
||||
"--append-system-prompt",
|
||||
"append prompt",
|
||||
"--approval-mode",
|
||||
"auto-edit",
|
||||
"--max-session-turns",
|
||||
"7",
|
||||
"--core-tools",
|
||||
"Read,Edit",
|
||||
"--exclude-tools",
|
||||
"Bash(rm *)",
|
||||
"--allowed-tools",
|
||||
"Bash(git status)",
|
||||
"--auth-type",
|
||||
"openai",
|
||||
"--include-partial-messages",
|
||||
"--session-id",
|
||||
VALID_UUID,
|
||||
]
|
||||
|
||||
|
||||
def test_cli_argument_precedence_prefers_resume_then_continue_then_session_id() -> None:
|
||||
args = build_cli_arguments(
|
||||
QueryOptions(
|
||||
resume=VALID_UUID,
|
||||
continue_session=True,
|
||||
session_id="223e4567-e89b-12d3-a456-426614174000",
|
||||
)
|
||||
)
|
||||
|
||||
assert "--resume" in args
|
||||
assert "--continue" not in args
|
||||
assert "--session-id" not in args
|
||||
|
||||
|
||||
def test_prepare_spawn_info_uses_runtime_for_python_scripts(tmp_path: Path) -> None:
|
||||
script_path = tmp_path / "fake-qwen.py"
|
||||
script_path.write_text("print('ok')\n", encoding="utf-8")
|
||||
|
||||
spawn_info = prepare_spawn_info(str(script_path))
|
||||
|
||||
assert spawn_info.command == sys.executable
|
||||
assert spawn_info.args == [str(script_path.resolve())]
|
||||
|
||||
|
||||
def test_prepare_spawn_info_uses_node_for_javascript_files(tmp_path: Path) -> None:
|
||||
script_path = tmp_path / "fake-qwen.js"
|
||||
script_path.write_text("console.log('ok');\n", encoding="utf-8")
|
||||
|
||||
spawn_info = prepare_spawn_info(str(script_path))
|
||||
|
||||
assert spawn_info.command == "node"
|
||||
assert spawn_info.args == [str(script_path.resolve())]
|
||||
|
||||
|
||||
def test_prepare_spawn_info_keeps_plain_command_names() -> None:
|
||||
spawn_info = prepare_spawn_info("qwen-custom")
|
||||
|
||||
assert spawn_info.command == "qwen-custom"
|
||||
assert spawn_info.args == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_discards_stderr_when_debug_is_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def fake_create_subprocess_exec(*args: Any, **kwargs: Any) -> DummyProcess:
|
||||
captured["args"] = args
|
||||
captured["kwargs"] = kwargs
|
||||
return DummyProcess()
|
||||
|
||||
monkeypatch.setattr(
|
||||
asyncio,
|
||||
"create_subprocess_exec",
|
||||
fake_create_subprocess_exec,
|
||||
)
|
||||
|
||||
transport_module = __import__(
|
||||
"qwen_code_sdk.transport",
|
||||
fromlist=["ProcessTransport"],
|
||||
)
|
||||
transport = transport_module.ProcessTransport(
|
||||
QueryOptions(timeout=TimeoutOptions())
|
||||
)
|
||||
|
||||
await transport.start()
|
||||
|
||||
assert captured["kwargs"]["stderr"] is subprocess.DEVNULL
|
||||
|
||||
|
||||
def test_prepare_spawn_info_defaults_to_qwen_when_none() -> None:
|
||||
spawn_info = prepare_spawn_info(None)
|
||||
|
||||
assert spawn_info.command == "qwen"
|
||||
assert spawn_info.args == []
|
||||
|
||||
|
||||
def test_prepare_spawn_info_uses_node_for_mjs_files(tmp_path: Path) -> None:
|
||||
script_path = tmp_path / "cli.mjs"
|
||||
script_path.write_text("export default {};\n", encoding="utf-8")
|
||||
|
||||
spawn_info = prepare_spawn_info(str(script_path))
|
||||
|
||||
assert spawn_info.command == "node"
|
||||
assert spawn_info.args == [str(script_path.resolve())]
|
||||
|
||||
|
||||
def test_prepare_spawn_info_uses_node_for_cjs_files(tmp_path: Path) -> None:
|
||||
script_path = tmp_path / "cli.cjs"
|
||||
script_path.write_text("module.exports = {};\n", encoding="utf-8")
|
||||
|
||||
spawn_info = prepare_spawn_info(str(script_path))
|
||||
|
||||
assert spawn_info.command == "node"
|
||||
assert spawn_info.args == [str(script_path.resolve())]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_start_raises_after_close(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def fake_create_subprocess_exec(*args: Any, **kwargs: Any) -> DummyProcess:
|
||||
return DummyProcess()
|
||||
|
||||
monkeypatch.setattr(
|
||||
asyncio,
|
||||
"create_subprocess_exec",
|
||||
fake_create_subprocess_exec,
|
||||
)
|
||||
|
||||
transport_module = __import__(
|
||||
"qwen_code_sdk.transport",
|
||||
fromlist=["ProcessTransport"],
|
||||
)
|
||||
transport = transport_module.ProcessTransport(
|
||||
QueryOptions(timeout=TimeoutOptions())
|
||||
)
|
||||
transport._closed = True
|
||||
|
||||
with pytest.raises(RuntimeError, match="Transport is closed"):
|
||||
await transport.start()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_messages_skips_malformed_json_lines() -> None:
|
||||
"""Malformed JSON lines should be skipped, not crash the stream."""
|
||||
|
||||
class FakeStdout:
|
||||
def __init__(self, lines: list[bytes]) -> None:
|
||||
self._lines = iter(lines)
|
||||
|
||||
async def readline(self) -> bytes:
|
||||
return next(self._lines, b"")
|
||||
|
||||
transport_module = __import__(
|
||||
"qwen_code_sdk.transport",
|
||||
fromlist=["ProcessTransport"],
|
||||
)
|
||||
transport = transport_module.ProcessTransport(
|
||||
QueryOptions(timeout=TimeoutOptions())
|
||||
)
|
||||
|
||||
class FakeProcess:
|
||||
returncode = 0
|
||||
stdin = None
|
||||
stderr = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.stdout = FakeStdout(
|
||||
[
|
||||
b"not valid json\n",
|
||||
b'{"type":"system","subtype":"init","uuid":"u","session_id":"s"}\n',
|
||||
b"also bad\n",
|
||||
b"",
|
||||
]
|
||||
)
|
||||
|
||||
async def wait(self) -> int:
|
||||
return 0
|
||||
|
||||
transport._process = FakeProcess()
|
||||
|
||||
messages: list[Any] = []
|
||||
async for msg in transport.read_messages():
|
||||
messages.append(msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["type"] == "system"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stderr_callback_exceptions_do_not_fail_transport() -> None:
|
||||
class FakeStdout:
|
||||
async def readline(self) -> bytes:
|
||||
return b""
|
||||
|
||||
class FakeStderr:
|
||||
def __init__(self) -> None:
|
||||
self._lines = iter([b"error message\n", b""])
|
||||
|
||||
async def readline(self) -> bytes:
|
||||
return next(self._lines, b"")
|
||||
|
||||
transport_module = __import__(
|
||||
"qwen_code_sdk.transport",
|
||||
fromlist=["ProcessTransport"],
|
||||
)
|
||||
|
||||
callback_calls = 0
|
||||
|
||||
def stderr_callback(text: str) -> None:
|
||||
nonlocal callback_calls
|
||||
callback_calls += 1
|
||||
assert text == "error message"
|
||||
raise RuntimeError("sink failed")
|
||||
|
||||
transport = transport_module.ProcessTransport(
|
||||
QueryOptions(
|
||||
stderr=stderr_callback,
|
||||
timeout=TimeoutOptions(),
|
||||
)
|
||||
)
|
||||
|
||||
class FakeProcess:
|
||||
returncode = 0
|
||||
stdin = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.stdout = FakeStdout()
|
||||
self.stderr = FakeStderr()
|
||||
|
||||
async def wait(self) -> int:
|
||||
return 0
|
||||
|
||||
transport._process = FakeProcess()
|
||||
transport._stderr_task = asyncio.create_task(transport._forward_stderr())
|
||||
|
||||
await transport.wait_for_exit()
|
||||
|
||||
assert callback_calls == 1
|
||||
174
packages/sdk-python/tests/unit/test_validation.py
Normal file
174
packages/sdk-python/tests/unit/test_validation.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from qwen_code_sdk.errors import ValidationError
|
||||
from qwen_code_sdk.types import QueryOptions, TimeoutOptions
|
||||
from qwen_code_sdk.validation import validate_query_options
|
||||
|
||||
VALID_UUID = "123e4567-e89b-12d3-a456-426614174000"
|
||||
|
||||
|
||||
def test_rejects_resume_with_continue_session() -> None:
|
||||
with pytest.raises(ValidationError, match="resume together with continue_session"):
|
||||
validate_query_options(
|
||||
QueryOptions(
|
||||
resume=VALID_UUID,
|
||||
continue_session=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_rejects_session_id_with_resume() -> None:
|
||||
with pytest.raises(ValidationError, match="Cannot use session_id with resume"):
|
||||
validate_query_options(
|
||||
QueryOptions(
|
||||
session_id=VALID_UUID,
|
||||
resume="223e4567-e89b-12d3-a456-426614174000",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_rejects_invalid_session_id() -> None:
|
||||
with pytest.raises(ValidationError, match="Invalid session_id"):
|
||||
validate_query_options(QueryOptions(session_id="not-a-uuid"))
|
||||
|
||||
|
||||
def test_rejects_invalid_resume() -> None:
|
||||
with pytest.raises(ValidationError, match="Invalid resume"):
|
||||
validate_query_options(QueryOptions(resume="not-a-uuid"))
|
||||
|
||||
|
||||
def test_rejects_invalid_permission_mode() -> None:
|
||||
with pytest.raises(ValidationError, match="Invalid permission_mode"):
|
||||
validate_query_options(
|
||||
QueryOptions.from_mapping({"permission_mode": "unsafe-mode"})
|
||||
)
|
||||
|
||||
|
||||
def test_rejects_invalid_auth_type() -> None:
|
||||
with pytest.raises(ValidationError, match="Invalid auth_type"):
|
||||
validate_query_options(QueryOptions.from_mapping({"auth_type": "custom"}))
|
||||
|
||||
|
||||
def test_from_mapping_rejects_non_callable_can_use_tool() -> None:
|
||||
with pytest.raises(TypeError, match="can_use_tool must be callable"):
|
||||
QueryOptions.from_mapping({"can_use_tool": "bad"})
|
||||
|
||||
|
||||
def test_from_mapping_rejects_non_callable_stderr() -> None:
|
||||
with pytest.raises(TypeError, match="stderr must be callable"):
|
||||
QueryOptions.from_mapping({"stderr": "bad"})
|
||||
|
||||
|
||||
def test_validation_rejects_non_callable_can_use_tool() -> None:
|
||||
with pytest.raises(ValidationError, match="can_use_tool must be callable"):
|
||||
validate_query_options(QueryOptions(can_use_tool=cast(Any, "bad")))
|
||||
|
||||
|
||||
def test_validation_rejects_non_callable_stderr() -> None:
|
||||
with pytest.raises(ValidationError, match="stderr must be callable"):
|
||||
validate_query_options(QueryOptions(stderr=cast(Any, "bad")))
|
||||
|
||||
|
||||
def test_from_mapping_rejects_sync_can_use_tool() -> None:
|
||||
def can_use_tool( # type: ignore[no-untyped-def]
|
||||
tool_name, tool_input, context
|
||||
):
|
||||
return {"behavior": "deny", "message": "bad"}
|
||||
|
||||
with pytest.raises(TypeError, match="can_use_tool must be an async callable"):
|
||||
QueryOptions.from_mapping({"can_use_tool": can_use_tool})
|
||||
|
||||
|
||||
def test_validation_rejects_sync_can_use_tool() -> None:
|
||||
def can_use_tool( # type: ignore[no-untyped-def]
|
||||
tool_name, tool_input, context
|
||||
):
|
||||
return {"behavior": "deny", "message": "bad"}
|
||||
|
||||
with pytest.raises(ValidationError, match="can_use_tool must be an async callable"):
|
||||
validate_query_options(QueryOptions(can_use_tool=cast(Any, can_use_tool)))
|
||||
|
||||
|
||||
def test_from_mapping_rejects_can_use_tool_with_wrong_arity() -> None:
|
||||
async def can_use_tool(
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
) -> dict[str, str]:
|
||||
return {"behavior": "deny"}
|
||||
|
||||
with pytest.raises(
|
||||
TypeError, match="can_use_tool must accept exactly 3 positional arguments"
|
||||
):
|
||||
QueryOptions.from_mapping({"can_use_tool": can_use_tool})
|
||||
|
||||
|
||||
def test_validation_rejects_can_use_tool_with_wrong_arity() -> None:
|
||||
async def can_use_tool(
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
) -> dict[str, str]:
|
||||
return {"behavior": "deny"}
|
||||
|
||||
with pytest.raises(
|
||||
ValidationError,
|
||||
match="can_use_tool must accept exactly 3 positional arguments",
|
||||
):
|
||||
validate_query_options(QueryOptions(can_use_tool=cast(Any, can_use_tool)))
|
||||
|
||||
|
||||
def test_from_mapping_rejects_stderr_with_wrong_arity() -> None:
|
||||
def stderr() -> None:
|
||||
return None
|
||||
|
||||
with pytest.raises(
|
||||
TypeError, match="stderr must accept exactly 1 positional argument"
|
||||
):
|
||||
QueryOptions.from_mapping({"stderr": stderr})
|
||||
|
||||
|
||||
def test_validation_rejects_stderr_with_wrong_arity() -> None:
|
||||
def stderr() -> None:
|
||||
return None
|
||||
|
||||
with pytest.raises(
|
||||
ValidationError, match="stderr must accept exactly 1 positional argument"
|
||||
):
|
||||
validate_query_options(QueryOptions(stderr=cast(Any, stderr)))
|
||||
|
||||
|
||||
def test_rejects_invalid_max_session_turns() -> None:
|
||||
with pytest.raises(ValidationError, match="max_session_turns"):
|
||||
validate_query_options(QueryOptions(max_session_turns=-2))
|
||||
|
||||
|
||||
def test_rejects_empty_qwen_executable_path() -> None:
|
||||
with pytest.raises(
|
||||
ValidationError, match="path_to_qwen_executable cannot be empty"
|
||||
):
|
||||
validate_query_options(QueryOptions(path_to_qwen_executable=" "))
|
||||
|
||||
|
||||
def test_timeout_rejects_non_numeric_value() -> None:
|
||||
with pytest.raises(TypeError, match=r"timeout\.can_use_tool must be a positive"):
|
||||
TimeoutOptions.from_mapping({"can_use_tool": "fast"})
|
||||
|
||||
|
||||
def test_timeout_rejects_negative_value() -> None:
|
||||
pattern = r"timeout\.control_request must be a positive"
|
||||
with pytest.raises(ValueError, match=pattern):
|
||||
TimeoutOptions.from_mapping({"control_request": -1})
|
||||
|
||||
|
||||
def test_timeout_rejects_boolean_value() -> None:
|
||||
with pytest.raises(TypeError, match=r"timeout\.stream_close must be a positive"):
|
||||
TimeoutOptions.from_mapping({"stream_close": True})
|
||||
|
||||
|
||||
def test_rejects_mcp_servers() -> None:
|
||||
with pytest.raises(ValidationError, match="mcp_servers is not supported"):
|
||||
validate_query_options(
|
||||
QueryOptions(mcp_servers={"my-server": {"command": "node", "args": []}})
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue