diff --git a/.github/workflows/sdk-python.yml b/.github/workflows/sdk-python.yml new file mode 100644 index 000000000..43e76dedb --- /dev/null +++ b/.github/workflows/sdk-python.yml @@ -0,0 +1,59 @@ +name: 'SDK Python' + +on: + pull_request: + branches: + - 'main' + - 'release/**' + paths: + - 'packages/sdk-python/**' + - 'docs/developers/sdk-python.md' + - 'docs/developers/_meta.ts' + - 'README.md' + - 'package.json' + - '.github/workflows/sdk-python.yml' + push: + branches: + - 'main' + - 'release/**' + paths: + - 'packages/sdk-python/**' + - 'docs/developers/sdk-python.md' + - 'docs/developers/_meta.ts' + - 'README.md' + - 'package.json' + - '.github/workflows/sdk-python.yml' + +jobs: + sdk-python: + name: 'SDK Python (${{ matrix.python-version }})' + runs-on: 'ubuntu-latest' + strategy: + fail-fast: false + matrix: + python-version: ['3.10', '3.11', '3.12'] + steps: + - name: 'Checkout' + uses: 'actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8' # ratchet:actions/checkout@v5 + + - name: 'Set up Python' + uses: 'actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065' # ratchet:actions/setup-python@v5 + with: + python-version: '${{ matrix.python-version }}' + + - name: 'Install SDK test dependencies' + run: | + python -m pip install --upgrade pip + python -m pip install -e 'packages/sdk-python[dev]' + + - name: 'Run Ruff' + run: 'python -m ruff check --config packages/sdk-python/pyproject.toml packages/sdk-python' + + - name: 'Run Ruff Format' + run: 'python -m ruff format --check --config packages/sdk-python/pyproject.toml packages/sdk-python' + + - name: 'Run Mypy' + run: 'python -m mypy --config-file packages/sdk-python/pyproject.toml packages/sdk-python/src' + + - name: 'Run Pytest' + run: 'python -m pytest -c packages/sdk-python/pyproject.toml packages/sdk-python/tests -q' diff --git a/README.md b/README.md index b56c72120..5358f9d0e 100644 --- a/README.md +++ b/README.md @@ -424,7 +424,7 @@ As an open-source terminal agent, you can use Qwen Code in four primary ways: 1. Interactive mode (terminal UI) 2. Headless mode (scripts, CI) 3. IDE integration (VS Code, Zed) -4. TypeScript SDK +4. SDKs (TypeScript, Python, Java) #### Interactive mode @@ -452,11 +452,38 @@ Use Qwen Code inside your editor (VS Code, Zed, and JetBrains IDEs): - [Use in Zed](https://qwenlm.github.io/qwen-code-docs/en/users/integration-zed/) - [Use in JetBrains IDEs](https://qwenlm.github.io/qwen-code-docs/en/users/integration-jetbrains/) -#### TypeScript SDK +#### SDKs -Build on top of Qwen Code with the TypeScript SDK: +Build on top of Qwen Code with the available SDKs: -- [Use the Qwen Code SDK](./packages/sdk-typescript/README.md) +- TypeScript: [Use the Qwen Code SDK](./packages/sdk-typescript/README.md) +- Python: [Use the Python SDK](./packages/sdk-python/README.md) +- Java: [Use the Java SDK](./packages/sdk-java/qwencode/README.md) + +Python SDK example: + +```python +import asyncio + +from qwen_code_sdk import is_sdk_result_message, query + + +async def main() -> None: + result = query( + "Summarize the repository layout.", + { + "cwd": "/path/to/project", + "path_to_qwen_executable": "qwen", + }, + ) + + async for message in result: + if is_sdk_result_message(message): + print(message["result"]) + + +asyncio.run(main()) +``` ## Commands & Shortcuts diff --git a/docs/developers/_meta.ts b/docs/developers/_meta.ts index 42c938c9f..ed6597257 100644 --- a/docs/developers/_meta.ts +++ b/docs/developers/_meta.ts @@ -11,7 +11,8 @@ export default { type: 'separator', }, 'sdk-typescript': 'Typescript SDK', - 'sdk-java': 'Java SDK(alpha)', + 'sdk-python': 'Python SDK (alpha)', + 'sdk-java': 'Java SDK (alpha)', 'Dive Into Qwen Code': { title: 'Dive Into Qwen Code', type: 'separator', diff --git a/docs/developers/sdk-python.md b/docs/developers/sdk-python.md new file mode 100644 index 000000000..7ef2c0ab6 --- /dev/null +++ b/docs/developers/sdk-python.md @@ -0,0 +1,168 @@ +# Python SDK + +## `qwen-code-sdk` + +`qwen-code-sdk` is an experimental Python SDK for Qwen Code. v1 targets the +existing `stream-json` CLI protocol and keeps the transport surface small and +testable. + +## Scope + +- Package name: `qwen-code-sdk` +- Import path: `qwen_code_sdk` +- Runtime requirement: Python `>=3.10` +- CLI dependency: external `qwen` executable is required in v1 +- Transport scope: process transport only +- Not included in v1: ACP transport, SDK-embedded MCP servers + +## Install + +```bash +pip install qwen-code-sdk +``` + +If `qwen` is not on `PATH`, pass `path_to_qwen_executable` explicitly. + +## Quick Start + +```python +import asyncio + +from qwen_code_sdk import is_sdk_result_message, query + + +async def main() -> None: + result = query( + "Explain the repository structure.", + { + "cwd": "/path/to/project", + "path_to_qwen_executable": "qwen", + }, + ) + + async for message in result: + if is_sdk_result_message(message): + print(message["result"]) + + +asyncio.run(main()) +``` + +## API Surface + +### Top-level entry points + +- `query(prompt, options=None) -> Query` +- `query_sync(prompt, options=None) -> SyncQuery` + +`prompt` supports either: + +- `str` for single-turn requests +- `AsyncIterable[SDKUserMessage]` for multi-turn streams + +### `Query` + +- Async iterable over SDK messages +- `close()` +- `interrupt()` +- `set_model(model)` +- `set_permission_mode(mode)` +- `supported_commands()` +- `mcp_server_status()` +- `get_session_id()` +- `is_closed()` + +### `QueryOptions` + +Supported options in v1: + +- `cwd` +- `model` +- `path_to_qwen_executable` +- `permission_mode` +- `can_use_tool` +- `env` +- `system_prompt` +- `append_system_prompt` +- `debug` +- `max_session_turns` +- `core_tools` +- `exclude_tools` +- `allowed_tools` +- `auth_type` +- `include_partial_messages` +- `resume` +- `continue_session` +- `session_id` +- `timeout` +- `mcp_servers` +- `stderr` + +Session argument priority is fixed as: + +1. `resume` +2. `continue_session` +3. `session_id` + +## Permission Handling + +When the CLI emits a `can_use_tool` control request, the SDK routes it through +`can_use_tool(tool_name, tool_input, context)`. + +- Default behavior: deny +- Default timeout: 60 seconds +- Timeout fallback: deny +- Callback exceptions: converted to deny with an error message +- Callback context: `cancel_event`, `suggestions`, and `blocked_path` +- Callback contract: `can_use_tool` must be async with 3 positional arguments; + `stderr` must accept 1 positional string argument + +## Error Model + +- `ValidationError`: invalid options, invalid UUIDs, unsupported combinations +- `ControlRequestTimeoutError`: initialize, interrupt, or other control request + timed out +- `ProcessExitError`: CLI exited non-zero +- `AbortError`: control request or session was cancelled + +## Troubleshooting + +If the SDK cannot start the CLI: + +- Verify `qwen --version` works in the target environment +- Pass `path_to_qwen_executable` if your shell uses `nvm`, `pyenv`, or other + non-standard PATH setup +- Use `debug=True` or `stderr=print` to surface CLI stderr while debugging + +If session control calls time out: + +- Check that the target `qwen` version supports `--input-format stream-json` +- Increase `timeout.control_request` +- Verify that no wrapper script is swallowing stdout/stderr + +## Repository Integration + +Repository-level helper commands: + +- `npm run test:sdk:python` +- `npm run lint:sdk:python` +- `npm run typecheck:sdk:python` +- `npm run smoke:sdk:python -- --qwen qwen` + +## Real E2E Smoke + +For a real runtime check (actual `qwen` process + real model call), run from +the repository root. The npm helper uses `python3`, so ensure it resolves to a +Python `>=3.10` interpreter: + +```bash +npm run smoke:sdk:python -- --qwen qwen +``` + +This script runs: + +- async single-turn query +- async control flow (`supported_commands`, permission mode updates) +- sync `query_sync` query + +It prints JSON and returns non-zero on failure. diff --git a/package.json b/package.json index 5611c59ea..507ea5da2 100644 --- a/package.json +++ b/package.json @@ -43,6 +43,7 @@ "test:integration:sandbox:podman": "cross-env QWEN_SANDBOX=podman vitest run --root ./integration-tests", "test:integration:sdk:sandbox:none": "cross-env QWEN_SANDBOX=false vitest run --root ./integration-tests --poolOptions.threads.maxThreads 2 sdk-typescript", "test:integration:sdk:sandbox:docker": "cross-env QWEN_SANDBOX=docker npm run build:sandbox && QWEN_SANDBOX=docker vitest run --root ./integration-tests --poolOptions.threads.maxThreads 2 sdk-typescript", + "test:sdk:python": "python3 -m pytest -c packages/sdk-python/pyproject.toml packages/sdk-python/tests -q", "test:integration:cli:sandbox:none": "cross-env QWEN_SANDBOX=false vitest run --root ./integration-tests cli", "test:integration:cli:sandbox:docker": "cross-env QWEN_SANDBOX=docker npm run build:sandbox && QWEN_SANDBOX=docker vitest run --root ./integration-tests cli", "test:integration:interactive:sandbox:none": "cross-env QWEN_SANDBOX=false vitest run --root ./integration-tests interactive", @@ -53,9 +54,12 @@ "lint": "eslint . --ext .ts,.tsx && eslint integration-tests", "lint:fix": "eslint . --fix && eslint integration-tests --fix", "lint:ci": "eslint . --ext .ts,.tsx --max-warnings 0 && eslint integration-tests --max-warnings 0", + "lint:sdk:python": "python3 -m ruff check --config packages/sdk-python/pyproject.toml packages/sdk-python", "lint:all": "node scripts/lint.js", "format": "prettier --experimental-cli --write .", "typecheck": "npm run typecheck --workspaces --if-present", + "typecheck:sdk:python": "python3 -m mypy --config-file packages/sdk-python/pyproject.toml packages/sdk-python/src", + "smoke:sdk:python": "python3 packages/sdk-python/scripts/smoke_real.py", "check-i18n": "npm run check-i18n --workspace=packages/cli", "preflight": "npm run clean && npm ci && npm run format && npm run lint:ci && npm run build && npm run typecheck && npm run test:ci", "prepare": "husky && npm run build && npm run bundle", diff --git a/packages/cli/src/ui/components/SettingsDialog.test.tsx b/packages/cli/src/ui/components/SettingsDialog.test.tsx index c72750737..c1b75f755 100644 --- a/packages/cli/src/ui/components/SettingsDialog.test.tsx +++ b/packages/cli/src/ui/components/SettingsDialog.test.tsx @@ -963,11 +963,25 @@ describe('SettingsDialog', () => { , ); - // Trigger a restart-required setting change: navigate to "Language: UI" (2nd item) and toggle it. - stdin.write(TerminalKeys.DOWN_ARROW as string); - await wait(); - stdin.write(TerminalKeys.ENTER as string); - await wait(); + await waitFor(() => { + expect(lastFrame()).toContain('Tool Approval Mode'); + }); + + const languageIndex = getDialogSettingKeys().indexOf('general.language'); + expect(languageIndex).toBeGreaterThanOrEqual(0); + + const press = async (key: string) => { + act(() => { + stdin.write(key); + }); + await wait(); + }; + + // Trigger a restart-required setting change by toggling the UI language setting. + for (let i = 0; i < languageIndex; i++) { + await press(TerminalKeys.DOWN_ARROW as string); + } + await press(TerminalKeys.ENTER as string); await waitFor(() => { expect(lastFrame()).toContain( @@ -976,10 +990,8 @@ describe('SettingsDialog', () => { }); // Switch scopes; restart prompt should remain visible. - stdin.write(TerminalKeys.TAB as string); - await wait(); - stdin.write('2'); - await wait(); + await press(TerminalKeys.TAB as string); + await press('2'); await waitFor(() => { expect(lastFrame()).toContain( diff --git a/packages/sdk-python/README.md b/packages/sdk-python/README.md new file mode 100644 index 000000000..67cf3ae34 --- /dev/null +++ b/packages/sdk-python/README.md @@ -0,0 +1,117 @@ +# qwen-code-sdk + +Experimental Python SDK for programmatic access to Qwen Code through the +`stream-json` protocol. + +## Installation + +```bash +pip install qwen-code-sdk +``` + +## Requirements + +- Python `>=3.10` +- External `qwen` CLI installed and available in `PATH` + +You can also point the SDK at an explicit CLI binary or script with +`path_to_qwen_executable`. + +## Quick Start + +```python +import asyncio + +from qwen_code_sdk import is_sdk_result_message, query + + +async def main() -> None: + result = query( + "List the top-level packages in this repository.", + { + "cwd": "/path/to/project", + "path_to_qwen_executable": "qwen", + }, + ) + + async for message in result: + if is_sdk_result_message(message): + print(message["result"]) + + +asyncio.run(main()) +``` + +## Sync API + +```python +from qwen_code_sdk import query_sync + + +with query_sync( + "Say hello", + { + "path_to_qwen_executable": "qwen", + }, +) as result: + for message in result: + print(message) +``` + +## Main APIs + +- `query(prompt, options=None) -> Query` +- `query_sync(prompt, options=None) -> SyncQuery` +- `Query.close()`, `interrupt()`, `set_model()`, `set_permission_mode()` +- `Query.supported_commands()`, `mcp_server_status()`, `get_session_id()` + +`prompt` accepts either a single `str` or an `AsyncIterable[SDKUserMessage]` +for multi-turn sessions. + +## Permission Callback + +```python +from qwen_code_sdk import query + + +async def can_use_tool(tool_name, tool_input, context): + if tool_name == "write_file": + return {"behavior": "deny", "message": "Writes disabled in this app"} + return {"behavior": "allow", "updatedInput": tool_input} + + +result = query( + "Create hello.txt", + { + "path_to_qwen_executable": "qwen", + "can_use_tool": can_use_tool, + }, +) +``` + +The callback defaults to deny. If it does not return within 60 seconds, the SDK +auto-denies the tool request. + +The `context` argument includes `cancel_event`, `suggestions`, and +`blocked_path` when the CLI provides a path-specific permission target. +`can_use_tool` must be an `async def` callback accepting +`(tool_name, tool_input, context)`. `stderr` must accept a single `str`. + +## Errors + +- `ValidationError`: invalid query options or malformed session identifiers +- `ControlRequestTimeoutError`: CLI control operation exceeded timeout +- `ProcessExitError`: `qwen` exited with a non-zero code +- `AbortError`: query or control request was cancelled + +## Current Scope + +`0.1.x` is intentionally narrow: + +- Uses external `qwen` CLI via process transport +- Targets `stream-json` parity with the TypeScript SDK core flow +- Does not yet implement ACP transport +- Does not yet embed MCP servers inside the SDK process + +See [developer documentation](../../docs/developers/sdk-python.md) for more +detail. diff --git a/packages/sdk-python/pyproject.toml b/packages/sdk-python/pyproject.toml new file mode 100644 index 000000000..10745ef69 --- /dev/null +++ b/packages/sdk-python/pyproject.toml @@ -0,0 +1,61 @@ +[build-system] +requires = ["hatchling>=1.27.0"] +build-backend = "hatchling.build" + +[project] +name = "qwen-code-sdk" +version = "0.1.0" +description = "Python SDK for programmatic access to qwen-code CLI" +readme = "README.md" +requires-python = ">=3.10" +license = { text = "Apache-2.0" } +authors = [{ name = "Qwen Team" }] +keywords = ["qwen", "qwen-code", "sdk", "python", "ai", "code-assistant"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Software Development :: Libraries :: Python Modules", +] +dependencies = ["typing-extensions>=4.12.0"] + +[project.urls] +Homepage = "https://qwenlm.github.io/qwen-code-docs/" +Repository = "https://github.com/QwenLM/qwen-code" +Issues = "https://github.com/QwenLM/qwen-code/issues" + +[project.optional-dependencies] +dev = [ + "mypy>=1.11.0", + "pytest>=8.3.0", + "pytest-asyncio>=0.24.0", + "ruff>=0.8.0", +] + +[tool.hatch.build.targets.wheel] +packages = ["src/qwen_code_sdk"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +pythonpath = ["src"] +asyncio_mode = "auto" + +[tool.ruff] +line-length = 88 +target-version = "py310" + +[tool.ruff.lint] +select = ["E", "F", "I", "B", "UP", "ASYNC", "RUF"] + +[tool.mypy] +python_version = "3.10" +strict = true +warn_unused_configs = true +warn_redundant_casts = true +warn_unused_ignores = true +show_error_codes = true +mypy_path = "src" diff --git a/packages/sdk-python/scripts/smoke_real.py b/packages/sdk-python/scripts/smoke_real.py new file mode 100644 index 000000000..820b5a21b --- /dev/null +++ b/packages/sdk-python/scripts/smoke_real.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python3 +"""Run real end-to-end smoke tests against qwen CLI using qwen_code_sdk. + +This script is intentionally lightweight and avoids any test doubles. +It is useful for manual verification after changing SDK runtime behavior. +""" + +from __future__ import annotations + +import sys + +if sys.version_info < (3, 10): # noqa: UP036 + import json + + version = ".".join(str(part) for part in sys.version_info[:3]) + payload = { + "ok": False, + "stage": "startup", + "error": f"Python >=3.10 is required, current version is {version}", + "error_type": "RuntimeError", + } + print(json.dumps(payload, ensure_ascii=False, indent=2)) + raise SystemExit(2) + +import argparse +import asyncio +import json +import subprocess +import threading +from collections.abc import AsyncIterator, Awaitable, Callable +from dataclasses import asdict, dataclass +from pathlib import Path +from queue import Empty, Queue +from typing import Any, TypeVar + +SDK_ROOT = Path(__file__).resolve().parents[1] +SRC_ROOT = SDK_ROOT / "src" +if str(SRC_ROOT) not in sys.path: + sys.path.insert(0, str(SRC_ROOT)) + +from qwen_code_sdk import ( # noqa: E402 + SDKUserMessage, + SyncQuery, + is_sdk_assistant_message, + is_sdk_result_message, + is_sdk_system_message, + query, + query_sync, +) +from qwen_code_sdk.transport import prepare_spawn_info # noqa: E402 + +T = TypeVar("T") + + +@dataclass +class AsyncSingleResult: + ok: bool + assistant_text: str | None + result_text: str | None + session_id: str + + +@dataclass +class AsyncControlResult: + ok: bool + supported_commands_type: str + saw_system_message: bool + saw_result_message: bool + session_id: str + + +@dataclass +class SyncResult: + ok: bool + saw_result_message: bool + result_text: str | None + session_id: str + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run real qwen_code_sdk smoke tests using qwen CLI", + ) + parser.add_argument( + "--qwen", + default="qwen", + help="Path or command for qwen executable (default: qwen)", + ) + parser.add_argument( + "--cwd", + default=str(Path.cwd()), + help="Working directory passed to SDK query options", + ) + parser.add_argument( + "--model", + default=None, + help="Optional model name. If set, script will call set_model(model).", + ) + parser.add_argument( + "--timeout-seconds", + type=float, + default=90.0, + help="Timeout used for control/callback/stream-close options", + ) + parser.add_argument( + "--json-only", + action="store_true", + help="Print only JSON result (no progress logs)", + ) + return parser.parse_args() + + +def check_qwen_cli_available(qwen_cmd: str, timeout_seconds: float) -> str: + spawn_info = prepare_spawn_info(qwen_cmd) + completed = subprocess.run( + [spawn_info.command, *spawn_info.args, "--version"], + check=True, + capture_output=True, + text=True, + timeout=timeout_seconds, + ) + return completed.stdout.strip() + + +def build_options(args: argparse.Namespace) -> dict[str, Any]: + return { + "cwd": args.cwd, + "path_to_qwen_executable": args.qwen, + "permission_mode": "yolo", + "max_session_turns": 1, + "timeout": { + "control_request": args.timeout_seconds, + "can_use_tool": args.timeout_seconds, + "stream_close": args.timeout_seconds, + }, + } + + +def extract_assistant_text(message: dict[str, Any]) -> str: + content = message["message"].get("content", []) + if not isinstance(content, list): + return "" + + text_parts: list[str] = [] + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text_parts.append(str(block.get("text", ""))) + return "".join(text_parts) + + +async def run_async_single(args: argparse.Namespace) -> AsyncSingleResult: + token = "SDK_REAL_ASYNC_OK" + options = build_options(args) + q = query( + f"Reply exactly with {token}", + options, + ) + + assistant_text: str | None = None + result_text: str | None = None + try: + async for message in q: + if is_sdk_assistant_message(message): + assistant_text = (assistant_text or "") + extract_assistant_text( + message + ) + if is_sdk_result_message(message): + result_text = str(message.get("result", "")) + finally: + await q.close() + + ok = token in (assistant_text or "") and token in (result_text or "") + return AsyncSingleResult( + ok=ok, + assistant_text=assistant_text, + result_text=result_text, + session_id=q.get_session_id(), + ) + + +async def run_async_controls(args: argparse.Namespace) -> AsyncControlResult: + token = "SDK_REAL_CONTROL_OK" + options = build_options(args) + release_prompt = asyncio.Event() + + async def prompts() -> AsyncIterator[SDKUserMessage]: + await release_prompt.wait() + yield { + "type": "user", + "session_id": "00000000-0000-4000-8000-000000000001", + "message": { + "role": "user", + "content": f"Reply exactly with {token}", + }, + "parent_tool_use_id": None, + } + + q = query(prompts(), options) + + supported: dict[str, Any] | None = None + saw_system_message = False + saw_result_message = False + try: + supported = await q.supported_commands() + await q.set_permission_mode("plan") + await q.set_permission_mode("yolo") + if args.model: + await q.set_model(args.model) + + release_prompt.set() + async for message in q: + if is_sdk_system_message(message): + saw_system_message = True + if is_sdk_result_message(message): + saw_result_message = True + break + finally: + await q.close() + + ok = isinstance(supported, dict) and saw_result_message + return AsyncControlResult( + ok=ok, + supported_commands_type=type(supported).__name__, + saw_system_message=saw_system_message, + saw_result_message=saw_result_message, + session_id=q.get_session_id(), + ) + + +async def run_stage(stage: str, coro: Awaitable[T], timeout_seconds: float) -> T: + try: + return await asyncio.wait_for(coro, timeout=timeout_seconds) + except TimeoutError as exc: + message = f"{stage} timed out after {timeout_seconds} seconds" + raise TimeoutError(message) from exc + + +def run_sync( + args: argparse.Namespace, + on_query: Callable[[SyncQuery], None] | None = None, +) -> SyncResult: + token = "SDK_REAL_SYNC_OK" + options = build_options(args) + q = query_sync( + f"Reply exactly with {token}", + options, + ) + if on_query is not None: + on_query(q) + + saw_result_message = False + result_text: str | None = None + try: + for message in q: + if is_sdk_result_message(message): + saw_result_message = True + result_text = str(message.get("result", "")) + break + finally: + q.close() + + ok = saw_result_message and token in (result_text or "") + return SyncResult( + ok=ok, + saw_result_message=saw_result_message, + result_text=result_text, + session_id=q.get_session_id(), + ) + + +def run_sync_with_timeout(args: argparse.Namespace) -> SyncResult: + result_queue: Queue[SyncResult | BaseException] = Queue(maxsize=1) + query_holder: dict[str, SyncQuery] = {} + + def remember_query(q: SyncQuery) -> None: + query_holder["query"] = q + + def worker() -> None: + try: + result_queue.put(run_sync(args, on_query=remember_query)) + except BaseException as exc: + result_queue.put(exc) + + thread = threading.Thread( + target=worker, + name="qwen-sdk-real-smoke-sync", + daemon=True, + ) + thread.start() + + try: + item = result_queue.get(timeout=args.timeout_seconds) + except Empty as exc: + q = query_holder.get("query") + if q is not None: + q.close() + raise TimeoutError( + f"sync check timed out after {args.timeout_seconds} seconds" + ) from exc + + thread.join(timeout=1.0) + if isinstance(item, BaseException): + raise item + return item + + +def build_failure_payload( + *, + stage: str, + exc: BaseException, + qwen_version: str | None = None, + completed: dict[str, Any] | None = None, +) -> dict[str, Any]: + payload: dict[str, Any] = { + "ok": False, + "stage": stage, + "error": str(exc), + "error_type": type(exc).__name__, + } + if qwen_version is not None: + payload["qwen_version"] = qwen_version + if completed: + payload["completed"] = completed + return payload + + +async def main() -> int: + args = parse_args() + + try: + qwen_version = check_qwen_cli_available(args.qwen, args.timeout_seconds) + except (subprocess.CalledProcessError, OSError, subprocess.TimeoutExpired) as exc: + payload = build_failure_payload(stage="preflight", exc=exc) + print(json.dumps(payload, ensure_ascii=False, indent=2)) + return 2 + + stage = "async single-turn check" + completed: dict[str, Any] = {} + try: + if not args.json_only: + print(f"[smoke] qwen version: {qwen_version}") + print(f"[smoke] running {stage}...") + async_single = await run_stage( + stage, + run_async_single(args), + args.timeout_seconds, + ) + completed["async_single"] = asdict(async_single) + + stage = "async control check" + if not args.json_only: + print(f"[smoke] running {stage}...") + async_controls = await run_stage( + stage, + run_async_controls(args), + args.timeout_seconds, + ) + completed["async_controls"] = asdict(async_controls) + + stage = "sync check" + if not args.json_only: + print(f"[smoke] running {stage}...") + sync_result = run_sync_with_timeout(args) + completed["sync"] = asdict(sync_result) + except Exception as exc: + payload = build_failure_payload( + stage=stage, + exc=exc, + qwen_version=qwen_version, + completed=completed, + ) + print(json.dumps(payload, ensure_ascii=False, indent=2)) + return 1 + + all_ok = async_single.ok and async_controls.ok and sync_result.ok + payload = { + "ok": all_ok, + "qwen_version": qwen_version, + "async_single": asdict(async_single), + "async_controls": asdict(async_controls), + "sync": asdict(sync_result), + } + print(json.dumps(payload, ensure_ascii=False, indent=2)) + return 0 if all_ok else 1 + + +if __name__ == "__main__": + raise SystemExit(asyncio.run(main())) diff --git a/packages/sdk-python/src/qwen_code_sdk/__init__.py b/packages/sdk-python/src/qwen_code_sdk/__init__.py new file mode 100644 index 000000000..b2de65ad2 --- /dev/null +++ b/packages/sdk-python/src/qwen_code_sdk/__init__.py @@ -0,0 +1,108 @@ +"""qwen_code_sdk package exports.""" + +from __future__ import annotations + +from collections.abc import AsyncIterable, Iterable, Mapping +from typing import Any + +from .errors import ( + AbortError, + ControlRequestTimeoutError, + ProcessExitError, + QwenSDKError, + ValidationError, +) +from .protocol import ( + APIAssistantMessage, + APIUserMessage, + ContentBlock, + SDKAssistantMessage, + SDKMessage, + SDKPartialAssistantMessage, + SDKResultMessage, + SDKSystemMessage, + SDKUserMessage, + TextBlock, + ThinkingBlock, + ToolResultBlock, + ToolUseBlock, + Usage, + is_control_cancel, + is_control_request, + is_control_response, + is_sdk_assistant_message, + is_sdk_partial_assistant_message, + is_sdk_result_message, + is_sdk_system_message, + is_sdk_user_message, +) +from .query import Query, query +from .sync_query import SyncQuery +from .types import ( + AuthType, + CanUseTool, + CanUseToolContext, + PermissionAllowResult, + PermissionDenyResult, + PermissionMode, + PermissionResult, + PermissionSuggestion, + QueryOptions, + QueryOptionsDict, + TimeoutOptions, + TimeoutOptionsDict, +) + + +def query_sync( + prompt: str | Iterable[SDKUserMessage] | AsyncIterable[SDKUserMessage], + options: QueryOptions | QueryOptionsDict | Mapping[str, Any] | None = None, +) -> SyncQuery: + return SyncQuery(prompt=prompt, options=options) + + +__all__ = [ + "APIAssistantMessage", + "APIUserMessage", + "AbortError", + "AuthType", + "CanUseTool", + "CanUseToolContext", + "ContentBlock", + "ControlRequestTimeoutError", + "PermissionAllowResult", + "PermissionDenyResult", + "PermissionMode", + "PermissionResult", + "PermissionSuggestion", + "ProcessExitError", + "Query", + "QueryOptions", + "QueryOptionsDict", + "QwenSDKError", + "SDKAssistantMessage", + "SDKMessage", + "SDKPartialAssistantMessage", + "SDKResultMessage", + "SDKSystemMessage", + "SDKUserMessage", + "SyncQuery", + "TextBlock", + "ThinkingBlock", + "TimeoutOptions", + "TimeoutOptionsDict", + "ToolResultBlock", + "ToolUseBlock", + "Usage", + "ValidationError", + "is_control_cancel", + "is_control_request", + "is_control_response", + "is_sdk_assistant_message", + "is_sdk_partial_assistant_message", + "is_sdk_result_message", + "is_sdk_system_message", + "is_sdk_user_message", + "query", + "query_sync", +] diff --git a/packages/sdk-python/src/qwen_code_sdk/errors.py b/packages/sdk-python/src/qwen_code_sdk/errors.py new file mode 100644 index 000000000..e8837f917 --- /dev/null +++ b/packages/sdk-python/src/qwen_code_sdk/errors.py @@ -0,0 +1,27 @@ +"""Error types for qwen_code_sdk.""" + +from __future__ import annotations + + +class QwenSDKError(Exception): + """Base error for all SDK failures.""" + + +class ValidationError(QwenSDKError): + """Raised when query options are invalid.""" + + +class AbortError(QwenSDKError): + """Raised when an operation is aborted by caller or transport.""" + + +class ProcessExitError(QwenSDKError): + """Raised when qwen CLI exits with non-zero status or signal.""" + + def __init__(self, message: str, exit_code: int | None = None) -> None: + super().__init__(message) + self.exit_code = exit_code + + +class ControlRequestTimeoutError(QwenSDKError): + """Raised when a control request times out waiting for response.""" diff --git a/packages/sdk-python/src/qwen_code_sdk/json_lines.py b/packages/sdk-python/src/qwen_code_sdk/json_lines.py new file mode 100644 index 000000000..8b4b33ef7 --- /dev/null +++ b/packages/sdk-python/src/qwen_code_sdk/json_lines.py @@ -0,0 +1,14 @@ +"""JSON lines utilities.""" + +from __future__ import annotations + +import json +from typing import Any + + +def serialize_json_line(payload: Any) -> str: + return json.dumps(payload, ensure_ascii=False, separators=(",", ":")) + "\n" + + +def parse_json_line(line: str) -> Any: + return json.loads(line) diff --git a/packages/sdk-python/src/qwen_code_sdk/protocol.py b/packages/sdk-python/src/qwen_code_sdk/protocol.py new file mode 100644 index 000000000..7e5e50b70 --- /dev/null +++ b/packages/sdk-python/src/qwen_code_sdk/protocol.py @@ -0,0 +1,357 @@ +"""Protocol message types and helpers for qwen stream-json.""" + +from __future__ import annotations + +from typing import Any, Literal, TypeAlias, TypeGuard + +from typing_extensions import NotRequired, TypedDict + +from .types import PermissionMode, PermissionSuggestion + + +class Annotation(TypedDict): + type: str + value: str + + +class Usage(TypedDict): + input_tokens: int + output_tokens: int + cache_creation_input_tokens: NotRequired[int] + cache_read_input_tokens: NotRequired[int] + total_tokens: NotRequired[int] + + +class ExtendedUsage(Usage, total=False): + server_tool_use: dict[str, int] + service_tier: str + cache_creation: dict[str, int] + + +class CLIPermissionDenial(TypedDict): + tool_name: str + tool_use_id: str + tool_input: Any + + +class TextBlock(TypedDict): + type: Literal["text"] + text: str + annotations: NotRequired[list[Annotation]] + + +class ThinkingBlock(TypedDict): + type: Literal["thinking"] + thinking: str + signature: NotRequired[str] + annotations: NotRequired[list[Annotation]] + + +class ToolUseBlock(TypedDict): + type: Literal["tool_use"] + id: str + name: str + input: Any + annotations: NotRequired[list[Annotation]] + + +class ToolResultBlock(TypedDict): + type: Literal["tool_result"] + tool_use_id: str + content: NotRequired[str | list[ContentBlock]] + is_error: NotRequired[bool] + annotations: NotRequired[list[Annotation]] + + +ContentBlock: TypeAlias = TextBlock | ThinkingBlock | ToolUseBlock | ToolResultBlock + + +class APIUserMessage(TypedDict): + role: Literal["user"] + content: str | list[ContentBlock] + + +class APIAssistantMessage(TypedDict): + role: Literal["assistant"] + content: list[ContentBlock] + id: NotRequired[str] + type: NotRequired[Literal["message"]] + model: NotRequired[str] + stop_reason: NotRequired[str | None] + usage: NotRequired[Usage] + + +class SDKUserMessage(TypedDict): + type: Literal["user"] + session_id: str + message: APIUserMessage + parent_tool_use_id: str | None + uuid: NotRequired[str] + options: NotRequired[dict[str, Any]] + + +class SDKAssistantMessage(TypedDict): + type: Literal["assistant"] + uuid: str + session_id: str + message: APIAssistantMessage + parent_tool_use_id: str | None + + +class MCPServerState(TypedDict): + name: str + status: str + + +class SDKSystemMessage(TypedDict): + type: Literal["system"] + subtype: str + uuid: str + session_id: str + data: NotRequired[Any] + cwd: NotRequired[str] + tools: NotRequired[list[str]] + mcp_servers: NotRequired[list[MCPServerState]] + model: NotRequired[str] + permission_mode: NotRequired[str] + slash_commands: NotRequired[list[str]] + qwen_code_version: NotRequired[str] + output_style: NotRequired[str] + agents: NotRequired[list[str]] + skills: NotRequired[list[str]] + capabilities: NotRequired[dict[str, Any]] + + +class SDKResultMessageSuccess(TypedDict): + type: Literal["result"] + subtype: Literal["success"] + uuid: str + session_id: str + is_error: Literal[False] + duration_ms: int + duration_api_ms: int + num_turns: int + result: str + usage: ExtendedUsage + permission_denials: list[CLIPermissionDenial] + + +class ResultErrorObject(TypedDict): + message: str + type: NotRequired[str] + + +class SDKResultMessageError(TypedDict): + type: Literal["result"] + subtype: Literal["error_max_turns", "error_during_execution"] + uuid: str + session_id: str + is_error: Literal[True] + duration_ms: int + duration_api_ms: int + num_turns: int + usage: ExtendedUsage + permission_denials: list[CLIPermissionDenial] + error: NotRequired[ResultErrorObject] + + +SDKResultMessage: TypeAlias = SDKResultMessageSuccess | SDKResultMessageError + + +class MessageStartStreamEvent(TypedDict): + type: Literal["message_start"] + message: dict[str, Any] + + +class ContentBlockStartEvent(TypedDict): + type: Literal["content_block_start"] + index: int + content_block: ContentBlock + + +class ContentBlockDeltaEvent(TypedDict): + type: Literal["content_block_delta"] + index: int + delta: dict[str, Any] + + +class ContentBlockStopEvent(TypedDict): + type: Literal["content_block_stop"] + index: int + + +class MessageStopStreamEvent(TypedDict): + type: Literal["message_stop"] + + +StreamEvent: TypeAlias = ( + MessageStartStreamEvent + | ContentBlockStartEvent + | ContentBlockDeltaEvent + | ContentBlockStopEvent + | MessageStopStreamEvent +) + + +class SDKPartialAssistantMessage(TypedDict): + type: Literal["stream_event"] + uuid: str + session_id: str + event: StreamEvent + parent_tool_use_id: str | None + + +class CLIControlInterruptRequest(TypedDict): + subtype: Literal["interrupt"] + + +class CLIControlPermissionRequest(TypedDict): + subtype: Literal["can_use_tool"] + tool_name: str + tool_use_id: str + input: Any + permission_suggestions: list[PermissionSuggestion] | None + blocked_path: str | None + + +class CLIControlInitializeRequest(TypedDict): + subtype: Literal["initialize"] + hooks: NotRequired[Any] + mcpServers: NotRequired[dict[str, dict[str, Any]]] + + +class CLIControlSetPermissionModeRequest(TypedDict): + subtype: Literal["set_permission_mode"] + mode: PermissionMode + + +class CLIControlSetModelRequest(TypedDict): + subtype: Literal["set_model"] + model: str + + +class CLIControlMcpStatusRequest(TypedDict): + subtype: Literal["mcp_server_status"] + + +class CLIControlSupportedCommandsRequest(TypedDict): + subtype: Literal["supported_commands"] + + +ControlRequestPayload: TypeAlias = ( + CLIControlInterruptRequest + | CLIControlPermissionRequest + | CLIControlInitializeRequest + | CLIControlSetPermissionModeRequest + | CLIControlSetModelRequest + | CLIControlMcpStatusRequest + | CLIControlSupportedCommandsRequest + | dict[str, Any] +) + + +class CLIControlRequest(TypedDict): + type: Literal["control_request"] + request_id: str + request: ControlRequestPayload + + +class ControlResponseSuccess(TypedDict): + subtype: Literal["success"] + request_id: str + response: Any + + +class ControlResponseError(TypedDict): + subtype: Literal["error"] + request_id: str + error: str | dict[str, Any] + + +class CLIControlResponse(TypedDict): + type: Literal["control_response"] + response: ControlResponseSuccess | ControlResponseError + + +class ControlCancelRequest(TypedDict): + type: Literal["control_cancel_request"] + request_id: NotRequired[str] + + +SDKMessage: TypeAlias = ( + SDKUserMessage + | SDKAssistantMessage + | SDKSystemMessage + | SDKResultMessage + | SDKPartialAssistantMessage +) + + +ControlMessage: TypeAlias = ( + CLIControlRequest | CLIControlResponse | ControlCancelRequest +) + + +def is_sdk_user_message(msg: Any) -> TypeGuard[SDKUserMessage]: + return isinstance(msg, dict) and msg.get("type") == "user" and "message" in msg + + +def is_sdk_assistant_message(msg: Any) -> TypeGuard[SDKAssistantMessage]: + return ( + isinstance(msg, dict) + and msg.get("type") == "assistant" + and "session_id" in msg + and "message" in msg + ) + + +def is_sdk_system_message(msg: Any) -> TypeGuard[SDKSystemMessage]: + return ( + isinstance(msg, dict) + and msg.get("type") == "system" + and "subtype" in msg + and "session_id" in msg + ) + + +def is_sdk_result_message(msg: Any) -> TypeGuard[SDKResultMessage]: + return ( + isinstance(msg, dict) + and msg.get("type") == "result" + and "subtype" in msg + and "session_id" in msg + ) + + +def is_sdk_partial_assistant_message(msg: Any) -> TypeGuard[SDKPartialAssistantMessage]: + return ( + isinstance(msg, dict) + and msg.get("type") == "stream_event" + and "session_id" in msg + and "event" in msg + ) + + +def is_control_request(msg: Any) -> TypeGuard[CLIControlRequest]: + return ( + isinstance(msg, dict) + and msg.get("type") == "control_request" + and "request_id" in msg + and "request" in msg + ) + + +def is_control_response(msg: Any) -> TypeGuard[CLIControlResponse]: + return ( + isinstance(msg, dict) + and msg.get("type") == "control_response" + and "response" in msg + ) + + +def is_control_cancel(msg: Any) -> TypeGuard[ControlCancelRequest]: + return ( + isinstance(msg, dict) + and msg.get("type") == "control_cancel_request" + and "request_id" in msg + ) diff --git a/packages/sdk-python/src/qwen_code_sdk/py.typed b/packages/sdk-python/src/qwen_code_sdk/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/packages/sdk-python/src/qwen_code_sdk/query.py b/packages/sdk-python/src/qwen_code_sdk/query.py new file mode 100644 index 000000000..9bef85781 --- /dev/null +++ b/packages/sdk-python/src/qwen_code_sdk/query.py @@ -0,0 +1,607 @@ +"""Async Query implementation for qwen_code_sdk.""" + +from __future__ import annotations + +import asyncio +import contextlib +from collections.abc import AsyncIterable, Mapping, MutableMapping +from dataclasses import dataclass, replace +from types import TracebackType +from typing import Any, cast +from uuid import uuid4 + +from .errors import AbortError, ControlRequestTimeoutError +from .json_lines import serialize_json_line +from .protocol import ( + CLIControlRequest, + CLIControlResponse, + SDKMessage, + SDKUserMessage, + is_control_cancel, + is_control_request, + is_control_response, + is_sdk_assistant_message, + is_sdk_partial_assistant_message, + is_sdk_result_message, + is_sdk_system_message, + is_sdk_user_message, +) +from .transport import ProcessTransport +from .types import ( + CanUseToolContext, + PermissionDenyResult, + QueryOptions, + QueryOptionsDict, +) +from .validation import validate_query_options + +_DONE = object() + + +@dataclass +class _PendingControlRequest: + future: asyncio.Future[dict[str, Any] | None] + cancel_event: asyncio.Event + timeout_handle: asyncio.TimerHandle + + +@dataclass +class _IncomingControlRequest: + task: asyncio.Task[None] + cancel_event: asyncio.Event + + +class Query: + def __init__( + self, + transport: ProcessTransport, + options: QueryOptions, + prompt: str | AsyncIterable[SDKUserMessage], + session_id: str, + ) -> None: + self._transport = transport + self._options = options + self._prompt = prompt + self._single_turn = isinstance(prompt, str) + self._session_id = session_id + self._session_id_locked = bool(options.resume or options.session_id) + + self._message_queue: asyncio.Queue[SDKMessage | Exception | object] = ( + asyncio.Queue() + ) + self._closed = False + self._started = False + self._start_lock = asyncio.Lock() + self._cancel_event = asyncio.Event() + + self._router_task: asyncio.Task[None] | None = None + self._input_task: asyncio.Task[None] | None = None + self._initialize_task: asyncio.Task[None] | None = None + self._first_result_event = asyncio.Event() + self._terminal_event_sent = False + self._exhausted = False + + self._pending_control_requests: dict[str, _PendingControlRequest] = {} + self._incoming_control_requests: dict[str, _IncomingControlRequest] = {} + + async def _ensure_started(self) -> None: + if self._closed: + raise RuntimeError("Query is closed") + if self._started: + return + + async with self._start_lock: + if self._closed: + raise RuntimeError("Query is closed") + if self._started: + return + await self._transport.start() + self._router_task = asyncio.create_task(self._message_router()) + self._initialize_task = asyncio.create_task(self._initialize()) + + if self._single_turn: + self._input_task = asyncio.create_task(self._send_single_turn_prompt()) + else: + self._input_task = asyncio.create_task( + self.stream_input(self._prompt) # type: ignore[arg-type] + ) + self._started = True + + async def _initialize(self) -> None: + try: + payload: dict[str, Any] = {"hooks": None} + await self._send_control_request("initialize", payload) + except Exception as exc: + await self._finish_with_error(exc) + + async def _send_single_turn_prompt(self) -> None: + try: + assert isinstance(self._prompt, str) + await self._wait_initialized() + message: SDKUserMessage = { + "type": "user", + "session_id": self._session_id, + "message": { + "role": "user", + "content": self._prompt, + }, + "parent_tool_use_id": None, + } + + await self._write_payload(message) + except Exception as exc: + await self._finish_with_error(exc) + raise + + async def _wait_initialized(self) -> None: + if self._initialize_task is None: + return + await self._initialize_task + + async def _message_router(self) -> None: + try: + async for message in self._transport.read_messages(): + await self._route_message(message) + if self._closed: + break + + if self._closed: + return + + if self._transport.exit_error is not None: + await self._finish_with_error(self._transport.exit_error) + return + + await self._finish() + except Exception as exc: # pragma: no cover - critical propagation path + await self._finish_with_error(exc) + + async def _route_message(self, message: Any) -> None: + self._maybe_update_session_id(message) + + if is_control_request(message): + self._start_incoming_control_request(message) + return + + if is_control_response(message): + self._handle_control_response(message) + return + + if is_control_cancel(message): + self._handle_control_cancel_request(message) + return + + if is_sdk_result_message(message): + self._first_result_event.set() + if self._single_turn: + self._transport.end_input() + await self._message_queue.put(message) + return + + if ( + is_sdk_system_message(message) + or is_sdk_assistant_message(message) + or is_sdk_user_message(message) + or is_sdk_partial_assistant_message(message) + ): + await self._message_queue.put(message) + return + + def _maybe_update_session_id(self, message: Any) -> None: + if self._session_id_locked or not isinstance(message, Mapping): + return + + session_id = message.get("session_id") + if isinstance(session_id, str) and session_id: + self._session_id = session_id + self._session_id_locked = True + + def _start_incoming_control_request(self, request: CLIControlRequest) -> None: + request_id = request["request_id"] + cancel_event = asyncio.Event() + + async def runner() -> None: + try: + await self._handle_control_request(request, cancel_event) + except asyncio.CancelledError: + pass + except Exception as exc: # pragma: no cover - fatal background path + await self._finish_with_error(exc) + finally: + self._incoming_control_requests.pop(request_id, None) + + task = asyncio.create_task(runner()) + self._incoming_control_requests[request_id] = _IncomingControlRequest( + task=task, + cancel_event=cancel_event, + ) + + async def _handle_control_request( + self, + request: CLIControlRequest, + cancel_event: asyncio.Event, + ) -> None: + request_id = request["request_id"] + payload = request["request"] + subtype = payload.get("subtype") + + try: + if subtype == "can_use_tool": + response = await self._handle_permission_request( + cast(MutableMapping[str, Any], payload), + cancel_event, + ) + elif subtype == "mcp_message": + raise RuntimeError("mcp_message is unsupported in python sdk v1") + else: + raise RuntimeError(f"Unknown control request subtype: {subtype}") + + if cancel_event.is_set(): + return + + await self._send_control_response( + request_id, success=True, response=response + ) + except Exception as exc: + if cancel_event.is_set(): + return + await self._send_control_response( + request_id, + success=False, + response=str(exc), + ) + + async def _handle_permission_request( + self, + payload: MutableMapping[str, Any], + cancel_event: asyncio.Event, + ) -> dict[str, Any]: + tool_name = str(payload.get("tool_name", "")) + tool_input = payload.get("input") + if not isinstance(tool_input, dict): + tool_input = {} + + if self._options.can_use_tool is None: + return {"behavior": "deny", "message": "Denied"} + + context: CanUseToolContext = { + "cancel_event": cancel_event, + "suggestions": payload.get("permission_suggestions"), + "blocked_path": payload.get("blocked_path"), + } + + try: + result = await asyncio.wait_for( + self._options.can_use_tool(tool_name, tool_input, context), + timeout=self._options.timeout.can_use_tool, + ) + except asyncio.TimeoutError: + return { + "behavior": "deny", + "message": "Permission request timed out", + } + except asyncio.CancelledError: + if cancel_event.is_set(): + raise + return { + "behavior": "deny", + "message": "Permission check failed: callback cancelled", + } + except Exception as exc: + return { + "behavior": "deny", + "message": f"Permission check failed: {exc}", + } + + behavior = result.get("behavior") + if behavior == "allow": + return { + "behavior": "allow", + "updatedInput": result.get("updatedInput", tool_input), + } + + deny_result = cast(PermissionDenyResult, result) + return { + "behavior": "deny", + "message": deny_result.get("message", "Denied"), + **( + {"interrupt": deny_result["interrupt"]} + if "interrupt" in deny_result + else {} + ), + } + + def _handle_control_response(self, response: CLIControlResponse) -> None: + payload = response["response"] + request_id = payload["request_id"] + + pending = self._pending_control_requests.pop(request_id, None) + if pending is None: + return + + pending.timeout_handle.cancel() + + if payload["subtype"] == "success": + if not pending.future.done(): + pending.future.set_result(payload.get("response")) + else: + error = payload.get("error", "Unknown control error") + if isinstance(error, dict): + error_message = str(error.get("message", "Unknown control error")) + else: + error_message = str(error) + if not pending.future.done(): + pending.future.set_exception(RuntimeError(error_message)) + + def _handle_control_cancel_request(self, message: Mapping[str, Any]) -> None: + request_id = message.get("request_id") + if not isinstance(request_id, str): + return + + pending = self._pending_control_requests.pop(request_id, None) + if pending is not None: + pending.timeout_handle.cancel() + pending.cancel_event.set() + if not pending.future.done(): + pending.future.set_exception(AbortError("Control request cancelled")) + + incoming = self._incoming_control_requests.get(request_id) + if incoming is None: + return + + incoming.cancel_event.set() + incoming.task.cancel() + + async def _send_control_request( + self, + subtype: str, + data: dict[str, Any] | None = None, + ) -> dict[str, Any] | None: + if self._closed: + raise RuntimeError("Query is closed") + + if subtype != "initialize": + await self._wait_initialized() + + request_id = str(uuid4()) + + loop = asyncio.get_running_loop() + future: asyncio.Future[dict[str, Any] | None] = loop.create_future() + cancel_event = asyncio.Event() + + def on_timeout() -> None: + pending = self._pending_control_requests.pop(request_id, None) + if pending is None: + return + pending.cancel_event.set() + if not pending.future.done(): + pending.future.set_exception( + ControlRequestTimeoutError(f"Control request timeout: {subtype}") + ) + + timeout_handle = loop.call_later( + self._options.timeout.control_request, + on_timeout, + ) + + self._pending_control_requests[request_id] = _PendingControlRequest( + future=future, + cancel_event=cancel_event, + timeout_handle=timeout_handle, + ) + + request_payload: dict[str, Any] = {"subtype": subtype} + if data: + request_payload.update(data) + + payload: CLIControlRequest = { + "type": "control_request", + "request_id": request_id, + "request": request_payload, + } + + await self._write_payload(payload) + return await future + + async def _send_control_response( + self, + request_id: str, + *, + success: bool, + response: Any, + ) -> None: + payload: CLIControlResponse + if success: + payload = { + "type": "control_response", + "response": { + "subtype": "success", + "request_id": request_id, + "response": response, + }, + } + else: + payload = { + "type": "control_response", + "response": { + "subtype": "error", + "request_id": request_id, + "error": str(response), + }, + } + + await self._write_payload(payload) + + async def _write_payload(self, payload: Any) -> None: + self._transport.write(serialize_json_line(payload)) + await self._transport.drain() + + async def stream_input(self, messages: AsyncIterable[SDKUserMessage]) -> None: + try: + if self._closed: + raise RuntimeError("Query is closed") + + await self._wait_initialized() + + async for message in messages: + if self._cancel_event.is_set() or self._closed: + break + await self._write_payload(message) + + if not self._single_turn: + try: + await asyncio.wait_for( + self._first_result_event.wait(), + timeout=self._options.timeout.stream_close, + ) + except asyncio.TimeoutError: + pass + + self._transport.end_input() + except Exception as exc: + await self._finish_with_error(exc) + raise + + async def interrupt(self) -> None: + await self._ensure_started() + await self._send_control_request("interrupt") + + async def set_permission_mode(self, mode: str) -> None: + await self._ensure_started() + await self._send_control_request("set_permission_mode", {"mode": mode}) + + async def set_model(self, model: str) -> None: + await self._ensure_started() + await self._send_control_request("set_model", {"model": model}) + + async def supported_commands(self) -> dict[str, Any] | None: + await self._ensure_started() + return await self._send_control_request("supported_commands") + + async def mcp_server_status(self) -> dict[str, Any] | None: + await self._ensure_started() + return await self._send_control_request("mcp_server_status") + + @property + def control_request_timeout(self) -> float: + return self._options.timeout.control_request + + def get_session_id(self) -> str: + return self._session_id + + def is_closed(self) -> bool: + return self._closed + + def _fail_pending_control_requests(self, error: Exception) -> None: + for request_id, pending in list(self._pending_control_requests.items()): + pending.timeout_handle.cancel() + pending.cancel_event.set() + if not pending.future.done(): + pending.future.set_exception(error) + self._pending_control_requests.pop(request_id, None) + + async def _cancel_incoming_control_requests(self) -> None: + current_task = asyncio.current_task() + tasks: list[asyncio.Task[None]] = [] + + for incoming in list(self._incoming_control_requests.values()): + incoming.cancel_event.set() + if incoming.task is current_task: + continue + incoming.task.cancel() + tasks.append(incoming.task) + + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + async def close(self) -> None: + if self._closed: + return + + self._closed = True + self._cancel_event.set() + + error = RuntimeError("Query is closed") + self._fail_pending_control_requests(error) + await self._cancel_incoming_control_requests() + + await self._transport.close() + + if self._input_task is not None: + self._input_task.cancel() + with contextlib.suppress(asyncio.CancelledError, Exception): + await self._input_task + + if self._router_task is not None: + with contextlib.suppress(Exception): + await self._router_task + + await self._finish() + + async def _finish(self) -> None: + if self._terminal_event_sent: + return + self._terminal_event_sent = True + await self._message_queue.put(_DONE) + + async def _finish_with_error(self, exc: Exception) -> None: + if self._terminal_event_sent: + return + self._closed = True + self._terminal_event_sent = True + self._cancel_event.set() + self._fail_pending_control_requests(exc) + await self._cancel_incoming_control_requests() + await self._transport.close() + await self._message_queue.put(exc) + await self._message_queue.put(_DONE) + + def __aiter__(self) -> Query: + return self + + async def __anext__(self) -> SDKMessage: + if self._exhausted: + raise StopAsyncIteration + await self._ensure_started() + item = await self._message_queue.get() + + if item is _DONE: + self._exhausted = True + raise StopAsyncIteration + + if isinstance(item, Exception): + raise item + + return cast(SDKMessage, item) + + async def __aenter__(self) -> Query: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.close() + + +def query( + prompt: str | AsyncIterable[SDKUserMessage], + options: QueryOptions | QueryOptionsDict | Mapping[str, Any] | None = None, +) -> Query: + if isinstance(options, QueryOptions): + parsed_options = replace(options) + else: + parsed_options = QueryOptions.from_mapping(options) + + validate_query_options(parsed_options) + + session_id = parsed_options.resume or parsed_options.session_id + if session_id is None and not parsed_options.continue_session: + session_id = str(uuid4()) + if parsed_options.resume is None and not parsed_options.continue_session: + parsed_options = replace(parsed_options, session_id=session_id) + + transport = ProcessTransport(parsed_options) + return Query(transport, parsed_options, prompt, session_id or "") diff --git a/packages/sdk-python/src/qwen_code_sdk/sync_query.py b/packages/sdk-python/src/qwen_code_sdk/sync_query.py new file mode 100644 index 000000000..c08713dd6 --- /dev/null +++ b/packages/sdk-python/src/qwen_code_sdk/sync_query.py @@ -0,0 +1,217 @@ +"""Synchronous wrapper around the async Query API.""" + +from __future__ import annotations + +import asyncio +import threading +import warnings +from collections.abc import AsyncIterable, AsyncIterator, Iterable, Mapping +from queue import Queue +from typing import Any, cast + +from .protocol import SDKMessage, SDKUserMessage +from .query import Query, query +from .types import QueryOptions, QueryOptionsDict + +_STOP = object() +_SYNC_TIMEOUT_MARGIN = 5.0 + + +class SyncQuery: + def __init__( + self, + prompt: str | Iterable[SDKUserMessage] | AsyncIterable[SDKUserMessage], + options: QueryOptions | QueryOptionsDict | Mapping[str, Any] | None = None, + ) -> None: + self._queue: Queue[SDKMessage | Exception | object] = Queue() + self._ready = threading.Event() + self._shutdown = threading.Event() + self._stop_sent = threading.Event() + self._exhausted = False + self._query: Query | None = None + self._consumer_task: asyncio.Task[None] | None = None + + self._loop = asyncio.new_event_loop() + self._thread = threading.Thread( + target=self._run_loop, + name="qwen-sdk-sync-loop", + daemon=True, + ) + self._thread.start() + + if isinstance(prompt, str) or isinstance(prompt, AsyncIterable): + source_prompt: str | AsyncIterable[SDKUserMessage] = prompt + else: + source_prompt = _iterable_to_async(prompt) + + future = asyncio.run_coroutine_threadsafe( + self._bootstrap(source_prompt, options), + self._loop, + ) + try: + future.result() + except Exception: + self._stop_loop() + raise + + def _run_loop(self) -> None: + asyncio.set_event_loop(self._loop) + self._loop.run_forever() + + async def _bootstrap( + self, + prompt: str | AsyncIterable[SDKUserMessage], + options: QueryOptions | QueryOptionsDict | Mapping[str, Any] | None, + ) -> None: + self._query = query(prompt=prompt, options=options) + self._ready.set() + self._consumer_task = asyncio.create_task(self._consume()) + + async def _consume(self) -> None: + assert self._query is not None + try: + async for message in self._query: + self._queue.put(message) + except Exception as exc: + self._queue.put(exc) + finally: + if not self._stop_sent.is_set(): + self._stop_sent.set() + self._queue.put(_STOP) + + def _require_query(self) -> Query: + self._ready.wait(timeout=30) + if self._query is None: + raise RuntimeError("SyncQuery failed to initialize") + return self._query + + def __iter__(self) -> SyncQuery: + return self + + def __next__(self) -> SDKMessage: + if self._exhausted: + raise StopIteration + item = self._queue.get() + + if item is _STOP: + self._exhausted = True + raise StopIteration + + if isinstance(item, Exception): + raise item + + return cast(SDKMessage, item) + + def __enter__(self) -> SyncQuery: + return self + + def __exit__(self, *_args: object) -> None: + self.close() + + def interrupt(self) -> None: + q = self._require_query() + asyncio.run_coroutine_threadsafe(q.interrupt(), self._loop).result( + timeout=q.control_request_timeout + _SYNC_TIMEOUT_MARGIN + ) + + def set_model(self, model: str) -> None: + q = self._require_query() + asyncio.run_coroutine_threadsafe(q.set_model(model), self._loop).result( + timeout=q.control_request_timeout + _SYNC_TIMEOUT_MARGIN + ) + + def set_permission_mode(self, mode: str) -> None: + q = self._require_query() + asyncio.run_coroutine_threadsafe( + q.set_permission_mode(mode), + self._loop, + ).result(timeout=q.control_request_timeout + _SYNC_TIMEOUT_MARGIN) + + def supported_commands(self) -> Any: + q = self._require_query() + return asyncio.run_coroutine_threadsafe( + q.supported_commands(), + self._loop, + ).result(timeout=q.control_request_timeout + _SYNC_TIMEOUT_MARGIN) + + def mcp_server_status(self) -> Any: + q = self._require_query() + return asyncio.run_coroutine_threadsafe( + q.mcp_server_status(), + self._loop, + ).result(timeout=q.control_request_timeout + _SYNC_TIMEOUT_MARGIN) + + def get_session_id(self) -> str: + q = self._require_query() + return q.get_session_id() + + def is_closed(self) -> bool: + q = self._require_query() + return q.is_closed() + + def close(self) -> None: + if self._shutdown.is_set(): + return + + self._shutdown.set() + + q = self._query + if q is not None: + try: + asyncio.run_coroutine_threadsafe(q.close(), self._loop).result( + timeout=30 + ) + except Exception: + pass + + # Wait for _consume() to put _STOP before stopping the loop, + # otherwise consumers blocked on queue.get() will deadlock. + if self._consumer_task is not None: + try: + asyncio.run_coroutine_threadsafe( + self._await_consumer(), self._loop + ).result(timeout=5) + except Exception: + pass + + if not self._stop_sent.is_set(): + self._stop_sent.set() + self._queue.put(_STOP) + self._stop_loop() + + async def _await_consumer(self) -> None: + if self._consumer_task is not None: + try: + await asyncio.wait_for(self._consumer_task, timeout=5.0) + except Exception: + pass + + def _stop_loop(self) -> None: + if self._loop.is_running(): + self._loop.call_soon_threadsafe(self._loop.stop) + self._thread.join(timeout=5) + if not self._loop.is_closed(): + self._loop.close() + + def __del__(self) -> None: + try: + if not self._shutdown.is_set(): + warnings.warn( + "SyncQuery was not closed. " + "Use 'with SyncQuery(...) as q:' or call q.close() explicitly.", + ResourceWarning, + stacklevel=1, + ) + try: + self.close() + except Exception: + pass + except AttributeError: + pass + + +async def _iterable_to_async( + messages: Iterable[SDKUserMessage], +) -> AsyncIterator[SDKUserMessage]: + for message in messages: + yield message diff --git a/packages/sdk-python/src/qwen_code_sdk/transport.py b/packages/sdk-python/src/qwen_code_sdk/transport.py new file mode 100644 index 000000000..854236494 --- /dev/null +++ b/packages/sdk-python/src/qwen_code_sdk/transport.py @@ -0,0 +1,243 @@ +"""Process transport for qwen CLI stream-json protocol.""" + +from __future__ import annotations + +import asyncio +import json +import os +import subprocess +import sys +from collections.abc import AsyncIterator +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from .errors import ProcessExitError +from .json_lines import parse_json_line +from .types import QueryOptions + + +@dataclass(frozen=True) +class SpawnInfo: + command: str + args: list[str] + + +def prepare_spawn_info(path_to_qwen_executable: str | None) -> SpawnInfo: + if path_to_qwen_executable is None: + return SpawnInfo(command="qwen", args=[]) + + spec = path_to_qwen_executable + if os.path.sep not in spec and ( + os.path.altsep is None or os.path.altsep not in spec + ): + return SpawnInfo(command=spec, args=[]) + + path = Path(spec).expanduser().resolve() + + suffix = path.suffix.lower() + if suffix == ".py": + return SpawnInfo(command=sys.executable, args=[str(path)]) + if suffix in {".js", ".mjs", ".cjs"}: + return SpawnInfo(command="node", args=[str(path)]) + + return SpawnInfo(command=str(path), args=[]) + + +class ProcessTransport: + def __init__(self, options: QueryOptions): + self._options = options + self._process: asyncio.subprocess.Process | None = None + self._stderr_task: asyncio.Task[None] | None = None + self._closed = False + self._input_closed = False + self._exit_error: Exception | None = None + + @property + def exit_error(self) -> Exception | None: + return self._exit_error + + @property + def is_closed(self) -> bool: + return self._closed + + async def start(self) -> None: + if self._closed: + raise RuntimeError("Transport is closed") + if self._process is not None: + return + + spawn_info = prepare_spawn_info(self._options.path_to_qwen_executable) + args = [*spawn_info.args, *build_cli_arguments(self._options)] + stderr_target = ( + asyncio.subprocess.PIPE + if self._options.debug or self._options.stderr is not None + else subprocess.DEVNULL + ) + + self._process = await asyncio.create_subprocess_exec( + spawn_info.command, + *args, + cwd=self._options.cwd, + env={**os.environ, **(self._options.env or {})}, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=stderr_target, + ) + + if self._options.debug or self._options.stderr is not None: + self._stderr_task = asyncio.create_task(self._forward_stderr()) + + async def _forward_stderr(self) -> None: + if self._process is None or self._process.stderr is None: + return + + while True: + chunk = await self._process.stderr.readline() + if not chunk: + return + text = chunk.decode("utf-8", errors="replace").rstrip("\n") + try: + if self._options.stderr is not None: + self._options.stderr(text) + elif self._options.debug: + print(text, file=sys.stderr) + except Exception: + print(text, file=sys.stderr) + + def write(self, data: str) -> None: + if self._closed: + raise RuntimeError("Transport is closed") + if self._process is None or self._process.stdin is None: + raise RuntimeError("Transport is not started") + if self._input_closed: + raise RuntimeError("Transport input is already closed") + + self._process.stdin.write(data.encode("utf-8")) + + async def drain(self) -> None: + if self._process is None or self._process.stdin is None: + return + await self._process.stdin.drain() + + def end_input(self) -> None: + if self._closed or self._input_closed: + return + if self._process is None or self._process.stdin is None: + return + self._process.stdin.close() + self._input_closed = True + + async def read_messages(self) -> AsyncIterator[Any]: + if self._process is None or self._process.stdout is None: + raise RuntimeError("Transport is not started") + + while True: + line = await self._process.stdout.readline() + if not line: + break + + raw = line.decode("utf-8", errors="replace").strip() + if not raw: + continue + try: + yield parse_json_line(raw) + except json.JSONDecodeError: + continue + + await self._finalize_exit() + + async def wait_for_exit(self) -> None: + if self._process is None: + return + await self._finalize_exit() + + async def _finalize_exit(self) -> None: + if self._process is None: + return + + return_code = self._process.returncode + if return_code is None: + return_code = await self._process.wait() + + if return_code != 0 and self._exit_error is None: + self._exit_error = ProcessExitError( + f"CLI process exited with code {return_code}", + exit_code=return_code, + ) + + if self._stderr_task is not None: + await self._stderr_task + self._stderr_task = None + + async def close(self) -> None: + if self._closed: + return + + self._closed = True + + if self._process is None: + return + + if self._process.stdin is not None and not self._input_closed: + self._process.stdin.close() + self._input_closed = True + + if self._process.returncode is None: + self._process.terminate() + try: + await asyncio.wait_for(self._process.wait(), timeout=5.0) + except asyncio.TimeoutError: + self._process.kill() + await self._process.wait() + + await self._finalize_exit() + + +def build_cli_arguments(options: QueryOptions) -> list[str]: + args: list[str] = [ + "--input-format", + "stream-json", + "--output-format", + "stream-json", + "--channel=SDK", + ] + + if options.model: + args.extend(["--model", options.model]) + + if options.system_prompt: + args.extend(["--system-prompt", options.system_prompt]) + + if options.append_system_prompt: + args.extend(["--append-system-prompt", options.append_system_prompt]) + + if options.permission_mode: + args.extend(["--approval-mode", options.permission_mode]) + + if options.max_session_turns is not None: + args.extend(["--max-session-turns", str(options.max_session_turns)]) + + if options.core_tools: + args.extend(["--core-tools", ",".join(options.core_tools)]) + + if options.exclude_tools: + args.extend(["--exclude-tools", ",".join(options.exclude_tools)]) + + if options.allowed_tools: + args.extend(["--allowed-tools", ",".join(options.allowed_tools)]) + + if options.auth_type: + args.extend(["--auth-type", options.auth_type]) + + if options.include_partial_messages: + args.append("--include-partial-messages") + + if options.resume: + args.extend(["--resume", options.resume]) + elif options.continue_session: + args.append("--continue") + elif options.session_id: + args.extend(["--session-id", options.session_id]) + + return args diff --git a/packages/sdk-python/src/qwen_code_sdk/types.py b/packages/sdk-python/src/qwen_code_sdk/types.py new file mode 100644 index 000000000..3d8ec7203 --- /dev/null +++ b/packages/sdk-python/src/qwen_code_sdk/types.py @@ -0,0 +1,323 @@ +"""Public type definitions for qwen_code_sdk.""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Mapping, MutableMapping +from dataclasses import dataclass +from inspect import Parameter, Signature, iscoroutinefunction, signature +from typing import ( + Any, + Literal, + TypeAlias, + TypedDict, + cast, +) + +from typing_extensions import NotRequired + +PermissionMode: TypeAlias = Literal["default", "plan", "auto-edit", "yolo"] +AuthType: TypeAlias = Literal[ + "openai", + "anthropic", + "qwen-oauth", + "gemini", + "vertex-ai", +] + + +class PermissionSuggestion(TypedDict): + type: Literal["allow", "deny", "modify"] + label: str + description: NotRequired[str] + modifiedInput: NotRequired[Any] + + +class PermissionAllowResult(TypedDict): + behavior: Literal["allow"] + updatedInput: NotRequired[dict[str, Any]] + + +class PermissionDenyResult(TypedDict): + behavior: Literal["deny"] + message: NotRequired[str] + interrupt: NotRequired[bool] + + +PermissionResult: TypeAlias = PermissionAllowResult | PermissionDenyResult + + +class CanUseToolContext(TypedDict): + cancel_event: Any + suggestions: list[PermissionSuggestion] | None + blocked_path: str | None + + +CanUseTool: TypeAlias = Callable[ + [str, dict[str, Any], CanUseToolContext], + Awaitable[PermissionResult], +] + + +class TimeoutOptionsDict(TypedDict, total=False): + """Timeout configuration. All values are in seconds.""" + + can_use_tool: float + control_request: float + stream_close: float + + +@dataclass(frozen=True) +class TimeoutOptions: + can_use_tool: float = 60.0 + control_request: float = 60.0 + stream_close: float = 60.0 + + @classmethod + def from_mapping(cls, value: Mapping[str, Any] | None) -> TimeoutOptions: + if value is None: + return cls() + + def _read(name: str, default: float) -> float: + raw = value.get(name, default) + if isinstance(raw, bool) or not isinstance(raw, (int, float)): + raise TypeError(f"timeout.{name} must be a positive number") + if raw <= 0: + raise ValueError(f"timeout.{name} must be a positive number") + return float(raw) + + return cls( + can_use_tool=_read("can_use_tool", 60.0), + control_request=_read("control_request", 60.0), + stream_close=_read("stream_close", 60.0), + ) + + +class QueryOptionsDict(TypedDict, total=False): + cwd: str + model: str + path_to_qwen_executable: str + permission_mode: PermissionMode + can_use_tool: CanUseTool + env: dict[str, str] + system_prompt: str + append_system_prompt: str + debug: bool + max_session_turns: int + core_tools: list[str] + exclude_tools: list[str] + allowed_tools: list[str] + auth_type: AuthType + include_partial_messages: bool + resume: str + continue_session: bool + session_id: str + timeout: TimeoutOptionsDict + mcp_servers: dict[str, dict[str, Any]] + stderr: Callable[[str], None] + + +@dataclass +class QueryOptions: + cwd: str | None = None + model: str | None = None + path_to_qwen_executable: str | None = None + permission_mode: PermissionMode | None = None + can_use_tool: CanUseTool | None = None + env: dict[str, str] | None = None + system_prompt: str | None = None + append_system_prompt: str | None = None + debug: bool = False + max_session_turns: int | None = None + core_tools: list[str] | None = None + exclude_tools: list[str] | None = None + allowed_tools: list[str] | None = None + auth_type: AuthType | None = None + include_partial_messages: bool = False + resume: str | None = None + continue_session: bool = False + session_id: str | None = None + timeout: TimeoutOptions = TimeoutOptions() + mcp_servers: dict[str, dict[str, Any]] | None = None + stderr: Callable[[str], None] | None = None + + @classmethod + def from_mapping(cls, value: Mapping[str, Any] | None) -> QueryOptions: + if value is None: + return cls() + + data: MutableMapping[str, Any] = dict(value) + timeout = TimeoutOptions.from_mapping(data.get("timeout")) + + return cls( + cwd=_as_optional_str(data, "cwd"), + model=_as_optional_str(data, "model"), + path_to_qwen_executable=_as_optional_str(data, "path_to_qwen_executable"), + permission_mode=cast( + PermissionMode | None, + _as_optional_str(data, "permission_mode"), + ), + can_use_tool=cast( + CanUseTool | None, + _as_optional_callable(data, "can_use_tool"), + ), + env=_as_optional_str_dict(data, "env"), + system_prompt=_as_optional_str(data, "system_prompt"), + append_system_prompt=_as_optional_str(data, "append_system_prompt"), + debug=_as_optional_bool(data, "debug") or False, + max_session_turns=_as_optional_int(data, "max_session_turns"), + core_tools=_as_optional_str_list(data, "core_tools"), + exclude_tools=_as_optional_str_list(data, "exclude_tools"), + allowed_tools=_as_optional_str_list(data, "allowed_tools"), + auth_type=cast( + AuthType | None, + _as_optional_str(data, "auth_type"), + ), + include_partial_messages=_as_optional_bool(data, "include_partial_messages") + or False, + resume=_as_optional_str(data, "resume"), + continue_session=_as_optional_bool(data, "continue_session") or False, + session_id=_as_optional_str(data, "session_id"), + timeout=timeout, + mcp_servers=_as_optional_nested_dict(data, "mcp_servers"), + stderr=cast( + Callable[[str], None] | None, + _as_optional_callable(data, "stderr"), + ), + ) + + +def _as_optional_str(data: Mapping[str, Any], key: str) -> str | None: + raw = data.get(key) + if raw is None: + return None + if not isinstance(raw, str): + raise TypeError(f"{key} must be a string") + return raw + + +def _as_optional_int(data: Mapping[str, Any], key: str) -> int | None: + raw = data.get(key) + if raw is None: + return None + if isinstance(raw, bool) or not isinstance(raw, int): + raise TypeError(f"{key} must be an integer") + return int(raw) + + +def _as_optional_bool(data: Mapping[str, Any], key: str) -> bool | None: + raw = data.get(key) + if raw is None: + return None + if not isinstance(raw, bool): + raise TypeError(f"{key} must be a boolean") + return raw + + +def _as_optional_callable( + data: Mapping[str, Any], key: str +) -> Callable[..., Any] | None: + raw = data.get(key) + if raw is None: + return None + if not callable(raw): + raise TypeError(f"{key} must be callable") + if key == "can_use_tool": + _validate_can_use_tool_callable(raw, error_type=TypeError) + elif key == "stderr": + _validate_stderr_callable(raw, error_type=TypeError) + return cast(Callable[..., Any], raw) + + +def _validate_can_use_tool_callable(value: object, error_type: type[Exception]) -> None: + if not callable(value): + raise error_type("can_use_tool must be callable") + + if not iscoroutinefunction(value): + raise error_type("can_use_tool must be an async callable") + + try: + sig = signature(value) + except (TypeError, ValueError): + return + + if not _supports_argument_count(sig, 3): + raise error_type("can_use_tool must accept exactly 3 positional arguments") + + +def _validate_stderr_callable(value: object, error_type: type[Exception]) -> None: + if not callable(value): + raise error_type("stderr must be callable") + + try: + sig = signature(value) + except (TypeError, ValueError): + return + + if not _supports_argument_count(sig, 1): + raise error_type("stderr must accept exactly 1 positional argument") + + +def _supports_argument_count(sig: Signature, count: int) -> bool: + params = list(sig.parameters.values()) + positional_params = [ + param + for param in params + if param.kind + in ( + Parameter.POSITIONAL_ONLY, + Parameter.POSITIONAL_OR_KEYWORD, + ) + ] + required_positional = [ + param for param in positional_params if param.default is Parameter.empty + ] + has_var_positional = any(param.kind is Parameter.VAR_POSITIONAL for param in params) + + if len(required_positional) > count: + return False + if has_var_positional: + return True + return len(positional_params) >= count + + +def _as_optional_str_dict(data: Mapping[str, Any], key: str) -> dict[str, str] | None: + raw = data.get(key) + if raw is None: + return None + if not isinstance(raw, Mapping): + raise TypeError(f"{key} must be a mapping of string to string") + + parsed: dict[str, str] = {} + for k, v in raw.items(): + if not isinstance(k, str) or not isinstance(v, str): + raise TypeError(f"{key} must be a mapping of string to string") + parsed[k] = v + return parsed + + +def _as_optional_str_list(data: Mapping[str, Any], key: str) -> list[str] | None: + raw = data.get(key) + if raw is None: + return None + if not isinstance(raw, list): + raise TypeError(f"{key} must be a list of strings") + if any(not isinstance(item, str) for item in raw): + raise TypeError(f"{key} must be a list of strings") + return list(raw) + + +def _as_optional_nested_dict( + data: Mapping[str, Any], key: str +) -> dict[str, dict[str, Any]] | None: + raw = data.get(key) + if raw is None: + return None + if not isinstance(raw, Mapping): + raise TypeError(f"{key} must be a mapping") + + parsed: dict[str, dict[str, Any]] = {} + for k, v in raw.items(): + if not isinstance(k, str) or not isinstance(v, Mapping): + raise TypeError(f"{key} must be a mapping of string to mapping") + parsed[k] = dict(v) + return parsed diff --git a/packages/sdk-python/src/qwen_code_sdk/validation.py b/packages/sdk-python/src/qwen_code_sdk/validation.py new file mode 100644 index 000000000..f19fe4cfb --- /dev/null +++ b/packages/sdk-python/src/qwen_code_sdk/validation.py @@ -0,0 +1,94 @@ +"""Validation helpers for query options.""" + +from __future__ import annotations + +from collections.abc import Callable +from uuid import RFC_4122, UUID + +from .errors import ValidationError +from .types import ( + QueryOptions, + _validate_can_use_tool_callable, + _validate_stderr_callable, +) + +_VALID_PERMISSION_MODES = {"default", "plan", "auto-edit", "yolo"} +_VALID_AUTH_TYPES = {"openai", "anthropic", "qwen-oauth", "gemini", "vertex-ai"} + + +def validate_query_options(options: QueryOptions) -> None: + if ( + options.permission_mode + and options.permission_mode not in _VALID_PERMISSION_MODES + ): + raise ValidationError( + f"Invalid permission_mode: {options.permission_mode!r}. " + "Expected one of: default, plan, auto-edit, yolo." + ) + + if options.auth_type and options.auth_type not in _VALID_AUTH_TYPES: + raise ValidationError( + f"Invalid auth_type: {options.auth_type!r}. " + "Expected one of: openai, anthropic, qwen-oauth, gemini, vertex-ai." + ) + + _validate_optional_callable(options.can_use_tool, _validate_can_use_tool_callable) + _validate_optional_callable(options.stderr, _validate_stderr_callable) + + if options.resume and options.continue_session: + raise ValidationError( + "Cannot use resume together with continue_session. " + "Use continue_session for latest session " + "or resume for a specific session ID." + ) + + if options.session_id and (options.resume or options.continue_session): + raise ValidationError( + "Cannot use session_id with resume or continue_session. " + "session_id starts a new session, " + "resume/continue_session restore existing sessions." + ) + + if options.session_id: + validate_session_id(options.session_id, "session_id") + + if options.resume: + validate_session_id(options.resume, "resume") + + if options.max_session_turns is not None and options.max_session_turns < -1: + raise ValidationError("max_session_turns must be -1 or a non-negative integer") + + if ( + options.path_to_qwen_executable is not None + and not options.path_to_qwen_executable.strip() + ): + raise ValidationError("path_to_qwen_executable cannot be empty") + + if options.mcp_servers: + raise ValidationError( + "mcp_servers is not supported in Python SDK v1. " + "Remove the mcp_servers option or use the TypeScript SDK." + ) + + +def _validate_optional_callable( + value: object, + validator: Callable[[object, type[ValidationError]], None], +) -> None: + if value is None: + return + validator(value, ValidationError) + + +def validate_session_id(value: str, param_name: str) -> None: + try: + parsed = UUID(value) + except ValueError as exc: + raise ValidationError( + f"Invalid {param_name}: {value!r}. Must be a valid UUID." + ) from exc + + if parsed.variant != RFC_4122: + raise ValidationError( + f"Invalid {param_name}: {value!r}. UUID variant must be RFC 4122." + ) diff --git a/packages/sdk-python/tests/integration/conftest.py b/packages/sdk-python/tests/integration/conftest.py new file mode 100644 index 000000000..f73a5e6ac --- /dev/null +++ b/packages/sdk-python/tests/integration/conftest.py @@ -0,0 +1,400 @@ +from __future__ import annotations + +import os +import stat +import textwrap +from pathlib import Path + +import pytest + + +@pytest.fixture() +def fake_qwen_path(tmp_path: Path) -> str: + script_path = tmp_path / "fake_qwen.py" + script_path.write_text( + textwrap.dedent( + """ + #!/usr/bin/env python3 + import argparse + import json + import sys + import uuid + + + def send(message): + sys.stdout.write(json.dumps(message, separators=(",", ":")) + "\\n") + sys.stdout.flush() + + + def parse_user_content(message): + payload = message.get("message", {}) + content = payload.get("content", "") + if isinstance(content, str): + return content + if isinstance(content, list): + text_parts = [] + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text_parts.append(str(block.get("text", ""))) + return " ".join(text_parts) + return str(content) + + + def build_system_message(): + return { + "type": "system", + "subtype": "init", + "uuid": session_id, + "session_id": session_id, + "cwd": ".", + "tools": ["Read", "Edit", "Bash"], + "mcp_servers": [], + "model": state["model"], + "permission_mode": state["permission_mode"], + "qwen_code_version": "fake-1.0.0", + "capabilities": { + "canSetModel": True, + "canSetPermissionMode": True, + }, + } + + + def build_assistant_message(text): + return { + "type": "assistant", + "uuid": str(uuid.uuid4()), + "session_id": session_id, + "message": { + "id": str(uuid.uuid4()), + "type": "message", + "role": "assistant", + "model": state["model"], + "content": [ + { + "type": "text", + "text": text, + } + ], + "usage": { + "input_tokens": 1, + "output_tokens": 1, + }, + }, + "parent_tool_use_id": None, + } + + + def build_result_message(result_text): + return { + "type": "result", + "subtype": "success", + "uuid": str(uuid.uuid4()), + "session_id": session_id, + "is_error": False, + "duration_ms": 5, + "duration_api_ms": 1, + "num_turns": 1, + "result": result_text, + "usage": { + "input_tokens": 1, + "output_tokens": 1, + }, + "permission_denials": [], + } + + + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--model") + parser.add_argument("--approval-mode") + parser.add_argument("--include-partial-messages", action="store_true") + parser.add_argument("--session-id") + parser.add_argument("--resume") + parser.add_argument( + "--continue", + dest="continue_session", + action="store_true", + ) + args, _ = parser.parse_known_args() + + session_id = ( + args.resume + or args.session_id + or ( + "aaaaaaaa-bbbb-4ccc-8ddd-eeeeeeeeeeee" + if args.continue_session + else str(uuid.uuid4()) + ) + ) + state = { + "model": args.model or "coder-model", + "permission_mode": args.approval_mode or "default", + "include_partial": bool(args.include_partial_messages), + } + + pending_permission = None + pending_unknown_control = None + + for line in sys.stdin: + line = line.strip() + if not line: + continue + message = json.loads(line) + msg_type = message.get("type") + + if msg_type == "control_request": + request_id = message["request_id"] + request = message["request"] + subtype = request.get("subtype") + + if subtype == "initialize": + send( + { + "type": "control_response", + "response": { + "subtype": "success", + "request_id": request_id, + "response": {}, + }, + } + ) + send(build_system_message()) + elif subtype == "set_model": + state["model"] = request["model"] + send( + { + "type": "control_response", + "response": { + "subtype": "success", + "request_id": request_id, + "response": {}, + }, + } + ) + send(build_system_message()) + elif subtype == "set_permission_mode": + state["permission_mode"] = request["mode"] + send( + { + "type": "control_response", + "response": { + "subtype": "success", + "request_id": request_id, + "response": {}, + }, + } + ) + send(build_system_message()) + elif subtype == "interrupt": + send( + { + "type": "control_response", + "response": { + "subtype": "success", + "request_id": request_id, + "response": {}, + }, + } + ) + elif subtype == "supported_commands": + send( + { + "type": "control_response", + "response": { + "subtype": "success", + "request_id": request_id, + "response": { + "commands": [ + "initialize", + "interrupt", + "set_model", + "set_permission_mode", + ] + }, + }, + } + ) + elif subtype == "mcp_server_status": + send( + { + "type": "control_response", + "response": { + "subtype": "success", + "request_id": request_id, + "response": {"servers": []}, + }, + } + ) + else: + send( + { + "type": "control_response", + "response": { + "subtype": "error", + "request_id": request_id, + "error": f"unsupported request: {subtype}", + }, + } + ) + + elif msg_type == "user": + prompt = parse_user_content(message) + + if "exit nonzero" in prompt: + sys.exit(9) + + if "request unknown control" in prompt: + request_id = str(uuid.uuid4()) + pending_unknown_control = { + "request_id": request_id, + "prompt": prompt, + } + send( + { + "type": "control_request", + "request_id": request_id, + "request": { + "subtype": "something_new", + "payload": {}, + }, + } + ) + continue + + if "use tool" in prompt or "create file" in prompt: + tool_use_id = str(uuid.uuid4()) + send( + { + "type": "assistant", + "uuid": str(uuid.uuid4()), + "session_id": session_id, + "message": { + "id": str(uuid.uuid4()), + "type": "message", + "role": "assistant", + "model": state["model"], + "content": [ + { + "type": "tool_use", + "id": tool_use_id, + "name": "write_file", + "input": { + "path": "demo.txt", + "content": "hello", + }, + } + ], + "usage": {"input_tokens": 1, "output_tokens": 1}, + }, + "parent_tool_use_id": None, + } + ) + request_id = str(uuid.uuid4()) + pending_permission = { + "request_id": request_id, + "tool_use_id": tool_use_id, + "prompt": prompt, + } + send( + { + "type": "control_request", + "request_id": request_id, + "request": { + "subtype": "can_use_tool", + "tool_name": "write_file", + "tool_use_id": tool_use_id, + "input": {"path": "demo.txt", "content": "hello"}, + "permission_suggestions": [ + {"type": "allow", "label": "Allow write"} + ], + "blocked_path": None, + }, + } + ) + continue + + if state["include_partial"]: + send( + { + "type": "stream_event", + "uuid": str(uuid.uuid4()), + "session_id": session_id, + "event": { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": "partial"}, + }, + "parent_tool_use_id": None, + } + ) + + send(build_assistant_message(f"Echo: {prompt}")) + send(build_result_message(f"done: {prompt}")) + + elif msg_type == "control_response": + payload = message.get("response", {}) + request_id = payload.get("request_id") + + if ( + pending_unknown_control + and request_id == pending_unknown_control["request_id"] + ): + if payload.get("subtype") != "error": + sys.exit(3) + prompt = pending_unknown_control["prompt"] + pending_unknown_control = None + send( + build_assistant_message( + f"Unknown control handled for: {prompt}" + ) + ) + send(build_result_message(f"unknown-control: {prompt}")) + continue + + if ( + pending_permission + and request_id == pending_permission["request_id"] + ): + prompt = pending_permission["prompt"] + tool_use_id = pending_permission["tool_use_id"] + pending_permission = None + + behavior = "deny" + if payload.get("subtype") == "success": + response_payload = payload.get("response") or {} + behavior = response_payload.get("behavior", "deny") + + is_allowed = behavior == "allow" + send( + { + "type": "user", + "session_id": session_id, + "message": { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tool_use_id, + "is_error": not is_allowed, + "content": "ok" if is_allowed else "denied", + } + ], + }, + "parent_tool_use_id": tool_use_id, + } + ) + send(build_assistant_message(f"tool handled: {prompt}")) + send(build_result_message(f"tool-result: {prompt}")) + continue + """ + ).strip() + + "\n", + encoding="utf-8", + ) + script_path.chmod(script_path.stat().st_mode | stat.S_IEXEC) + return str(script_path) + + +@pytest.fixture(autouse=True) +def disable_history_expansion() -> None: + # No-op fixture used as explicit marker for deterministic test env. + os.environ.setdefault("PYTHONUTF8", "1") diff --git a/packages/sdk-python/tests/integration/test_async_query.py b/packages/sdk-python/tests/integration/test_async_query.py new file mode 100644 index 000000000..3aca1e807 --- /dev/null +++ b/packages/sdk-python/tests/integration/test_async_query.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator, Callable +from typing import Any + +import pytest +from qwen_code_sdk import ( + ProcessExitError, + SDKUserMessage, + is_sdk_assistant_message, + is_sdk_partial_assistant_message, + is_sdk_result_message, + is_sdk_system_message, + is_sdk_user_message, + query, +) + +CONTINUED_SESSION_ID = "aaaaaaaa-bbbb-4ccc-8ddd-eeeeeeeeeeee" +VALID_UUID = "123e4567-e89b-12d3-a456-426614174000" +RESUME_UUID = "223e4567-e89b-12d3-a456-426614174000" + + +async def _collect_messages(result: Any) -> list[dict[str, Any]]: + messages: list[dict[str, Any]] = [] + async for message in result: + messages.append(message) + return messages + + +async def _wait_for(predicate: Callable[[], bool], timeout: float = 2.0) -> None: + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout + while loop.time() < deadline: + if predicate(): + return + await asyncio.sleep(0.01) + raise AssertionError("timed out waiting for expected SDK state") + + +def _tool_result_error_flag(message: dict[str, Any]) -> bool: + content = message["message"]["content"] + assert isinstance(content, list) + return bool(content[0]["is_error"]) + + +@pytest.mark.asyncio +async def test_single_turn_query(fake_qwen_path: str) -> None: + result = query( + "hello world", + { + "path_to_qwen_executable": fake_qwen_path, + }, + ) + messages = await _collect_messages(result) + + assistant = next( + message for message in messages if is_sdk_assistant_message(message) + ) + final = next(message for message in messages if is_sdk_result_message(message)) + + assert assistant["message"]["content"][0]["text"] == "Echo: hello world" + assert final["result"] == "done: hello world" + await result.close() + + +@pytest.mark.asyncio +async def test_include_partial_messages(fake_qwen_path: str) -> None: + result = query( + "stream partial", + { + "path_to_qwen_executable": fake_qwen_path, + "include_partial_messages": True, + }, + ) + messages = await _collect_messages(result) + + partial = next( + message for message in messages if is_sdk_partial_assistant_message(message) + ) + assert partial["event"]["type"] == "content_block_delta" + await result.close() + + +@pytest.mark.asyncio +async def test_default_permission_callback_denies_tool_use(fake_qwen_path: str) -> None: + result = query( + "use tool now", + { + "path_to_qwen_executable": fake_qwen_path, + }, + ) + messages = await _collect_messages(result) + + tool_result = next( + message + for message in messages + if is_sdk_user_message(message) + and isinstance(message["message"]["content"], list) + ) + assert _tool_result_error_flag(tool_result) is True + await result.close() + + +@pytest.mark.asyncio +async def test_permission_callback_can_allow_tool_use(fake_qwen_path: str) -> None: + async def can_use_tool( + tool_name: str, + tool_input: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + assert tool_name == "write_file" + assert tool_input["path"] == "demo.txt" + assert context["suggestions"][0]["type"] == "allow" + return {"behavior": "allow", "updatedInput": tool_input} + + result = query( + "create file with use tool", + { + "path_to_qwen_executable": fake_qwen_path, + "can_use_tool": can_use_tool, + }, + ) + messages = await _collect_messages(result) + + tool_result = next( + message + for message in messages + if is_sdk_user_message(message) + and isinstance(message["message"]["content"], list) + ) + assert _tool_result_error_flag(tool_result) is False + await result.close() + + +@pytest.mark.asyncio +async def test_unknown_control_requests_are_rejected(fake_qwen_path: str) -> None: + result = query( + "request unknown control", + { + "path_to_qwen_executable": fake_qwen_path, + }, + ) + messages = await _collect_messages(result) + + final = next(message for message in messages if is_sdk_result_message(message)) + assert final["result"] == "unknown-control: request unknown control" + await result.close() + + +@pytest.mark.asyncio +async def test_dynamic_controls_and_status(fake_qwen_path: str) -> None: + release_input = asyncio.Event() + + async def prompts() -> AsyncIterator[SDKUserMessage]: + yield { + "type": "user", + "session_id": VALID_UUID, + "message": { + "role": "user", + "content": "first turn", + }, + "parent_tool_use_id": None, + } + await release_input.wait() + + result = query( + prompts(), + { + "path_to_qwen_executable": fake_qwen_path, + "session_id": VALID_UUID, + }, + ) + + messages: list[dict[str, Any]] = [] + + async def consume() -> list[dict[str, Any]]: + async for message in result: + messages.append(message) + return messages + + collector = asyncio.create_task(consume()) + await _wait_for(lambda: any(is_sdk_result_message(message) for message in messages)) + + assert await result.supported_commands() == { + "commands": [ + "initialize", + "interrupt", + "set_model", + "set_permission_mode", + ] + } + assert await result.mcp_server_status() == {"servers": []} + + await result.set_model("new-model") + await result.set_permission_mode("plan") + release_input.set() + await collector + + system_messages = [ + message for message in messages if is_sdk_system_message(message) + ] + assert any(message["model"] == "new-model" for message in system_messages) + assert any(message["permission_mode"] == "plan" for message in system_messages) + await result.close() + + +@pytest.mark.asyncio +async def test_session_id_resume_and_continue(fake_qwen_path: str) -> None: + explicit = query( + "hello explicit", + { + "path_to_qwen_executable": fake_qwen_path, + "session_id": VALID_UUID, + }, + ) + explicit_messages = await _collect_messages(explicit) + assert explicit.get_session_id() == VALID_UUID + assert all(message["session_id"] == VALID_UUID for message in explicit_messages) + await explicit.close() + + resumed = query( + "hello resume", + { + "path_to_qwen_executable": fake_qwen_path, + "resume": RESUME_UUID, + }, + ) + resumed_messages = await _collect_messages(resumed) + assert resumed.get_session_id() == RESUME_UUID + assert all(message["session_id"] == RESUME_UUID for message in resumed_messages) + await resumed.close() + + continued = query( + "hello continue", + { + "path_to_qwen_executable": fake_qwen_path, + "continue_session": True, + }, + ) + continued_messages = await _collect_messages(continued) + assert continued.get_session_id() == CONTINUED_SESSION_ID + assert any( + message["session_id"] == CONTINUED_SESSION_ID for message in continued_messages + ) + await continued.close() + + +@pytest.mark.asyncio +async def test_non_zero_process_exit_is_propagated(fake_qwen_path: str) -> None: + result = query( + "please exit nonzero", + { + "path_to_qwen_executable": fake_qwen_path, + }, + ) + + with pytest.raises(ProcessExitError, match="code 9"): + await _collect_messages(result) + + await result.close() + + +@pytest.mark.asyncio +async def test_async_context_manager(fake_qwen_path: str) -> None: + async with query( + "hello context", + { + "path_to_qwen_executable": fake_qwen_path, + }, + ) as result: + messages = await _collect_messages(result) + + assert result.is_closed() + final = next(m for m in messages if is_sdk_result_message(m)) + assert final["result"] == "done: hello context" diff --git a/packages/sdk-python/tests/integration/test_sync_query.py b/packages/sdk-python/tests/integration/test_sync_query.py new file mode 100644 index 000000000..6f129ccbf --- /dev/null +++ b/packages/sdk-python/tests/integration/test_sync_query.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import threading +import time + +import pytest +import qwen_code_sdk.sync_query as sync_query_module +from qwen_code_sdk import is_sdk_result_message, query_sync +from qwen_code_sdk.sync_query import SyncQuery + + +def test_sync_query_single_turn(fake_qwen_path: str) -> None: + result = query_sync( + "hello sync", + { + "path_to_qwen_executable": fake_qwen_path, + }, + ) + + commands = result.supported_commands() + messages = list(result) + + assert commands["commands"][0] == "initialize" + assert any( + is_sdk_result_message(message) and message["result"] == "done: hello sync" + for message in messages + ) + + result.close() + result.close() + + +def test_sync_query_bootstrap_failure_cleans_up_loop_thread( + monkeypatch: pytest.MonkeyPatch, +) -> None: + def raising_query(*args: object, **kwargs: object) -> object: + raise RuntimeError("bootstrap failed") + + monkeypatch.setattr(sync_query_module, "query", raising_query) + + baseline_threads = { + thread.ident + for thread in threading.enumerate() + if thread.name == "qwen-sdk-sync-loop" + } + + with pytest.raises(RuntimeError, match="bootstrap failed"): + SyncQuery("hello") + + deadline = time.time() + 1.0 + while time.time() < deadline: + active_threads = { + thread.ident + for thread in threading.enumerate() + if thread.name == "qwen-sdk-sync-loop" + } + if active_threads == baseline_threads: + break + time.sleep(0.01) + + active_threads = { + thread.ident + for thread in threading.enumerate() + if thread.name == "qwen-sdk-sync-loop" + } + assert active_threads == baseline_threads + + +def test_sync_query_context_manager(fake_qwen_path: str) -> None: + with query_sync( + "hello context", + { + "path_to_qwen_executable": fake_qwen_path, + }, + ) as result: + messages = list(result) + assert any( + is_sdk_result_message(m) and m["result"] == "done: hello context" + for m in messages + ) + + assert result.is_closed() diff --git a/packages/sdk-python/tests/unit/test_query_core.py b/packages/sdk-python/tests/unit/test_query_core.py new file mode 100644 index 000000000..0dd8f3f62 --- /dev/null +++ b/packages/sdk-python/tests/unit/test_query_core.py @@ -0,0 +1,612 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Callable +from typing import Any, cast + +import pytest +from qwen_code_sdk.errors import AbortError, ControlRequestTimeoutError +from qwen_code_sdk.json_lines import parse_json_line +from qwen_code_sdk.query import Query +from qwen_code_sdk.types import QueryOptions, TimeoutOptions + +VALID_UUID = "123e4567-e89b-12d3-a456-426614174000" +_EOF = object() + + +class FakeTransport: + def __init__(self) -> None: + self.writes: list[dict[str, Any]] = [] + self.exit_error: Exception | None = None + self.closed = False + self.close_calls = 0 + self.input_closed = False + self._queue: asyncio.Queue[dict[str, Any] | object] = asyncio.Queue() + + async def start(self) -> None: + return None + + def write(self, data: str) -> None: + self.writes.append(parse_json_line(data)) + + async def drain(self) -> None: + return None + + def end_input(self) -> None: + self.input_closed = True + + async def read_messages(self): # type: ignore[no-untyped-def] + while True: + item = await self._queue.get() + if item is _EOF: + break + yield item + + async def close(self) -> None: + self.closed = True + self.close_calls += 1 + self.input_closed = True + self._queue.put_nowait(_EOF) + + def push(self, payload: dict[str, Any]) -> None: + self._queue.put_nowait(payload) + + +async def _wait_for(predicate: Callable[[], bool], timeout: float = 1.0) -> None: + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout + while loop.time() < deadline: + if predicate(): + return + await asyncio.sleep(0.01) + raise AssertionError("timed out waiting for test condition") + + +async def _wait_for_request( + transport: FakeTransport, + subtype: str, + timeout: float = 1.0, +) -> dict[str, Any]: + await _wait_for( + lambda: any( + payload.get("type") == "control_request" + and payload.get("request", {}).get("subtype") == subtype + for payload in transport.writes + ), + timeout=timeout, + ) + for payload in transport.writes: + if ( + payload.get("type") == "control_request" + and payload.get("request", {}).get("subtype") == subtype + ): + return payload + raise AssertionError(f"missing control request: {subtype}") + + +async def _wait_for_control_response( + transport: FakeTransport, + request_id: str, + timeout: float = 1.0, +) -> dict[str, Any]: + await _wait_for( + lambda: any( + payload.get("type") == "control_response" + and payload.get("response", {}).get("request_id") == request_id + for payload in transport.writes + ), + timeout=timeout, + ) + for payload in transport.writes: + if ( + payload.get("type") == "control_response" + and payload.get("response", {}).get("request_id") == request_id + ): + return payload + raise AssertionError(f"missing control response: {request_id}") + + +async def _start_query(transport: FakeTransport) -> Query: + query = Query( + transport=transport, # type: ignore[arg-type] + options=QueryOptions( + timeout=TimeoutOptions( + can_use_tool=0.05, + control_request=0.05, + stream_close=0.05, + ) + ), + prompt="hello", + session_id=VALID_UUID, + ) + await query._ensure_started() + + init_request = await _wait_for_request(transport, "initialize") + transport.push( + { + "type": "control_response", + "response": { + "subtype": "success", + "request_id": init_request["request_id"], + "response": {}, + }, + } + ) + await _wait_for( + lambda: any(payload.get("type") == "user" for payload in transport.writes) + ) + return query + + +@pytest.mark.asyncio +async def test_unknown_control_request_returns_error_response() -> None: + transport = FakeTransport() + query = await _start_query(transport) + + transport.push( + { + "type": "control_request", + "request_id": "unknown-1", + "request": { + "subtype": "something_new", + }, + } + ) + + response = await _wait_for_control_response(transport, "unknown-1") + + assert response["response"]["subtype"] == "error" + assert "Unknown control request subtype" in response["response"]["error"] + await query.close() + + +@pytest.mark.asyncio +async def test_control_request_times_out() -> None: + transport = FakeTransport() + query = await _start_query(transport) + + with pytest.raises(ControlRequestTimeoutError, match="supported_commands"): + await query.supported_commands() + + await query.close() + + +@pytest.mark.asyncio +async def test_control_request_cancel_propagates_abort_error() -> None: + transport = FakeTransport() + query = await _start_query(transport) + + task = asyncio.create_task(query.supported_commands()) + request = await _wait_for_request(transport, "supported_commands") + transport.push( + { + "type": "control_cancel_request", + "request_id": request["request_id"], + } + ) + + with pytest.raises(AbortError, match="Control request cancelled"): + await task + + await query.close() + + +@pytest.mark.asyncio +async def test_incoming_control_request_cancel_does_not_block_router() -> None: + transport = FakeTransport() + started = asyncio.Event() + cancelled = asyncio.Event() + captured_cancel_events: list[asyncio.Event] = [] + + async def can_use_tool( + tool_name: str, + tool_input: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + assert tool_name == "write_file" + assert tool_input["path"] == "demo.txt" + cancel_event = cast(asyncio.Event, context["cancel_event"]) + captured_cancel_events.append(cancel_event) + started.set() + try: + await cancel_event.wait() + cancelled.set() + return {"behavior": "deny", "message": "Cancelled"} + except asyncio.CancelledError: + if cancel_event.is_set(): + cancelled.set() + raise + + query = Query( + transport=transport, # type: ignore[arg-type] + options=QueryOptions( + can_use_tool=can_use_tool, + timeout=TimeoutOptions( + can_use_tool=1.0, + control_request=0.2, + stream_close=0.05, + ), + ), + prompt="hello", + session_id=VALID_UUID, + ) + await query._ensure_started() + + init_request = await _wait_for_request(transport, "initialize") + transport.push( + { + "type": "control_response", + "response": { + "subtype": "success", + "request_id": init_request["request_id"], + "response": {}, + }, + } + ) + await _wait_for( + lambda: any(payload.get("type") == "user" for payload in transport.writes) + ) + + transport.push( + { + "type": "control_request", + "request_id": "incoming-1", + "request": { + "subtype": "can_use_tool", + "tool_name": "write_file", + "tool_use_id": "tool-1", + "input": {"path": "demo.txt", "content": "hello"}, + "permission_suggestions": [], + "blocked_path": None, + }, + } + ) + + await _wait_for(lambda: started.is_set()) + assert captured_cancel_events[0] is not query._cancel_event + + supported_commands_task = asyncio.create_task(query.supported_commands()) + supported_request = await _wait_for_request(transport, "supported_commands") + + transport.push( + { + "type": "control_cancel_request", + "request_id": "incoming-1", + } + ) + await _wait_for(lambda: cancelled.is_set()) + + transport.push( + { + "type": "control_response", + "response": { + "subtype": "success", + "request_id": supported_request["request_id"], + "response": {"commands": ["supported_commands"]}, + }, + } + ) + + assert await supported_commands_task == {"commands": ["supported_commands"]} + assert all( + not ( + payload.get("type") == "control_response" + and payload.get("response", {}).get("request_id") == "incoming-1" + ) + for payload in transport.writes + ) + await query.close() + + +@pytest.mark.asyncio +async def test_permission_request_passes_blocked_path_to_callback() -> None: + transport = FakeTransport() + captured_context: dict[str, Any] | None = None + + async def can_use_tool( + tool_name: str, + tool_input: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + nonlocal captured_context + assert tool_name == "write_file" + assert tool_input["path"] == "demo.txt" + captured_context = context + return {"behavior": "deny", "message": "blocked"} + + query = Query( + transport=transport, # type: ignore[arg-type] + options=QueryOptions( + can_use_tool=can_use_tool, + timeout=TimeoutOptions( + can_use_tool=1.0, + control_request=0.2, + stream_close=0.05, + ), + ), + prompt="hello", + session_id=VALID_UUID, + ) + await query._ensure_started() + + init_request = await _wait_for_request(transport, "initialize") + transport.push( + { + "type": "control_response", + "response": { + "subtype": "success", + "request_id": init_request["request_id"], + "response": {}, + }, + } + ) + await _wait_for( + lambda: any(payload.get("type") == "user" for payload in transport.writes) + ) + + transport.push( + { + "type": "control_request", + "request_id": "incoming-2", + "request": { + "subtype": "can_use_tool", + "tool_name": "write_file", + "tool_use_id": "tool-2", + "input": {"path": "demo.txt", "content": "hello"}, + "permission_suggestions": [], + "blocked_path": "/tmp/demo.txt", + }, + } + ) + + response = await _wait_for_control_response(transport, "incoming-2") + + assert captured_context is not None + assert isinstance(captured_context["cancel_event"], asyncio.Event) + assert captured_context["suggestions"] == [] + assert captured_context["blocked_path"] == "/tmp/demo.txt" + assert response["response"]["subtype"] == "success" + assert response["response"]["response"] == { + "behavior": "deny", + "message": "blocked", + } + await query.close() + + +@pytest.mark.asyncio +async def test_permission_request_cancelled_callback_returns_deny() -> None: + transport = FakeTransport() + + async def can_use_tool( + tool_name: str, + tool_input: dict[str, Any], + context: dict[str, Any], + ) -> dict[str, Any]: + assert tool_name == "write_file" + assert tool_input["path"] == "demo.txt" + assert isinstance(context["cancel_event"], asyncio.Event) + raise asyncio.CancelledError() + + query = Query( + transport=transport, # type: ignore[arg-type] + options=QueryOptions( + can_use_tool=can_use_tool, + timeout=TimeoutOptions( + can_use_tool=1.0, + control_request=0.2, + stream_close=0.05, + ), + ), + prompt="hello", + session_id=VALID_UUID, + ) + await query._ensure_started() + + init_request = await _wait_for_request(transport, "initialize") + transport.push( + { + "type": "control_response", + "response": { + "subtype": "success", + "request_id": init_request["request_id"], + "response": {}, + }, + } + ) + await _wait_for( + lambda: any(payload.get("type") == "user" for payload in transport.writes) + ) + + transport.push( + { + "type": "control_request", + "request_id": "incoming-3", + "request": { + "subtype": "can_use_tool", + "tool_name": "write_file", + "tool_use_id": "tool-3", + "input": {"path": "demo.txt", "content": "hello"}, + "permission_suggestions": [], + "blocked_path": None, + }, + } + ) + + response = await _wait_for_control_response(transport, "incoming-3") + + assert response["response"]["subtype"] == "success" + assert response["response"]["response"] == { + "behavior": "deny", + "message": "Permission check failed: callback cancelled", + } + await query.close() + + +@pytest.mark.asyncio +async def test_finish_with_error_closes_transport_and_fails_pending_requests() -> None: + transport = FakeTransport() + query = await _start_query(transport) + + supported_commands_task = asyncio.create_task(query.supported_commands()) + await _wait_for_request(transport, "supported_commands") + + await query._finish_with_error(RuntimeError("boom")) + + with pytest.raises(RuntimeError, match="boom"): + await supported_commands_task + + assert query.is_closed() is True + assert transport.closed is True + + +@pytest.mark.asyncio +async def test_ensure_started_raises_after_close() -> None: + transport = FakeTransport() + query = await _start_query(transport) + + await query.close() + + with pytest.raises(RuntimeError, match="Query is closed"): + await query.supported_commands() + + +@pytest.mark.asyncio +async def test_anext_after_exhaustion_raises_stop_async_iteration() -> None: + """After the async iterator is exhausted, subsequent __anext__ calls must + raise StopAsyncIteration immediately instead of blocking.""" + transport = FakeTransport() + query = await _start_query(transport) + + # Deliver one assistant message, then a result to end the turn. + transport.push( + { + "type": "assistant", + "session_id": VALID_UUID, + "message": { + "role": "assistant", + "content": [{"type": "text", "text": "hi"}], + }, + } + ) + transport.push( + { + "type": "result", + "session_id": VALID_UUID, + "result": "done", + "is_error": False, + "duration_ms": 10, + "duration_api_ms": 5, + "num_turns": 1, + } + ) + # Signal end of transport stream so the router finishes naturally. + await transport.close() + + # Consume all messages until exhaustion. + messages: list[Any] = [] + with pytest.raises(StopAsyncIteration): + while True: + messages.append(await query.__anext__()) + + assert len(messages) >= 1 + + # The iterator is now exhausted — a second call must raise immediately. + with pytest.raises(StopAsyncIteration): + await query.__anext__() + + +@pytest.mark.asyncio +async def test_initialize_failure_no_unhandled_task_exception( + recwarn: pytest.WarningsChecker, +) -> None: + """When _initialize fails, no 'Task exception was never retrieved' warning + should appear — _finish_with_error already surfaces the error.""" + transport = FakeTransport() + query = Query( + transport=transport, # type: ignore[arg-type] + options=QueryOptions( + timeout=TimeoutOptions( + can_use_tool=0.05, + control_request=0.05, + stream_close=0.05, + ) + ), + prompt="hello", + session_id=VALID_UUID, + ) + await query._ensure_started() + + # Let the initialize request time out — this triggers _finish_with_error + # inside _initialize. + init_request = await _wait_for_request(transport, "initialize") + assert init_request is not None # init was sent + + # Don't respond to initialize — let the control-request timeout fire. + # The error propagates through _message_queue. + with pytest.raises(ControlRequestTimeoutError): + await query.__anext__() + + await query.close() + + # Give the event loop a moment to report any unhandled task exceptions. + await asyncio.sleep(0.1) + + # No "Task exception was never retrieved" warnings should have appeared. + task_warnings = [w for w in recwarn.list if "never retrieved" in str(w.message)] + assert task_warnings == [] + + +@pytest.mark.asyncio +async def test_async_context_manager_closes_on_exit() -> None: + transport = FakeTransport() + query = Query( + transport=transport, # type: ignore[arg-type] + options=QueryOptions( + timeout=TimeoutOptions( + can_use_tool=0.05, + control_request=0.05, + stream_close=0.05, + ) + ), + prompt="hello", + session_id=VALID_UUID, + ) + + async with query as q: + assert q is query + assert not q.is_closed() + + assert query.is_closed() is True + + +def test_sync_next_after_exhaustion_raises_stop_iteration() -> None: + """After the sync iterator is exhausted, subsequent __next__ calls must + raise StopIteration immediately instead of blocking on queue.get().""" + from queue import Queue + + from qwen_code_sdk.sync_query import _STOP, SyncQuery + + # Build a minimal SyncQuery without spawning the real event-loop thread. + sq = object.__new__(SyncQuery) + sq._queue = Queue() + sq._exhausted = False + + # Put one message then the sentinel. + msg_payload = { + "type": "assistant", + "message": {"role": "assistant", "content": []}, + } + sq._queue.put(msg_payload) + sq._queue.put(_STOP) + + # First call returns the message. + msg = next(sq) + assert msg["type"] == "assistant" + + # Second call should exhaust. + with pytest.raises(StopIteration): + next(sq) + + # Third call must raise immediately, not block. + with pytest.raises(StopIteration): + next(sq) diff --git a/packages/sdk-python/tests/unit/test_transport.py b/packages/sdk-python/tests/unit/test_transport.py new file mode 100644 index 000000000..340d1c1e6 --- /dev/null +++ b/packages/sdk-python/tests/unit/test_transport.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +import asyncio +import subprocess +import sys +from pathlib import Path +from typing import Any + +import pytest +from qwen_code_sdk.transport import build_cli_arguments, prepare_spawn_info +from qwen_code_sdk.types import QueryOptions, TimeoutOptions + +VALID_UUID = "123e4567-e89b-12d3-a456-426614174000" + + +class DummyProcess: + def __init__(self) -> None: + self.stdin = None + self.stdout = None + self.stderr = None + self.returncode = 0 + + +def test_build_cli_arguments_maps_supported_options() -> None: + args = build_cli_arguments( + QueryOptions( + model="qwen3-coder", + system_prompt="system prompt", + append_system_prompt="append prompt", + permission_mode="auto-edit", + max_session_turns=7, + core_tools=["Read", "Edit"], + exclude_tools=["Bash(rm *)"], + allowed_tools=["Bash(git status)"], + auth_type="openai", + include_partial_messages=True, + session_id=VALID_UUID, + ) + ) + + assert args == [ + "--input-format", + "stream-json", + "--output-format", + "stream-json", + "--channel=SDK", + "--model", + "qwen3-coder", + "--system-prompt", + "system prompt", + "--append-system-prompt", + "append prompt", + "--approval-mode", + "auto-edit", + "--max-session-turns", + "7", + "--core-tools", + "Read,Edit", + "--exclude-tools", + "Bash(rm *)", + "--allowed-tools", + "Bash(git status)", + "--auth-type", + "openai", + "--include-partial-messages", + "--session-id", + VALID_UUID, + ] + + +def test_cli_argument_precedence_prefers_resume_then_continue_then_session_id() -> None: + args = build_cli_arguments( + QueryOptions( + resume=VALID_UUID, + continue_session=True, + session_id="223e4567-e89b-12d3-a456-426614174000", + ) + ) + + assert "--resume" in args + assert "--continue" not in args + assert "--session-id" not in args + + +def test_prepare_spawn_info_uses_runtime_for_python_scripts(tmp_path: Path) -> None: + script_path = tmp_path / "fake-qwen.py" + script_path.write_text("print('ok')\n", encoding="utf-8") + + spawn_info = prepare_spawn_info(str(script_path)) + + assert spawn_info.command == sys.executable + assert spawn_info.args == [str(script_path.resolve())] + + +def test_prepare_spawn_info_uses_node_for_javascript_files(tmp_path: Path) -> None: + script_path = tmp_path / "fake-qwen.js" + script_path.write_text("console.log('ok');\n", encoding="utf-8") + + spawn_info = prepare_spawn_info(str(script_path)) + + assert spawn_info.command == "node" + assert spawn_info.args == [str(script_path.resolve())] + + +def test_prepare_spawn_info_keeps_plain_command_names() -> None: + spawn_info = prepare_spawn_info("qwen-custom") + + assert spawn_info.command == "qwen-custom" + assert spawn_info.args == [] + + +@pytest.mark.asyncio +async def test_transport_discards_stderr_when_debug_is_disabled( + monkeypatch: pytest.MonkeyPatch, +) -> None: + captured: dict[str, Any] = {} + + async def fake_create_subprocess_exec(*args: Any, **kwargs: Any) -> DummyProcess: + captured["args"] = args + captured["kwargs"] = kwargs + return DummyProcess() + + monkeypatch.setattr( + asyncio, + "create_subprocess_exec", + fake_create_subprocess_exec, + ) + + transport_module = __import__( + "qwen_code_sdk.transport", + fromlist=["ProcessTransport"], + ) + transport = transport_module.ProcessTransport( + QueryOptions(timeout=TimeoutOptions()) + ) + + await transport.start() + + assert captured["kwargs"]["stderr"] is subprocess.DEVNULL + + +def test_prepare_spawn_info_defaults_to_qwen_when_none() -> None: + spawn_info = prepare_spawn_info(None) + + assert spawn_info.command == "qwen" + assert spawn_info.args == [] + + +def test_prepare_spawn_info_uses_node_for_mjs_files(tmp_path: Path) -> None: + script_path = tmp_path / "cli.mjs" + script_path.write_text("export default {};\n", encoding="utf-8") + + spawn_info = prepare_spawn_info(str(script_path)) + + assert spawn_info.command == "node" + assert spawn_info.args == [str(script_path.resolve())] + + +def test_prepare_spawn_info_uses_node_for_cjs_files(tmp_path: Path) -> None: + script_path = tmp_path / "cli.cjs" + script_path.write_text("module.exports = {};\n", encoding="utf-8") + + spawn_info = prepare_spawn_info(str(script_path)) + + assert spawn_info.command == "node" + assert spawn_info.args == [str(script_path.resolve())] + + +@pytest.mark.asyncio +async def test_transport_start_raises_after_close( + monkeypatch: pytest.MonkeyPatch, +) -> None: + async def fake_create_subprocess_exec(*args: Any, **kwargs: Any) -> DummyProcess: + return DummyProcess() + + monkeypatch.setattr( + asyncio, + "create_subprocess_exec", + fake_create_subprocess_exec, + ) + + transport_module = __import__( + "qwen_code_sdk.transport", + fromlist=["ProcessTransport"], + ) + transport = transport_module.ProcessTransport( + QueryOptions(timeout=TimeoutOptions()) + ) + transport._closed = True + + with pytest.raises(RuntimeError, match="Transport is closed"): + await transport.start() + + +@pytest.mark.asyncio +async def test_read_messages_skips_malformed_json_lines() -> None: + """Malformed JSON lines should be skipped, not crash the stream.""" + + class FakeStdout: + def __init__(self, lines: list[bytes]) -> None: + self._lines = iter(lines) + + async def readline(self) -> bytes: + return next(self._lines, b"") + + transport_module = __import__( + "qwen_code_sdk.transport", + fromlist=["ProcessTransport"], + ) + transport = transport_module.ProcessTransport( + QueryOptions(timeout=TimeoutOptions()) + ) + + class FakeProcess: + returncode = 0 + stdin = None + stderr = None + + def __init__(self) -> None: + self.stdout = FakeStdout( + [ + b"not valid json\n", + b'{"type":"system","subtype":"init","uuid":"u","session_id":"s"}\n', + b"also bad\n", + b"", + ] + ) + + async def wait(self) -> int: + return 0 + + transport._process = FakeProcess() + + messages: list[Any] = [] + async for msg in transport.read_messages(): + messages.append(msg) + + assert len(messages) == 1 + assert messages[0]["type"] == "system" + + +@pytest.mark.asyncio +async def test_stderr_callback_exceptions_do_not_fail_transport() -> None: + class FakeStdout: + async def readline(self) -> bytes: + return b"" + + class FakeStderr: + def __init__(self) -> None: + self._lines = iter([b"error message\n", b""]) + + async def readline(self) -> bytes: + return next(self._lines, b"") + + transport_module = __import__( + "qwen_code_sdk.transport", + fromlist=["ProcessTransport"], + ) + + callback_calls = 0 + + def stderr_callback(text: str) -> None: + nonlocal callback_calls + callback_calls += 1 + assert text == "error message" + raise RuntimeError("sink failed") + + transport = transport_module.ProcessTransport( + QueryOptions( + stderr=stderr_callback, + timeout=TimeoutOptions(), + ) + ) + + class FakeProcess: + returncode = 0 + stdin = None + + def __init__(self) -> None: + self.stdout = FakeStdout() + self.stderr = FakeStderr() + + async def wait(self) -> int: + return 0 + + transport._process = FakeProcess() + transport._stderr_task = asyncio.create_task(transport._forward_stderr()) + + await transport.wait_for_exit() + + assert callback_calls == 1 diff --git a/packages/sdk-python/tests/unit/test_validation.py b/packages/sdk-python/tests/unit/test_validation.py new file mode 100644 index 000000000..dd85ef2e9 --- /dev/null +++ b/packages/sdk-python/tests/unit/test_validation.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +from typing import Any, cast + +import pytest +from qwen_code_sdk.errors import ValidationError +from qwen_code_sdk.types import QueryOptions, TimeoutOptions +from qwen_code_sdk.validation import validate_query_options + +VALID_UUID = "123e4567-e89b-12d3-a456-426614174000" + + +def test_rejects_resume_with_continue_session() -> None: + with pytest.raises(ValidationError, match="resume together with continue_session"): + validate_query_options( + QueryOptions( + resume=VALID_UUID, + continue_session=True, + ) + ) + + +def test_rejects_session_id_with_resume() -> None: + with pytest.raises(ValidationError, match="Cannot use session_id with resume"): + validate_query_options( + QueryOptions( + session_id=VALID_UUID, + resume="223e4567-e89b-12d3-a456-426614174000", + ) + ) + + +def test_rejects_invalid_session_id() -> None: + with pytest.raises(ValidationError, match="Invalid session_id"): + validate_query_options(QueryOptions(session_id="not-a-uuid")) + + +def test_rejects_invalid_resume() -> None: + with pytest.raises(ValidationError, match="Invalid resume"): + validate_query_options(QueryOptions(resume="not-a-uuid")) + + +def test_rejects_invalid_permission_mode() -> None: + with pytest.raises(ValidationError, match="Invalid permission_mode"): + validate_query_options( + QueryOptions.from_mapping({"permission_mode": "unsafe-mode"}) + ) + + +def test_rejects_invalid_auth_type() -> None: + with pytest.raises(ValidationError, match="Invalid auth_type"): + validate_query_options(QueryOptions.from_mapping({"auth_type": "custom"})) + + +def test_from_mapping_rejects_non_callable_can_use_tool() -> None: + with pytest.raises(TypeError, match="can_use_tool must be callable"): + QueryOptions.from_mapping({"can_use_tool": "bad"}) + + +def test_from_mapping_rejects_non_callable_stderr() -> None: + with pytest.raises(TypeError, match="stderr must be callable"): + QueryOptions.from_mapping({"stderr": "bad"}) + + +def test_validation_rejects_non_callable_can_use_tool() -> None: + with pytest.raises(ValidationError, match="can_use_tool must be callable"): + validate_query_options(QueryOptions(can_use_tool=cast(Any, "bad"))) + + +def test_validation_rejects_non_callable_stderr() -> None: + with pytest.raises(ValidationError, match="stderr must be callable"): + validate_query_options(QueryOptions(stderr=cast(Any, "bad"))) + + +def test_from_mapping_rejects_sync_can_use_tool() -> None: + def can_use_tool( # type: ignore[no-untyped-def] + tool_name, tool_input, context + ): + return {"behavior": "deny", "message": "bad"} + + with pytest.raises(TypeError, match="can_use_tool must be an async callable"): + QueryOptions.from_mapping({"can_use_tool": can_use_tool}) + + +def test_validation_rejects_sync_can_use_tool() -> None: + def can_use_tool( # type: ignore[no-untyped-def] + tool_name, tool_input, context + ): + return {"behavior": "deny", "message": "bad"} + + with pytest.raises(ValidationError, match="can_use_tool must be an async callable"): + validate_query_options(QueryOptions(can_use_tool=cast(Any, can_use_tool))) + + +def test_from_mapping_rejects_can_use_tool_with_wrong_arity() -> None: + async def can_use_tool( + tool_name: str, + tool_input: dict[str, Any], + ) -> dict[str, str]: + return {"behavior": "deny"} + + with pytest.raises( + TypeError, match="can_use_tool must accept exactly 3 positional arguments" + ): + QueryOptions.from_mapping({"can_use_tool": can_use_tool}) + + +def test_validation_rejects_can_use_tool_with_wrong_arity() -> None: + async def can_use_tool( + tool_name: str, + tool_input: dict[str, Any], + ) -> dict[str, str]: + return {"behavior": "deny"} + + with pytest.raises( + ValidationError, + match="can_use_tool must accept exactly 3 positional arguments", + ): + validate_query_options(QueryOptions(can_use_tool=cast(Any, can_use_tool))) + + +def test_from_mapping_rejects_stderr_with_wrong_arity() -> None: + def stderr() -> None: + return None + + with pytest.raises( + TypeError, match="stderr must accept exactly 1 positional argument" + ): + QueryOptions.from_mapping({"stderr": stderr}) + + +def test_validation_rejects_stderr_with_wrong_arity() -> None: + def stderr() -> None: + return None + + with pytest.raises( + ValidationError, match="stderr must accept exactly 1 positional argument" + ): + validate_query_options(QueryOptions(stderr=cast(Any, stderr))) + + +def test_rejects_invalid_max_session_turns() -> None: + with pytest.raises(ValidationError, match="max_session_turns"): + validate_query_options(QueryOptions(max_session_turns=-2)) + + +def test_rejects_empty_qwen_executable_path() -> None: + with pytest.raises( + ValidationError, match="path_to_qwen_executable cannot be empty" + ): + validate_query_options(QueryOptions(path_to_qwen_executable=" ")) + + +def test_timeout_rejects_non_numeric_value() -> None: + with pytest.raises(TypeError, match=r"timeout\.can_use_tool must be a positive"): + TimeoutOptions.from_mapping({"can_use_tool": "fast"}) + + +def test_timeout_rejects_negative_value() -> None: + pattern = r"timeout\.control_request must be a positive" + with pytest.raises(ValueError, match=pattern): + TimeoutOptions.from_mapping({"control_request": -1}) + + +def test_timeout_rejects_boolean_value() -> None: + with pytest.raises(TypeError, match=r"timeout\.stream_close must be a positive"): + TimeoutOptions.from_mapping({"stream_close": True}) + + +def test_rejects_mcp_servers() -> None: + with pytest.raises(ValidationError, match="mcp_servers is not supported"): + validate_query_options( + QueryOptions(mcp_servers={"my-server": {"command": "node", "args": []}}) + )