mirror of
https://github.com/QwenLM/qwen-code.git
synced 2026-04-28 03:30:40 +00:00
* Codex worktree snapshot: startup-cleanup Co-authored-by: Codex * Add Python SDK real smoke test Adds a repository-only real E2E smoke script for the Python SDK, plus npm and developer documentation entry points. Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): address review findings — bugs, type safety, and test coverage - Fix prepare_spawn_info: JS files now use "node" instead of sys.executable - Fix protocol.py: correct total=False misuse on 7 TypedDicts (required fields were optional) - Fix query.py: add _closed guard in _ensure_started, suppress exceptions in close() - Fix sync_query.py: prevent close() deadlock, add context manager, add timeouts - Fix transport.py: handle malformed JSON lines, add _closed guard in start() - Fix validation.py: use uuid.RFC_4122 instead of magic UUID - Fix __init__.py: export TextBlock, widen query_sync signature - Remove dead code: ensure_not_aborted, write_json_line, _thread_error - Add 12 new tests (29 → 41): context managers, JSON skip, closed guards, spawn info, timeouts Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): address wenshao review — session_id, bool validation, debug stderr - Fix continue_session=True generating a wrong random session_id - Add _as_optional_bool helper for strict type validation on bool fields - Default debug stderr to sys.stderr when no custom callback is provided Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): address remaining wenshao review feedback Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * test(cli): harden settings dialog restart prompt test Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): review fixes — UUID compat, stderr fallback, sync cleanup - Remove UUID version restriction to support v6/v7/v8 (RFC 9562) - Always write to sys.stderr when stderr callback raises (was silent when debug=False) - Prevent duplicate _STOP sentinel in SyncQuery.close() via _stop_sent flag - Add ruff format --check to CI workflow - Fix smoke_real.py version guard: fail early before imports instead of NameError - Apply ruff format to existing files Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): remaining review fixes — exit_code attr, guard strictness, sync timeout - Add exit_code attribute to ProcessExitError for programmatic access - Strengthen is_control_response/is_control_cancel guards to require payload fields, preventing misrouting of malformed messages - Expose control_request_timeout property on Query so SyncQuery uses the configured timeout instead of a hardcoded 30s default - Use dataclasses.replace() instead of direct mutation on frozen-style QueryOptions in query() factory - Add ResourceWarning in SyncQuery.__del__ when not properly closed Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): add exit_code default and guard __del__ against partial GC - Give ProcessExitError.exit_code a default value (-1) so user code can construct the exception with just a message string - Wrap SyncQuery.__del__ in try/except AttributeError to prevent crashes when the object is partially garbage-collected Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): review fixes — resource leak, type safety, CI matrix, docs - Fix SyncQuery.__del__ to call close() on GC instead of only warning - Replace hasattr duck-type check with isinstance(prompt, AsyncIterable) - Type-validate permission_mode/auth_type in QueryOptions.from_mapping - Use TypeGuard return types on all is_sdk_*/is_control_* predicates - Add 5s margin to sync wrapper timeouts to prevent error type masking - Expand CI matrix to test Python 3.10, 3.11, 3.12 - Change ProcessExitError.exit_code default from -1 to None - Add stderr to docs QueryOptions listing - Update README sync example to use context manager pattern Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): preserve iterator exhaustion state and suppress detached task warning - Add _exhausted flag to Query.__anext__ and SyncQuery.__next__ so repeated iteration after end-of-stream raises Stop(Async)Iteration instead of blocking forever. - Remove re-raise in _initialize() to prevent asyncio "Task exception was never retrieved" warning on detached tasks; the error is already surfaced via _finish_with_error(). Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): reject mcp_servers at validation time and add iterator/init tests - Reject mcp_servers in validate_query_options() with a clear error instead of advertising MCP support to the CLI and then failing at runtime when mcp_message arrives. - Remove dead mcp_servers branch from _initialize(). - Add tests for async/sync iterator exhaustion, detached init task warning suppression, and mcp_servers validation. Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * fix(sdk-python): fix ruff lint errors in new tests - Use ControlRequestTimeoutError instead of bare Exception (B017) - Fix import sorting for stdlib vs third-party (I001) - Break long line to stay within 88-char limit (E501) Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> * style(sdk-python): apply ruff format to new tests Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com> --------- Co-authored-by: jinye.djy <jinye.djy@alibaba-inc.com> Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
This commit is contained in:
parent
202be6ec7d
commit
e384338145
25 changed files with 4676 additions and 14 deletions
59
.github/workflows/sdk-python.yml
vendored
Normal file
59
.github/workflows/sdk-python.yml
vendored
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
name: 'SDK Python'
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- 'main'
|
||||
- 'release/**'
|
||||
paths:
|
||||
- 'packages/sdk-python/**'
|
||||
- 'docs/developers/sdk-python.md'
|
||||
- 'docs/developers/_meta.ts'
|
||||
- 'README.md'
|
||||
- 'package.json'
|
||||
- '.github/workflows/sdk-python.yml'
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
- 'release/**'
|
||||
paths:
|
||||
- 'packages/sdk-python/**'
|
||||
- 'docs/developers/sdk-python.md'
|
||||
- 'docs/developers/_meta.ts'
|
||||
- 'README.md'
|
||||
- 'package.json'
|
||||
- '.github/workflows/sdk-python.yml'
|
||||
|
||||
jobs:
|
||||
sdk-python:
|
||||
name: 'SDK Python (${{ matrix.python-version }})'
|
||||
runs-on: 'ubuntu-latest'
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ['3.10', '3.11', '3.12']
|
||||
steps:
|
||||
- name: 'Checkout'
|
||||
uses: 'actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8' # ratchet:actions/checkout@v5
|
||||
|
||||
- name: 'Set up Python'
|
||||
uses: 'actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065' # ratchet:actions/setup-python@v5
|
||||
with:
|
||||
python-version: '${{ matrix.python-version }}'
|
||||
|
||||
- name: 'Install SDK test dependencies'
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -e 'packages/sdk-python[dev]'
|
||||
|
||||
- name: 'Run Ruff'
|
||||
run: 'python -m ruff check --config packages/sdk-python/pyproject.toml packages/sdk-python'
|
||||
|
||||
- name: 'Run Ruff Format'
|
||||
run: 'python -m ruff format --check --config packages/sdk-python/pyproject.toml packages/sdk-python'
|
||||
|
||||
- name: 'Run Mypy'
|
||||
run: 'python -m mypy --config-file packages/sdk-python/pyproject.toml packages/sdk-python/src'
|
||||
|
||||
- name: 'Run Pytest'
|
||||
run: 'python -m pytest -c packages/sdk-python/pyproject.toml packages/sdk-python/tests -q'
|
||||
35
README.md
35
README.md
|
|
@ -424,7 +424,7 @@ As an open-source terminal agent, you can use Qwen Code in four primary ways:
|
|||
1. Interactive mode (terminal UI)
|
||||
2. Headless mode (scripts, CI)
|
||||
3. IDE integration (VS Code, Zed)
|
||||
4. TypeScript SDK
|
||||
4. SDKs (TypeScript, Python, Java)
|
||||
|
||||
#### Interactive mode
|
||||
|
||||
|
|
@ -452,11 +452,38 @@ Use Qwen Code inside your editor (VS Code, Zed, and JetBrains IDEs):
|
|||
- [Use in Zed](https://qwenlm.github.io/qwen-code-docs/en/users/integration-zed/)
|
||||
- [Use in JetBrains IDEs](https://qwenlm.github.io/qwen-code-docs/en/users/integration-jetbrains/)
|
||||
|
||||
#### TypeScript SDK
|
||||
#### SDKs
|
||||
|
||||
Build on top of Qwen Code with the TypeScript SDK:
|
||||
Build on top of Qwen Code with the available SDKs:
|
||||
|
||||
- [Use the Qwen Code SDK](./packages/sdk-typescript/README.md)
|
||||
- TypeScript: [Use the Qwen Code SDK](./packages/sdk-typescript/README.md)
|
||||
- Python: [Use the Python SDK](./packages/sdk-python/README.md)
|
||||
- Java: [Use the Java SDK](./packages/sdk-java/qwencode/README.md)
|
||||
|
||||
Python SDK example:
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
|
||||
from qwen_code_sdk import is_sdk_result_message, query
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
result = query(
|
||||
"Summarize the repository layout.",
|
||||
{
|
||||
"cwd": "/path/to/project",
|
||||
"path_to_qwen_executable": "qwen",
|
||||
},
|
||||
)
|
||||
|
||||
async for message in result:
|
||||
if is_sdk_result_message(message):
|
||||
print(message["result"])
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## Commands & Shortcuts
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,8 @@ export default {
|
|||
type: 'separator',
|
||||
},
|
||||
'sdk-typescript': 'Typescript SDK',
|
||||
'sdk-java': 'Java SDK(alpha)',
|
||||
'sdk-python': 'Python SDK (alpha)',
|
||||
'sdk-java': 'Java SDK (alpha)',
|
||||
'Dive Into Qwen Code': {
|
||||
title: 'Dive Into Qwen Code',
|
||||
type: 'separator',
|
||||
|
|
|
|||
168
docs/developers/sdk-python.md
Normal file
168
docs/developers/sdk-python.md
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
# Python SDK
|
||||
|
||||
## `qwen-code-sdk`
|
||||
|
||||
`qwen-code-sdk` is an experimental Python SDK for Qwen Code. v1 targets the
|
||||
existing `stream-json` CLI protocol and keeps the transport surface small and
|
||||
testable.
|
||||
|
||||
## Scope
|
||||
|
||||
- Package name: `qwen-code-sdk`
|
||||
- Import path: `qwen_code_sdk`
|
||||
- Runtime requirement: Python `>=3.10`
|
||||
- CLI dependency: external `qwen` executable is required in v1
|
||||
- Transport scope: process transport only
|
||||
- Not included in v1: ACP transport, SDK-embedded MCP servers
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
pip install qwen-code-sdk
|
||||
```
|
||||
|
||||
If `qwen` is not on `PATH`, pass `path_to_qwen_executable` explicitly.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
|
||||
from qwen_code_sdk import is_sdk_result_message, query
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
result = query(
|
||||
"Explain the repository structure.",
|
||||
{
|
||||
"cwd": "/path/to/project",
|
||||
"path_to_qwen_executable": "qwen",
|
||||
},
|
||||
)
|
||||
|
||||
async for message in result:
|
||||
if is_sdk_result_message(message):
|
||||
print(message["result"])
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## API Surface
|
||||
|
||||
### Top-level entry points
|
||||
|
||||
- `query(prompt, options=None) -> Query`
|
||||
- `query_sync(prompt, options=None) -> SyncQuery`
|
||||
|
||||
`prompt` supports either:
|
||||
|
||||
- `str` for single-turn requests
|
||||
- `AsyncIterable[SDKUserMessage]` for multi-turn streams
|
||||
|
||||
### `Query`
|
||||
|
||||
- Async iterable over SDK messages
|
||||
- `close()`
|
||||
- `interrupt()`
|
||||
- `set_model(model)`
|
||||
- `set_permission_mode(mode)`
|
||||
- `supported_commands()`
|
||||
- `mcp_server_status()`
|
||||
- `get_session_id()`
|
||||
- `is_closed()`
|
||||
|
||||
### `QueryOptions`
|
||||
|
||||
Supported options in v1:
|
||||
|
||||
- `cwd`
|
||||
- `model`
|
||||
- `path_to_qwen_executable`
|
||||
- `permission_mode`
|
||||
- `can_use_tool`
|
||||
- `env`
|
||||
- `system_prompt`
|
||||
- `append_system_prompt`
|
||||
- `debug`
|
||||
- `max_session_turns`
|
||||
- `core_tools`
|
||||
- `exclude_tools`
|
||||
- `allowed_tools`
|
||||
- `auth_type`
|
||||
- `include_partial_messages`
|
||||
- `resume`
|
||||
- `continue_session`
|
||||
- `session_id`
|
||||
- `timeout`
|
||||
- `mcp_servers`
|
||||
- `stderr`
|
||||
|
||||
Session argument priority is fixed as:
|
||||
|
||||
1. `resume`
|
||||
2. `continue_session`
|
||||
3. `session_id`
|
||||
|
||||
## Permission Handling
|
||||
|
||||
When the CLI emits a `can_use_tool` control request, the SDK routes it through
|
||||
`can_use_tool(tool_name, tool_input, context)`.
|
||||
|
||||
- Default behavior: deny
|
||||
- Default timeout: 60 seconds
|
||||
- Timeout fallback: deny
|
||||
- Callback exceptions: converted to deny with an error message
|
||||
- Callback context: `cancel_event`, `suggestions`, and `blocked_path`
|
||||
- Callback contract: `can_use_tool` must be async with 3 positional arguments;
|
||||
`stderr` must accept 1 positional string argument
|
||||
|
||||
## Error Model
|
||||
|
||||
- `ValidationError`: invalid options, invalid UUIDs, unsupported combinations
|
||||
- `ControlRequestTimeoutError`: initialize, interrupt, or other control request
|
||||
timed out
|
||||
- `ProcessExitError`: CLI exited non-zero
|
||||
- `AbortError`: control request or session was cancelled
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If the SDK cannot start the CLI:
|
||||
|
||||
- Verify `qwen --version` works in the target environment
|
||||
- Pass `path_to_qwen_executable` if your shell uses `nvm`, `pyenv`, or other
|
||||
non-standard PATH setup
|
||||
- Use `debug=True` or `stderr=print` to surface CLI stderr while debugging
|
||||
|
||||
If session control calls time out:
|
||||
|
||||
- Check that the target `qwen` version supports `--input-format stream-json`
|
||||
- Increase `timeout.control_request`
|
||||
- Verify that no wrapper script is swallowing stdout/stderr
|
||||
|
||||
## Repository Integration
|
||||
|
||||
Repository-level helper commands:
|
||||
|
||||
- `npm run test:sdk:python`
|
||||
- `npm run lint:sdk:python`
|
||||
- `npm run typecheck:sdk:python`
|
||||
- `npm run smoke:sdk:python -- --qwen qwen`
|
||||
|
||||
## Real E2E Smoke
|
||||
|
||||
For a real runtime check (actual `qwen` process + real model call), run from
|
||||
the repository root. The npm helper uses `python3`, so ensure it resolves to a
|
||||
Python `>=3.10` interpreter:
|
||||
|
||||
```bash
|
||||
npm run smoke:sdk:python -- --qwen qwen
|
||||
```
|
||||
|
||||
This script runs:
|
||||
|
||||
- async single-turn query
|
||||
- async control flow (`supported_commands`, permission mode updates)
|
||||
- sync `query_sync` query
|
||||
|
||||
It prints JSON and returns non-zero on failure.
|
||||
|
|
@ -43,6 +43,7 @@
|
|||
"test:integration:sandbox:podman": "cross-env QWEN_SANDBOX=podman vitest run --root ./integration-tests",
|
||||
"test:integration:sdk:sandbox:none": "cross-env QWEN_SANDBOX=false vitest run --root ./integration-tests --poolOptions.threads.maxThreads 2 sdk-typescript",
|
||||
"test:integration:sdk:sandbox:docker": "cross-env QWEN_SANDBOX=docker npm run build:sandbox && QWEN_SANDBOX=docker vitest run --root ./integration-tests --poolOptions.threads.maxThreads 2 sdk-typescript",
|
||||
"test:sdk:python": "python3 -m pytest -c packages/sdk-python/pyproject.toml packages/sdk-python/tests -q",
|
||||
"test:integration:cli:sandbox:none": "cross-env QWEN_SANDBOX=false vitest run --root ./integration-tests cli",
|
||||
"test:integration:cli:sandbox:docker": "cross-env QWEN_SANDBOX=docker npm run build:sandbox && QWEN_SANDBOX=docker vitest run --root ./integration-tests cli",
|
||||
"test:integration:interactive:sandbox:none": "cross-env QWEN_SANDBOX=false vitest run --root ./integration-tests interactive",
|
||||
|
|
@ -53,9 +54,12 @@
|
|||
"lint": "eslint . --ext .ts,.tsx && eslint integration-tests",
|
||||
"lint:fix": "eslint . --fix && eslint integration-tests --fix",
|
||||
"lint:ci": "eslint . --ext .ts,.tsx --max-warnings 0 && eslint integration-tests --max-warnings 0",
|
||||
"lint:sdk:python": "python3 -m ruff check --config packages/sdk-python/pyproject.toml packages/sdk-python",
|
||||
"lint:all": "node scripts/lint.js",
|
||||
"format": "prettier --experimental-cli --write .",
|
||||
"typecheck": "npm run typecheck --workspaces --if-present",
|
||||
"typecheck:sdk:python": "python3 -m mypy --config-file packages/sdk-python/pyproject.toml packages/sdk-python/src",
|
||||
"smoke:sdk:python": "python3 packages/sdk-python/scripts/smoke_real.py",
|
||||
"check-i18n": "npm run check-i18n --workspace=packages/cli",
|
||||
"preflight": "npm run clean && npm ci && npm run format && npm run lint:ci && npm run build && npm run typecheck && npm run test:ci",
|
||||
"prepare": "husky && npm run build && npm run bundle",
|
||||
|
|
|
|||
|
|
@ -963,11 +963,25 @@ describe('SettingsDialog', () => {
|
|||
</KeypressProvider>,
|
||||
);
|
||||
|
||||
// Trigger a restart-required setting change: navigate to "Language: UI" (2nd item) and toggle it.
|
||||
stdin.write(TerminalKeys.DOWN_ARROW as string);
|
||||
await wait();
|
||||
stdin.write(TerminalKeys.ENTER as string);
|
||||
await wait();
|
||||
await waitFor(() => {
|
||||
expect(lastFrame()).toContain('Tool Approval Mode');
|
||||
});
|
||||
|
||||
const languageIndex = getDialogSettingKeys().indexOf('general.language');
|
||||
expect(languageIndex).toBeGreaterThanOrEqual(0);
|
||||
|
||||
const press = async (key: string) => {
|
||||
act(() => {
|
||||
stdin.write(key);
|
||||
});
|
||||
await wait();
|
||||
};
|
||||
|
||||
// Trigger a restart-required setting change by toggling the UI language setting.
|
||||
for (let i = 0; i < languageIndex; i++) {
|
||||
await press(TerminalKeys.DOWN_ARROW as string);
|
||||
}
|
||||
await press(TerminalKeys.ENTER as string);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(lastFrame()).toContain(
|
||||
|
|
@ -976,10 +990,8 @@ describe('SettingsDialog', () => {
|
|||
});
|
||||
|
||||
// Switch scopes; restart prompt should remain visible.
|
||||
stdin.write(TerminalKeys.TAB as string);
|
||||
await wait();
|
||||
stdin.write('2');
|
||||
await wait();
|
||||
await press(TerminalKeys.TAB as string);
|
||||
await press('2');
|
||||
|
||||
await waitFor(() => {
|
||||
expect(lastFrame()).toContain(
|
||||
|
|
|
|||
117
packages/sdk-python/README.md
Normal file
117
packages/sdk-python/README.md
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
# qwen-code-sdk
|
||||
|
||||
Experimental Python SDK for programmatic access to Qwen Code through the
|
||||
`stream-json` protocol.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install qwen-code-sdk
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python `>=3.10`
|
||||
- External `qwen` CLI installed and available in `PATH`
|
||||
|
||||
You can also point the SDK at an explicit CLI binary or script with
|
||||
`path_to_qwen_executable`.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
|
||||
from qwen_code_sdk import is_sdk_result_message, query
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
result = query(
|
||||
"List the top-level packages in this repository.",
|
||||
{
|
||||
"cwd": "/path/to/project",
|
||||
"path_to_qwen_executable": "qwen",
|
||||
},
|
||||
)
|
||||
|
||||
async for message in result:
|
||||
if is_sdk_result_message(message):
|
||||
print(message["result"])
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## Sync API
|
||||
|
||||
```python
|
||||
from qwen_code_sdk import query_sync
|
||||
|
||||
|
||||
with query_sync(
|
||||
"Say hello",
|
||||
{
|
||||
"path_to_qwen_executable": "qwen",
|
||||
},
|
||||
) as result:
|
||||
for message in result:
|
||||
print(message)
|
||||
```
|
||||
|
||||
## Main APIs
|
||||
|
||||
- `query(prompt, options=None) -> Query`
|
||||
- `query_sync(prompt, options=None) -> SyncQuery`
|
||||
- `Query.close()`, `interrupt()`, `set_model()`, `set_permission_mode()`
|
||||
- `Query.supported_commands()`, `mcp_server_status()`, `get_session_id()`
|
||||
|
||||
`prompt` accepts either a single `str` or an `AsyncIterable[SDKUserMessage]`
|
||||
for multi-turn sessions.
|
||||
|
||||
## Permission Callback
|
||||
|
||||
```python
|
||||
from qwen_code_sdk import query
|
||||
|
||||
|
||||
async def can_use_tool(tool_name, tool_input, context):
|
||||
if tool_name == "write_file":
|
||||
return {"behavior": "deny", "message": "Writes disabled in this app"}
|
||||
return {"behavior": "allow", "updatedInput": tool_input}
|
||||
|
||||
|
||||
result = query(
|
||||
"Create hello.txt",
|
||||
{
|
||||
"path_to_qwen_executable": "qwen",
|
||||
"can_use_tool": can_use_tool,
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
The callback defaults to deny. If it does not return within 60 seconds, the SDK
|
||||
auto-denies the tool request.
|
||||
|
||||
The `context` argument includes `cancel_event`, `suggestions`, and
|
||||
`blocked_path` when the CLI provides a path-specific permission target.
|
||||
`can_use_tool` must be an `async def` callback accepting
|
||||
`(tool_name, tool_input, context)`. `stderr` must accept a single `str`.
|
||||
|
||||
## Errors
|
||||
|
||||
- `ValidationError`: invalid query options or malformed session identifiers
|
||||
- `ControlRequestTimeoutError`: CLI control operation exceeded timeout
|
||||
- `ProcessExitError`: `qwen` exited with a non-zero code
|
||||
- `AbortError`: query or control request was cancelled
|
||||
|
||||
## Current Scope
|
||||
|
||||
`0.1.x` is intentionally narrow:
|
||||
|
||||
- Uses external `qwen` CLI via process transport
|
||||
- Targets `stream-json` parity with the TypeScript SDK core flow
|
||||
- Does not yet implement ACP transport
|
||||
- Does not yet embed MCP servers inside the SDK process
|
||||
|
||||
See [developer documentation](../../docs/developers/sdk-python.md) for more
|
||||
detail.
|
||||
61
packages/sdk-python/pyproject.toml
Normal file
61
packages/sdk-python/pyproject.toml
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
[build-system]
|
||||
requires = ["hatchling>=1.27.0"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "qwen-code-sdk"
|
||||
version = "0.1.0"
|
||||
description = "Python SDK for programmatic access to qwen-code CLI"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
license = { text = "Apache-2.0" }
|
||||
authors = [{ name = "Qwen Team" }]
|
||||
keywords = ["qwen", "qwen-code", "sdk", "python", "ai", "code-assistant"]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
]
|
||||
dependencies = ["typing-extensions>=4.12.0"]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://qwenlm.github.io/qwen-code-docs/"
|
||||
Repository = "https://github.com/QwenLM/qwen-code"
|
||||
Issues = "https://github.com/QwenLM/qwen-code/issues"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"mypy>=1.11.0",
|
||||
"pytest>=8.3.0",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
"ruff>=0.8.0",
|
||||
]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/qwen_code_sdk"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
pythonpath = ["src"]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
target-version = "py310"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "B", "UP", "ASYNC", "RUF"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
strict = true
|
||||
warn_unused_configs = true
|
||||
warn_redundant_casts = true
|
||||
warn_unused_ignores = true
|
||||
show_error_codes = true
|
||||
mypy_path = "src"
|
||||
388
packages/sdk-python/scripts/smoke_real.py
Normal file
388
packages/sdk-python/scripts/smoke_real.py
Normal file
|
|
@ -0,0 +1,388 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Run real end-to-end smoke tests against qwen CLI using qwen_code_sdk.
|
||||
|
||||
This script is intentionally lightweight and avoids any test doubles.
|
||||
It is useful for manual verification after changing SDK runtime behavior.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 10): # noqa: UP036
|
||||
import json
|
||||
|
||||
version = ".".join(str(part) for part in sys.version_info[:3])
|
||||
payload = {
|
||||
"ok": False,
|
||||
"stage": "startup",
|
||||
"error": f"Python >=3.10 is required, current version is {version}",
|
||||
"error_type": "RuntimeError",
|
||||
}
|
||||
print(json.dumps(payload, ensure_ascii=False, indent=2))
|
||||
raise SystemExit(2)
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import subprocess
|
||||
import threading
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from typing import Any, TypeVar
|
||||
|
||||
SDK_ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC_ROOT = SDK_ROOT / "src"
|
||||
if str(SRC_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(SRC_ROOT))
|
||||
|
||||
from qwen_code_sdk import ( # noqa: E402
|
||||
SDKUserMessage,
|
||||
SyncQuery,
|
||||
is_sdk_assistant_message,
|
||||
is_sdk_result_message,
|
||||
is_sdk_system_message,
|
||||
query,
|
||||
query_sync,
|
||||
)
|
||||
from qwen_code_sdk.transport import prepare_spawn_info # noqa: E402
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsyncSingleResult:
|
||||
ok: bool
|
||||
assistant_text: str | None
|
||||
result_text: str | None
|
||||
session_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsyncControlResult:
|
||||
ok: bool
|
||||
supported_commands_type: str
|
||||
saw_system_message: bool
|
||||
saw_result_message: bool
|
||||
session_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncResult:
|
||||
ok: bool
|
||||
saw_result_message: bool
|
||||
result_text: str | None
|
||||
session_id: str
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run real qwen_code_sdk smoke tests using qwen CLI",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--qwen",
|
||||
default="qwen",
|
||||
help="Path or command for qwen executable (default: qwen)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cwd",
|
||||
default=str(Path.cwd()),
|
||||
help="Working directory passed to SDK query options",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default=None,
|
||||
help="Optional model name. If set, script will call set_model(model).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout-seconds",
|
||||
type=float,
|
||||
default=90.0,
|
||||
help="Timeout used for control/callback/stream-close options",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json-only",
|
||||
action="store_true",
|
||||
help="Print only JSON result (no progress logs)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def check_qwen_cli_available(qwen_cmd: str, timeout_seconds: float) -> str:
|
||||
spawn_info = prepare_spawn_info(qwen_cmd)
|
||||
completed = subprocess.run(
|
||||
[spawn_info.command, *spawn_info.args, "--version"],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
return completed.stdout.strip()
|
||||
|
||||
|
||||
def build_options(args: argparse.Namespace) -> dict[str, Any]:
|
||||
return {
|
||||
"cwd": args.cwd,
|
||||
"path_to_qwen_executable": args.qwen,
|
||||
"permission_mode": "yolo",
|
||||
"max_session_turns": 1,
|
||||
"timeout": {
|
||||
"control_request": args.timeout_seconds,
|
||||
"can_use_tool": args.timeout_seconds,
|
||||
"stream_close": args.timeout_seconds,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def extract_assistant_text(message: dict[str, Any]) -> str:
|
||||
content = message["message"].get("content", [])
|
||||
if not isinstance(content, list):
|
||||
return ""
|
||||
|
||||
text_parts: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text_parts.append(str(block.get("text", "")))
|
||||
return "".join(text_parts)
|
||||
|
||||
|
||||
async def run_async_single(args: argparse.Namespace) -> AsyncSingleResult:
|
||||
token = "SDK_REAL_ASYNC_OK"
|
||||
options = build_options(args)
|
||||
q = query(
|
||||
f"Reply exactly with {token}",
|
||||
options,
|
||||
)
|
||||
|
||||
assistant_text: str | None = None
|
||||
result_text: str | None = None
|
||||
try:
|
||||
async for message in q:
|
||||
if is_sdk_assistant_message(message):
|
||||
assistant_text = (assistant_text or "") + extract_assistant_text(
|
||||
message
|
||||
)
|
||||
if is_sdk_result_message(message):
|
||||
result_text = str(message.get("result", ""))
|
||||
finally:
|
||||
await q.close()
|
||||
|
||||
ok = token in (assistant_text or "") and token in (result_text or "")
|
||||
return AsyncSingleResult(
|
||||
ok=ok,
|
||||
assistant_text=assistant_text,
|
||||
result_text=result_text,
|
||||
session_id=q.get_session_id(),
|
||||
)
|
||||
|
||||
|
||||
async def run_async_controls(args: argparse.Namespace) -> AsyncControlResult:
|
||||
token = "SDK_REAL_CONTROL_OK"
|
||||
options = build_options(args)
|
||||
release_prompt = asyncio.Event()
|
||||
|
||||
async def prompts() -> AsyncIterator[SDKUserMessage]:
|
||||
await release_prompt.wait()
|
||||
yield {
|
||||
"type": "user",
|
||||
"session_id": "00000000-0000-4000-8000-000000000001",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": f"Reply exactly with {token}",
|
||||
},
|
||||
"parent_tool_use_id": None,
|
||||
}
|
||||
|
||||
q = query(prompts(), options)
|
||||
|
||||
supported: dict[str, Any] | None = None
|
||||
saw_system_message = False
|
||||
saw_result_message = False
|
||||
try:
|
||||
supported = await q.supported_commands()
|
||||
await q.set_permission_mode("plan")
|
||||
await q.set_permission_mode("yolo")
|
||||
if args.model:
|
||||
await q.set_model(args.model)
|
||||
|
||||
release_prompt.set()
|
||||
async for message in q:
|
||||
if is_sdk_system_message(message):
|
||||
saw_system_message = True
|
||||
if is_sdk_result_message(message):
|
||||
saw_result_message = True
|
||||
break
|
||||
finally:
|
||||
await q.close()
|
||||
|
||||
ok = isinstance(supported, dict) and saw_result_message
|
||||
return AsyncControlResult(
|
||||
ok=ok,
|
||||
supported_commands_type=type(supported).__name__,
|
||||
saw_system_message=saw_system_message,
|
||||
saw_result_message=saw_result_message,
|
||||
session_id=q.get_session_id(),
|
||||
)
|
||||
|
||||
|
||||
async def run_stage(stage: str, coro: Awaitable[T], timeout_seconds: float) -> T:
|
||||
try:
|
||||
return await asyncio.wait_for(coro, timeout=timeout_seconds)
|
||||
except TimeoutError as exc:
|
||||
message = f"{stage} timed out after {timeout_seconds} seconds"
|
||||
raise TimeoutError(message) from exc
|
||||
|
||||
|
||||
def run_sync(
|
||||
args: argparse.Namespace,
|
||||
on_query: Callable[[SyncQuery], None] | None = None,
|
||||
) -> SyncResult:
|
||||
token = "SDK_REAL_SYNC_OK"
|
||||
options = build_options(args)
|
||||
q = query_sync(
|
||||
f"Reply exactly with {token}",
|
||||
options,
|
||||
)
|
||||
if on_query is not None:
|
||||
on_query(q)
|
||||
|
||||
saw_result_message = False
|
||||
result_text: str | None = None
|
||||
try:
|
||||
for message in q:
|
||||
if is_sdk_result_message(message):
|
||||
saw_result_message = True
|
||||
result_text = str(message.get("result", ""))
|
||||
break
|
||||
finally:
|
||||
q.close()
|
||||
|
||||
ok = saw_result_message and token in (result_text or "")
|
||||
return SyncResult(
|
||||
ok=ok,
|
||||
saw_result_message=saw_result_message,
|
||||
result_text=result_text,
|
||||
session_id=q.get_session_id(),
|
||||
)
|
||||
|
||||
|
||||
def run_sync_with_timeout(args: argparse.Namespace) -> SyncResult:
|
||||
result_queue: Queue[SyncResult | BaseException] = Queue(maxsize=1)
|
||||
query_holder: dict[str, SyncQuery] = {}
|
||||
|
||||
def remember_query(q: SyncQuery) -> None:
|
||||
query_holder["query"] = q
|
||||
|
||||
def worker() -> None:
|
||||
try:
|
||||
result_queue.put(run_sync(args, on_query=remember_query))
|
||||
except BaseException as exc:
|
||||
result_queue.put(exc)
|
||||
|
||||
thread = threading.Thread(
|
||||
target=worker,
|
||||
name="qwen-sdk-real-smoke-sync",
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
|
||||
try:
|
||||
item = result_queue.get(timeout=args.timeout_seconds)
|
||||
except Empty as exc:
|
||||
q = query_holder.get("query")
|
||||
if q is not None:
|
||||
q.close()
|
||||
raise TimeoutError(
|
||||
f"sync check timed out after {args.timeout_seconds} seconds"
|
||||
) from exc
|
||||
|
||||
thread.join(timeout=1.0)
|
||||
if isinstance(item, BaseException):
|
||||
raise item
|
||||
return item
|
||||
|
||||
|
||||
def build_failure_payload(
|
||||
*,
|
||||
stage: str,
|
||||
exc: BaseException,
|
||||
qwen_version: str | None = None,
|
||||
completed: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"ok": False,
|
||||
"stage": stage,
|
||||
"error": str(exc),
|
||||
"error_type": type(exc).__name__,
|
||||
}
|
||||
if qwen_version is not None:
|
||||
payload["qwen_version"] = qwen_version
|
||||
if completed:
|
||||
payload["completed"] = completed
|
||||
return payload
|
||||
|
||||
|
||||
async def main() -> int:
|
||||
args = parse_args()
|
||||
|
||||
try:
|
||||
qwen_version = check_qwen_cli_available(args.qwen, args.timeout_seconds)
|
||||
except (subprocess.CalledProcessError, OSError, subprocess.TimeoutExpired) as exc:
|
||||
payload = build_failure_payload(stage="preflight", exc=exc)
|
||||
print(json.dumps(payload, ensure_ascii=False, indent=2))
|
||||
return 2
|
||||
|
||||
stage = "async single-turn check"
|
||||
completed: dict[str, Any] = {}
|
||||
try:
|
||||
if not args.json_only:
|
||||
print(f"[smoke] qwen version: {qwen_version}")
|
||||
print(f"[smoke] running {stage}...")
|
||||
async_single = await run_stage(
|
||||
stage,
|
||||
run_async_single(args),
|
||||
args.timeout_seconds,
|
||||
)
|
||||
completed["async_single"] = asdict(async_single)
|
||||
|
||||
stage = "async control check"
|
||||
if not args.json_only:
|
||||
print(f"[smoke] running {stage}...")
|
||||
async_controls = await run_stage(
|
||||
stage,
|
||||
run_async_controls(args),
|
||||
args.timeout_seconds,
|
||||
)
|
||||
completed["async_controls"] = asdict(async_controls)
|
||||
|
||||
stage = "sync check"
|
||||
if not args.json_only:
|
||||
print(f"[smoke] running {stage}...")
|
||||
sync_result = run_sync_with_timeout(args)
|
||||
completed["sync"] = asdict(sync_result)
|
||||
except Exception as exc:
|
||||
payload = build_failure_payload(
|
||||
stage=stage,
|
||||
exc=exc,
|
||||
qwen_version=qwen_version,
|
||||
completed=completed,
|
||||
)
|
||||
print(json.dumps(payload, ensure_ascii=False, indent=2))
|
||||
return 1
|
||||
|
||||
all_ok = async_single.ok and async_controls.ok and sync_result.ok
|
||||
payload = {
|
||||
"ok": all_ok,
|
||||
"qwen_version": qwen_version,
|
||||
"async_single": asdict(async_single),
|
||||
"async_controls": asdict(async_controls),
|
||||
"sync": asdict(sync_result),
|
||||
}
|
||||
print(json.dumps(payload, ensure_ascii=False, indent=2))
|
||||
return 0 if all_ok else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(asyncio.run(main()))
|
||||
108
packages/sdk-python/src/qwen_code_sdk/__init__.py
Normal file
108
packages/sdk-python/src/qwen_code_sdk/__init__.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
"""qwen_code_sdk package exports."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterable, Iterable, Mapping
|
||||
from typing import Any
|
||||
|
||||
from .errors import (
|
||||
AbortError,
|
||||
ControlRequestTimeoutError,
|
||||
ProcessExitError,
|
||||
QwenSDKError,
|
||||
ValidationError,
|
||||
)
|
||||
from .protocol import (
|
||||
APIAssistantMessage,
|
||||
APIUserMessage,
|
||||
ContentBlock,
|
||||
SDKAssistantMessage,
|
||||
SDKMessage,
|
||||
SDKPartialAssistantMessage,
|
||||
SDKResultMessage,
|
||||
SDKSystemMessage,
|
||||
SDKUserMessage,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
Usage,
|
||||
is_control_cancel,
|
||||
is_control_request,
|
||||
is_control_response,
|
||||
is_sdk_assistant_message,
|
||||
is_sdk_partial_assistant_message,
|
||||
is_sdk_result_message,
|
||||
is_sdk_system_message,
|
||||
is_sdk_user_message,
|
||||
)
|
||||
from .query import Query, query
|
||||
from .sync_query import SyncQuery
|
||||
from .types import (
|
||||
AuthType,
|
||||
CanUseTool,
|
||||
CanUseToolContext,
|
||||
PermissionAllowResult,
|
||||
PermissionDenyResult,
|
||||
PermissionMode,
|
||||
PermissionResult,
|
||||
PermissionSuggestion,
|
||||
QueryOptions,
|
||||
QueryOptionsDict,
|
||||
TimeoutOptions,
|
||||
TimeoutOptionsDict,
|
||||
)
|
||||
|
||||
|
||||
def query_sync(
|
||||
prompt: str | Iterable[SDKUserMessage] | AsyncIterable[SDKUserMessage],
|
||||
options: QueryOptions | QueryOptionsDict | Mapping[str, Any] | None = None,
|
||||
) -> SyncQuery:
|
||||
return SyncQuery(prompt=prompt, options=options)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"APIAssistantMessage",
|
||||
"APIUserMessage",
|
||||
"AbortError",
|
||||
"AuthType",
|
||||
"CanUseTool",
|
||||
"CanUseToolContext",
|
||||
"ContentBlock",
|
||||
"ControlRequestTimeoutError",
|
||||
"PermissionAllowResult",
|
||||
"PermissionDenyResult",
|
||||
"PermissionMode",
|
||||
"PermissionResult",
|
||||
"PermissionSuggestion",
|
||||
"ProcessExitError",
|
||||
"Query",
|
||||
"QueryOptions",
|
||||
"QueryOptionsDict",
|
||||
"QwenSDKError",
|
||||
"SDKAssistantMessage",
|
||||
"SDKMessage",
|
||||
"SDKPartialAssistantMessage",
|
||||
"SDKResultMessage",
|
||||
"SDKSystemMessage",
|
||||
"SDKUserMessage",
|
||||
"SyncQuery",
|
||||
"TextBlock",
|
||||
"ThinkingBlock",
|
||||
"TimeoutOptions",
|
||||
"TimeoutOptionsDict",
|
||||
"ToolResultBlock",
|
||||
"ToolUseBlock",
|
||||
"Usage",
|
||||
"ValidationError",
|
||||
"is_control_cancel",
|
||||
"is_control_request",
|
||||
"is_control_response",
|
||||
"is_sdk_assistant_message",
|
||||
"is_sdk_partial_assistant_message",
|
||||
"is_sdk_result_message",
|
||||
"is_sdk_system_message",
|
||||
"is_sdk_user_message",
|
||||
"query",
|
||||
"query_sync",
|
||||
]
|
||||
27
packages/sdk-python/src/qwen_code_sdk/errors.py
Normal file
27
packages/sdk-python/src/qwen_code_sdk/errors.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""Error types for qwen_code_sdk."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class QwenSDKError(Exception):
|
||||
"""Base error for all SDK failures."""
|
||||
|
||||
|
||||
class ValidationError(QwenSDKError):
|
||||
"""Raised when query options are invalid."""
|
||||
|
||||
|
||||
class AbortError(QwenSDKError):
|
||||
"""Raised when an operation is aborted by caller or transport."""
|
||||
|
||||
|
||||
class ProcessExitError(QwenSDKError):
|
||||
"""Raised when qwen CLI exits with non-zero status or signal."""
|
||||
|
||||
def __init__(self, message: str, exit_code: int | None = None) -> None:
|
||||
super().__init__(message)
|
||||
self.exit_code = exit_code
|
||||
|
||||
|
||||
class ControlRequestTimeoutError(QwenSDKError):
|
||||
"""Raised when a control request times out waiting for response."""
|
||||
14
packages/sdk-python/src/qwen_code_sdk/json_lines.py
Normal file
14
packages/sdk-python/src/qwen_code_sdk/json_lines.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
"""JSON lines utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
def serialize_json_line(payload: Any) -> str:
|
||||
return json.dumps(payload, ensure_ascii=False, separators=(",", ":")) + "\n"
|
||||
|
||||
|
||||
def parse_json_line(line: str) -> Any:
|
||||
return json.loads(line)
|
||||
357
packages/sdk-python/src/qwen_code_sdk/protocol.py
Normal file
357
packages/sdk-python/src/qwen_code_sdk/protocol.py
Normal file
|
|
@ -0,0 +1,357 @@
|
|||
"""Protocol message types and helpers for qwen stream-json."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, TypeAlias, TypeGuard
|
||||
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from .types import PermissionMode, PermissionSuggestion
|
||||
|
||||
|
||||
class Annotation(TypedDict):
|
||||
type: str
|
||||
value: str
|
||||
|
||||
|
||||
class Usage(TypedDict):
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
cache_creation_input_tokens: NotRequired[int]
|
||||
cache_read_input_tokens: NotRequired[int]
|
||||
total_tokens: NotRequired[int]
|
||||
|
||||
|
||||
class ExtendedUsage(Usage, total=False):
|
||||
server_tool_use: dict[str, int]
|
||||
service_tier: str
|
||||
cache_creation: dict[str, int]
|
||||
|
||||
|
||||
class CLIPermissionDenial(TypedDict):
|
||||
tool_name: str
|
||||
tool_use_id: str
|
||||
tool_input: Any
|
||||
|
||||
|
||||
class TextBlock(TypedDict):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
annotations: NotRequired[list[Annotation]]
|
||||
|
||||
|
||||
class ThinkingBlock(TypedDict):
|
||||
type: Literal["thinking"]
|
||||
thinking: str
|
||||
signature: NotRequired[str]
|
||||
annotations: NotRequired[list[Annotation]]
|
||||
|
||||
|
||||
class ToolUseBlock(TypedDict):
|
||||
type: Literal["tool_use"]
|
||||
id: str
|
||||
name: str
|
||||
input: Any
|
||||
annotations: NotRequired[list[Annotation]]
|
||||
|
||||
|
||||
class ToolResultBlock(TypedDict):
|
||||
type: Literal["tool_result"]
|
||||
tool_use_id: str
|
||||
content: NotRequired[str | list[ContentBlock]]
|
||||
is_error: NotRequired[bool]
|
||||
annotations: NotRequired[list[Annotation]]
|
||||
|
||||
|
||||
ContentBlock: TypeAlias = TextBlock | ThinkingBlock | ToolUseBlock | ToolResultBlock
|
||||
|
||||
|
||||
class APIUserMessage(TypedDict):
|
||||
role: Literal["user"]
|
||||
content: str | list[ContentBlock]
|
||||
|
||||
|
||||
class APIAssistantMessage(TypedDict):
|
||||
role: Literal["assistant"]
|
||||
content: list[ContentBlock]
|
||||
id: NotRequired[str]
|
||||
type: NotRequired[Literal["message"]]
|
||||
model: NotRequired[str]
|
||||
stop_reason: NotRequired[str | None]
|
||||
usage: NotRequired[Usage]
|
||||
|
||||
|
||||
class SDKUserMessage(TypedDict):
|
||||
type: Literal["user"]
|
||||
session_id: str
|
||||
message: APIUserMessage
|
||||
parent_tool_use_id: str | None
|
||||
uuid: NotRequired[str]
|
||||
options: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
class SDKAssistantMessage(TypedDict):
|
||||
type: Literal["assistant"]
|
||||
uuid: str
|
||||
session_id: str
|
||||
message: APIAssistantMessage
|
||||
parent_tool_use_id: str | None
|
||||
|
||||
|
||||
class MCPServerState(TypedDict):
|
||||
name: str
|
||||
status: str
|
||||
|
||||
|
||||
class SDKSystemMessage(TypedDict):
|
||||
type: Literal["system"]
|
||||
subtype: str
|
||||
uuid: str
|
||||
session_id: str
|
||||
data: NotRequired[Any]
|
||||
cwd: NotRequired[str]
|
||||
tools: NotRequired[list[str]]
|
||||
mcp_servers: NotRequired[list[MCPServerState]]
|
||||
model: NotRequired[str]
|
||||
permission_mode: NotRequired[str]
|
||||
slash_commands: NotRequired[list[str]]
|
||||
qwen_code_version: NotRequired[str]
|
||||
output_style: NotRequired[str]
|
||||
agents: NotRequired[list[str]]
|
||||
skills: NotRequired[list[str]]
|
||||
capabilities: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
class SDKResultMessageSuccess(TypedDict):
|
||||
type: Literal["result"]
|
||||
subtype: Literal["success"]
|
||||
uuid: str
|
||||
session_id: str
|
||||
is_error: Literal[False]
|
||||
duration_ms: int
|
||||
duration_api_ms: int
|
||||
num_turns: int
|
||||
result: str
|
||||
usage: ExtendedUsage
|
||||
permission_denials: list[CLIPermissionDenial]
|
||||
|
||||
|
||||
class ResultErrorObject(TypedDict):
|
||||
message: str
|
||||
type: NotRequired[str]
|
||||
|
||||
|
||||
class SDKResultMessageError(TypedDict):
|
||||
type: Literal["result"]
|
||||
subtype: Literal["error_max_turns", "error_during_execution"]
|
||||
uuid: str
|
||||
session_id: str
|
||||
is_error: Literal[True]
|
||||
duration_ms: int
|
||||
duration_api_ms: int
|
||||
num_turns: int
|
||||
usage: ExtendedUsage
|
||||
permission_denials: list[CLIPermissionDenial]
|
||||
error: NotRequired[ResultErrorObject]
|
||||
|
||||
|
||||
SDKResultMessage: TypeAlias = SDKResultMessageSuccess | SDKResultMessageError
|
||||
|
||||
|
||||
class MessageStartStreamEvent(TypedDict):
|
||||
type: Literal["message_start"]
|
||||
message: dict[str, Any]
|
||||
|
||||
|
||||
class ContentBlockStartEvent(TypedDict):
|
||||
type: Literal["content_block_start"]
|
||||
index: int
|
||||
content_block: ContentBlock
|
||||
|
||||
|
||||
class ContentBlockDeltaEvent(TypedDict):
|
||||
type: Literal["content_block_delta"]
|
||||
index: int
|
||||
delta: dict[str, Any]
|
||||
|
||||
|
||||
class ContentBlockStopEvent(TypedDict):
|
||||
type: Literal["content_block_stop"]
|
||||
index: int
|
||||
|
||||
|
||||
class MessageStopStreamEvent(TypedDict):
|
||||
type: Literal["message_stop"]
|
||||
|
||||
|
||||
StreamEvent: TypeAlias = (
|
||||
MessageStartStreamEvent
|
||||
| ContentBlockStartEvent
|
||||
| ContentBlockDeltaEvent
|
||||
| ContentBlockStopEvent
|
||||
| MessageStopStreamEvent
|
||||
)
|
||||
|
||||
|
||||
class SDKPartialAssistantMessage(TypedDict):
|
||||
type: Literal["stream_event"]
|
||||
uuid: str
|
||||
session_id: str
|
||||
event: StreamEvent
|
||||
parent_tool_use_id: str | None
|
||||
|
||||
|
||||
class CLIControlInterruptRequest(TypedDict):
|
||||
subtype: Literal["interrupt"]
|
||||
|
||||
|
||||
class CLIControlPermissionRequest(TypedDict):
|
||||
subtype: Literal["can_use_tool"]
|
||||
tool_name: str
|
||||
tool_use_id: str
|
||||
input: Any
|
||||
permission_suggestions: list[PermissionSuggestion] | None
|
||||
blocked_path: str | None
|
||||
|
||||
|
||||
class CLIControlInitializeRequest(TypedDict):
|
||||
subtype: Literal["initialize"]
|
||||
hooks: NotRequired[Any]
|
||||
mcpServers: NotRequired[dict[str, dict[str, Any]]]
|
||||
|
||||
|
||||
class CLIControlSetPermissionModeRequest(TypedDict):
|
||||
subtype: Literal["set_permission_mode"]
|
||||
mode: PermissionMode
|
||||
|
||||
|
||||
class CLIControlSetModelRequest(TypedDict):
|
||||
subtype: Literal["set_model"]
|
||||
model: str
|
||||
|
||||
|
||||
class CLIControlMcpStatusRequest(TypedDict):
|
||||
subtype: Literal["mcp_server_status"]
|
||||
|
||||
|
||||
class CLIControlSupportedCommandsRequest(TypedDict):
|
||||
subtype: Literal["supported_commands"]
|
||||
|
||||
|
||||
ControlRequestPayload: TypeAlias = (
|
||||
CLIControlInterruptRequest
|
||||
| CLIControlPermissionRequest
|
||||
| CLIControlInitializeRequest
|
||||
| CLIControlSetPermissionModeRequest
|
||||
| CLIControlSetModelRequest
|
||||
| CLIControlMcpStatusRequest
|
||||
| CLIControlSupportedCommandsRequest
|
||||
| dict[str, Any]
|
||||
)
|
||||
|
||||
|
||||
class CLIControlRequest(TypedDict):
|
||||
type: Literal["control_request"]
|
||||
request_id: str
|
||||
request: ControlRequestPayload
|
||||
|
||||
|
||||
class ControlResponseSuccess(TypedDict):
|
||||
subtype: Literal["success"]
|
||||
request_id: str
|
||||
response: Any
|
||||
|
||||
|
||||
class ControlResponseError(TypedDict):
|
||||
subtype: Literal["error"]
|
||||
request_id: str
|
||||
error: str | dict[str, Any]
|
||||
|
||||
|
||||
class CLIControlResponse(TypedDict):
|
||||
type: Literal["control_response"]
|
||||
response: ControlResponseSuccess | ControlResponseError
|
||||
|
||||
|
||||
class ControlCancelRequest(TypedDict):
|
||||
type: Literal["control_cancel_request"]
|
||||
request_id: NotRequired[str]
|
||||
|
||||
|
||||
SDKMessage: TypeAlias = (
|
||||
SDKUserMessage
|
||||
| SDKAssistantMessage
|
||||
| SDKSystemMessage
|
||||
| SDKResultMessage
|
||||
| SDKPartialAssistantMessage
|
||||
)
|
||||
|
||||
|
||||
ControlMessage: TypeAlias = (
|
||||
CLIControlRequest | CLIControlResponse | ControlCancelRequest
|
||||
)
|
||||
|
||||
|
||||
def is_sdk_user_message(msg: Any) -> TypeGuard[SDKUserMessage]:
|
||||
return isinstance(msg, dict) and msg.get("type") == "user" and "message" in msg
|
||||
|
||||
|
||||
def is_sdk_assistant_message(msg: Any) -> TypeGuard[SDKAssistantMessage]:
|
||||
return (
|
||||
isinstance(msg, dict)
|
||||
and msg.get("type") == "assistant"
|
||||
and "session_id" in msg
|
||||
and "message" in msg
|
||||
)
|
||||
|
||||
|
||||
def is_sdk_system_message(msg: Any) -> TypeGuard[SDKSystemMessage]:
|
||||
return (
|
||||
isinstance(msg, dict)
|
||||
and msg.get("type") == "system"
|
||||
and "subtype" in msg
|
||||
and "session_id" in msg
|
||||
)
|
||||
|
||||
|
||||
def is_sdk_result_message(msg: Any) -> TypeGuard[SDKResultMessage]:
|
||||
return (
|
||||
isinstance(msg, dict)
|
||||
and msg.get("type") == "result"
|
||||
and "subtype" in msg
|
||||
and "session_id" in msg
|
||||
)
|
||||
|
||||
|
||||
def is_sdk_partial_assistant_message(msg: Any) -> TypeGuard[SDKPartialAssistantMessage]:
|
||||
return (
|
||||
isinstance(msg, dict)
|
||||
and msg.get("type") == "stream_event"
|
||||
and "session_id" in msg
|
||||
and "event" in msg
|
||||
)
|
||||
|
||||
|
||||
def is_control_request(msg: Any) -> TypeGuard[CLIControlRequest]:
|
||||
return (
|
||||
isinstance(msg, dict)
|
||||
and msg.get("type") == "control_request"
|
||||
and "request_id" in msg
|
||||
and "request" in msg
|
||||
)
|
||||
|
||||
|
||||
def is_control_response(msg: Any) -> TypeGuard[CLIControlResponse]:
|
||||
return (
|
||||
isinstance(msg, dict)
|
||||
and msg.get("type") == "control_response"
|
||||
and "response" in msg
|
||||
)
|
||||
|
||||
|
||||
def is_control_cancel(msg: Any) -> TypeGuard[ControlCancelRequest]:
|
||||
return (
|
||||
isinstance(msg, dict)
|
||||
and msg.get("type") == "control_cancel_request"
|
||||
and "request_id" in msg
|
||||
)
|
||||
0
packages/sdk-python/src/qwen_code_sdk/py.typed
Normal file
0
packages/sdk-python/src/qwen_code_sdk/py.typed
Normal file
607
packages/sdk-python/src/qwen_code_sdk/query.py
Normal file
607
packages/sdk-python/src/qwen_code_sdk/query.py
Normal file
|
|
@ -0,0 +1,607 @@
|
|||
"""Async Query implementation for qwen_code_sdk."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from collections.abc import AsyncIterable, Mapping, MutableMapping
|
||||
from dataclasses import dataclass, replace
|
||||
from types import TracebackType
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from .errors import AbortError, ControlRequestTimeoutError
|
||||
from .json_lines import serialize_json_line
|
||||
from .protocol import (
|
||||
CLIControlRequest,
|
||||
CLIControlResponse,
|
||||
SDKMessage,
|
||||
SDKUserMessage,
|
||||
is_control_cancel,
|
||||
is_control_request,
|
||||
is_control_response,
|
||||
is_sdk_assistant_message,
|
||||
is_sdk_partial_assistant_message,
|
||||
is_sdk_result_message,
|
||||
is_sdk_system_message,
|
||||
is_sdk_user_message,
|
||||
)
|
||||
from .transport import ProcessTransport
|
||||
from .types import (
|
||||
CanUseToolContext,
|
||||
PermissionDenyResult,
|
||||
QueryOptions,
|
||||
QueryOptionsDict,
|
||||
)
|
||||
from .validation import validate_query_options
|
||||
|
||||
_DONE = object()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PendingControlRequest:
|
||||
future: asyncio.Future[dict[str, Any] | None]
|
||||
cancel_event: asyncio.Event
|
||||
timeout_handle: asyncio.TimerHandle
|
||||
|
||||
|
||||
@dataclass
|
||||
class _IncomingControlRequest:
|
||||
task: asyncio.Task[None]
|
||||
cancel_event: asyncio.Event
|
||||
|
||||
|
||||
class Query:
|
||||
def __init__(
|
||||
self,
|
||||
transport: ProcessTransport,
|
||||
options: QueryOptions,
|
||||
prompt: str | AsyncIterable[SDKUserMessage],
|
||||
session_id: str,
|
||||
) -> None:
|
||||
self._transport = transport
|
||||
self._options = options
|
||||
self._prompt = prompt
|
||||
self._single_turn = isinstance(prompt, str)
|
||||
self._session_id = session_id
|
||||
self._session_id_locked = bool(options.resume or options.session_id)
|
||||
|
||||
self._message_queue: asyncio.Queue[SDKMessage | Exception | object] = (
|
||||
asyncio.Queue()
|
||||
)
|
||||
self._closed = False
|
||||
self._started = False
|
||||
self._start_lock = asyncio.Lock()
|
||||
self._cancel_event = asyncio.Event()
|
||||
|
||||
self._router_task: asyncio.Task[None] | None = None
|
||||
self._input_task: asyncio.Task[None] | None = None
|
||||
self._initialize_task: asyncio.Task[None] | None = None
|
||||
self._first_result_event = asyncio.Event()
|
||||
self._terminal_event_sent = False
|
||||
self._exhausted = False
|
||||
|
||||
self._pending_control_requests: dict[str, _PendingControlRequest] = {}
|
||||
self._incoming_control_requests: dict[str, _IncomingControlRequest] = {}
|
||||
|
||||
async def _ensure_started(self) -> None:
|
||||
if self._closed:
|
||||
raise RuntimeError("Query is closed")
|
||||
if self._started:
|
||||
return
|
||||
|
||||
async with self._start_lock:
|
||||
if self._closed:
|
||||
raise RuntimeError("Query is closed")
|
||||
if self._started:
|
||||
return
|
||||
await self._transport.start()
|
||||
self._router_task = asyncio.create_task(self._message_router())
|
||||
self._initialize_task = asyncio.create_task(self._initialize())
|
||||
|
||||
if self._single_turn:
|
||||
self._input_task = asyncio.create_task(self._send_single_turn_prompt())
|
||||
else:
|
||||
self._input_task = asyncio.create_task(
|
||||
self.stream_input(self._prompt) # type: ignore[arg-type]
|
||||
)
|
||||
self._started = True
|
||||
|
||||
async def _initialize(self) -> None:
|
||||
try:
|
||||
payload: dict[str, Any] = {"hooks": None}
|
||||
await self._send_control_request("initialize", payload)
|
||||
except Exception as exc:
|
||||
await self._finish_with_error(exc)
|
||||
|
||||
async def _send_single_turn_prompt(self) -> None:
|
||||
try:
|
||||
assert isinstance(self._prompt, str)
|
||||
await self._wait_initialized()
|
||||
message: SDKUserMessage = {
|
||||
"type": "user",
|
||||
"session_id": self._session_id,
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": self._prompt,
|
||||
},
|
||||
"parent_tool_use_id": None,
|
||||
}
|
||||
|
||||
await self._write_payload(message)
|
||||
except Exception as exc:
|
||||
await self._finish_with_error(exc)
|
||||
raise
|
||||
|
||||
async def _wait_initialized(self) -> None:
|
||||
if self._initialize_task is None:
|
||||
return
|
||||
await self._initialize_task
|
||||
|
||||
async def _message_router(self) -> None:
|
||||
try:
|
||||
async for message in self._transport.read_messages():
|
||||
await self._route_message(message)
|
||||
if self._closed:
|
||||
break
|
||||
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
if self._transport.exit_error is not None:
|
||||
await self._finish_with_error(self._transport.exit_error)
|
||||
return
|
||||
|
||||
await self._finish()
|
||||
except Exception as exc: # pragma: no cover - critical propagation path
|
||||
await self._finish_with_error(exc)
|
||||
|
||||
async def _route_message(self, message: Any) -> None:
|
||||
self._maybe_update_session_id(message)
|
||||
|
||||
if is_control_request(message):
|
||||
self._start_incoming_control_request(message)
|
||||
return
|
||||
|
||||
if is_control_response(message):
|
||||
self._handle_control_response(message)
|
||||
return
|
||||
|
||||
if is_control_cancel(message):
|
||||
self._handle_control_cancel_request(message)
|
||||
return
|
||||
|
||||
if is_sdk_result_message(message):
|
||||
self._first_result_event.set()
|
||||
if self._single_turn:
|
||||
self._transport.end_input()
|
||||
await self._message_queue.put(message)
|
||||
return
|
||||
|
||||
if (
|
||||
is_sdk_system_message(message)
|
||||
or is_sdk_assistant_message(message)
|
||||
or is_sdk_user_message(message)
|
||||
or is_sdk_partial_assistant_message(message)
|
||||
):
|
||||
await self._message_queue.put(message)
|
||||
return
|
||||
|
||||
def _maybe_update_session_id(self, message: Any) -> None:
|
||||
if self._session_id_locked or not isinstance(message, Mapping):
|
||||
return
|
||||
|
||||
session_id = message.get("session_id")
|
||||
if isinstance(session_id, str) and session_id:
|
||||
self._session_id = session_id
|
||||
self._session_id_locked = True
|
||||
|
||||
def _start_incoming_control_request(self, request: CLIControlRequest) -> None:
|
||||
request_id = request["request_id"]
|
||||
cancel_event = asyncio.Event()
|
||||
|
||||
async def runner() -> None:
|
||||
try:
|
||||
await self._handle_control_request(request, cancel_event)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc: # pragma: no cover - fatal background path
|
||||
await self._finish_with_error(exc)
|
||||
finally:
|
||||
self._incoming_control_requests.pop(request_id, None)
|
||||
|
||||
task = asyncio.create_task(runner())
|
||||
self._incoming_control_requests[request_id] = _IncomingControlRequest(
|
||||
task=task,
|
||||
cancel_event=cancel_event,
|
||||
)
|
||||
|
||||
async def _handle_control_request(
|
||||
self,
|
||||
request: CLIControlRequest,
|
||||
cancel_event: asyncio.Event,
|
||||
) -> None:
|
||||
request_id = request["request_id"]
|
||||
payload = request["request"]
|
||||
subtype = payload.get("subtype")
|
||||
|
||||
try:
|
||||
if subtype == "can_use_tool":
|
||||
response = await self._handle_permission_request(
|
||||
cast(MutableMapping[str, Any], payload),
|
||||
cancel_event,
|
||||
)
|
||||
elif subtype == "mcp_message":
|
||||
raise RuntimeError("mcp_message is unsupported in python sdk v1")
|
||||
else:
|
||||
raise RuntimeError(f"Unknown control request subtype: {subtype}")
|
||||
|
||||
if cancel_event.is_set():
|
||||
return
|
||||
|
||||
await self._send_control_response(
|
||||
request_id, success=True, response=response
|
||||
)
|
||||
except Exception as exc:
|
||||
if cancel_event.is_set():
|
||||
return
|
||||
await self._send_control_response(
|
||||
request_id,
|
||||
success=False,
|
||||
response=str(exc),
|
||||
)
|
||||
|
||||
async def _handle_permission_request(
|
||||
self,
|
||||
payload: MutableMapping[str, Any],
|
||||
cancel_event: asyncio.Event,
|
||||
) -> dict[str, Any]:
|
||||
tool_name = str(payload.get("tool_name", ""))
|
||||
tool_input = payload.get("input")
|
||||
if not isinstance(tool_input, dict):
|
||||
tool_input = {}
|
||||
|
||||
if self._options.can_use_tool is None:
|
||||
return {"behavior": "deny", "message": "Denied"}
|
||||
|
||||
context: CanUseToolContext = {
|
||||
"cancel_event": cancel_event,
|
||||
"suggestions": payload.get("permission_suggestions"),
|
||||
"blocked_path": payload.get("blocked_path"),
|
||||
}
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._options.can_use_tool(tool_name, tool_input, context),
|
||||
timeout=self._options.timeout.can_use_tool,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
return {
|
||||
"behavior": "deny",
|
||||
"message": "Permission request timed out",
|
||||
}
|
||||
except asyncio.CancelledError:
|
||||
if cancel_event.is_set():
|
||||
raise
|
||||
return {
|
||||
"behavior": "deny",
|
||||
"message": "Permission check failed: callback cancelled",
|
||||
}
|
||||
except Exception as exc:
|
||||
return {
|
||||
"behavior": "deny",
|
||||
"message": f"Permission check failed: {exc}",
|
||||
}
|
||||
|
||||
behavior = result.get("behavior")
|
||||
if behavior == "allow":
|
||||
return {
|
||||
"behavior": "allow",
|
||||
"updatedInput": result.get("updatedInput", tool_input),
|
||||
}
|
||||
|
||||
deny_result = cast(PermissionDenyResult, result)
|
||||
return {
|
||||
"behavior": "deny",
|
||||
"message": deny_result.get("message", "Denied"),
|
||||
**(
|
||||
{"interrupt": deny_result["interrupt"]}
|
||||
if "interrupt" in deny_result
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
def _handle_control_response(self, response: CLIControlResponse) -> None:
|
||||
payload = response["response"]
|
||||
request_id = payload["request_id"]
|
||||
|
||||
pending = self._pending_control_requests.pop(request_id, None)
|
||||
if pending is None:
|
||||
return
|
||||
|
||||
pending.timeout_handle.cancel()
|
||||
|
||||
if payload["subtype"] == "success":
|
||||
if not pending.future.done():
|
||||
pending.future.set_result(payload.get("response"))
|
||||
else:
|
||||
error = payload.get("error", "Unknown control error")
|
||||
if isinstance(error, dict):
|
||||
error_message = str(error.get("message", "Unknown control error"))
|
||||
else:
|
||||
error_message = str(error)
|
||||
if not pending.future.done():
|
||||
pending.future.set_exception(RuntimeError(error_message))
|
||||
|
||||
def _handle_control_cancel_request(self, message: Mapping[str, Any]) -> None:
|
||||
request_id = message.get("request_id")
|
||||
if not isinstance(request_id, str):
|
||||
return
|
||||
|
||||
pending = self._pending_control_requests.pop(request_id, None)
|
||||
if pending is not None:
|
||||
pending.timeout_handle.cancel()
|
||||
pending.cancel_event.set()
|
||||
if not pending.future.done():
|
||||
pending.future.set_exception(AbortError("Control request cancelled"))
|
||||
|
||||
incoming = self._incoming_control_requests.get(request_id)
|
||||
if incoming is None:
|
||||
return
|
||||
|
||||
incoming.cancel_event.set()
|
||||
incoming.task.cancel()
|
||||
|
||||
async def _send_control_request(
|
||||
self,
|
||||
subtype: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
if self._closed:
|
||||
raise RuntimeError("Query is closed")
|
||||
|
||||
if subtype != "initialize":
|
||||
await self._wait_initialized()
|
||||
|
||||
request_id = str(uuid4())
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
future: asyncio.Future[dict[str, Any] | None] = loop.create_future()
|
||||
cancel_event = asyncio.Event()
|
||||
|
||||
def on_timeout() -> None:
|
||||
pending = self._pending_control_requests.pop(request_id, None)
|
||||
if pending is None:
|
||||
return
|
||||
pending.cancel_event.set()
|
||||
if not pending.future.done():
|
||||
pending.future.set_exception(
|
||||
ControlRequestTimeoutError(f"Control request timeout: {subtype}")
|
||||
)
|
||||
|
||||
timeout_handle = loop.call_later(
|
||||
self._options.timeout.control_request,
|
||||
on_timeout,
|
||||
)
|
||||
|
||||
self._pending_control_requests[request_id] = _PendingControlRequest(
|
||||
future=future,
|
||||
cancel_event=cancel_event,
|
||||
timeout_handle=timeout_handle,
|
||||
)
|
||||
|
||||
request_payload: dict[str, Any] = {"subtype": subtype}
|
||||
if data:
|
||||
request_payload.update(data)
|
||||
|
||||
payload: CLIControlRequest = {
|
||||
"type": "control_request",
|
||||
"request_id": request_id,
|
||||
"request": request_payload,
|
||||
}
|
||||
|
||||
await self._write_payload(payload)
|
||||
return await future
|
||||
|
||||
async def _send_control_response(
|
||||
self,
|
||||
request_id: str,
|
||||
*,
|
||||
success: bool,
|
||||
response: Any,
|
||||
) -> None:
|
||||
payload: CLIControlResponse
|
||||
if success:
|
||||
payload = {
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": request_id,
|
||||
"response": response,
|
||||
},
|
||||
}
|
||||
else:
|
||||
payload = {
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "error",
|
||||
"request_id": request_id,
|
||||
"error": str(response),
|
||||
},
|
||||
}
|
||||
|
||||
await self._write_payload(payload)
|
||||
|
||||
async def _write_payload(self, payload: Any) -> None:
|
||||
self._transport.write(serialize_json_line(payload))
|
||||
await self._transport.drain()
|
||||
|
||||
async def stream_input(self, messages: AsyncIterable[SDKUserMessage]) -> None:
|
||||
try:
|
||||
if self._closed:
|
||||
raise RuntimeError("Query is closed")
|
||||
|
||||
await self._wait_initialized()
|
||||
|
||||
async for message in messages:
|
||||
if self._cancel_event.is_set() or self._closed:
|
||||
break
|
||||
await self._write_payload(message)
|
||||
|
||||
if not self._single_turn:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._first_result_event.wait(),
|
||||
timeout=self._options.timeout.stream_close,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
self._transport.end_input()
|
||||
except Exception as exc:
|
||||
await self._finish_with_error(exc)
|
||||
raise
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
await self._ensure_started()
|
||||
await self._send_control_request("interrupt")
|
||||
|
||||
async def set_permission_mode(self, mode: str) -> None:
|
||||
await self._ensure_started()
|
||||
await self._send_control_request("set_permission_mode", {"mode": mode})
|
||||
|
||||
async def set_model(self, model: str) -> None:
|
||||
await self._ensure_started()
|
||||
await self._send_control_request("set_model", {"model": model})
|
||||
|
||||
async def supported_commands(self) -> dict[str, Any] | None:
|
||||
await self._ensure_started()
|
||||
return await self._send_control_request("supported_commands")
|
||||
|
||||
async def mcp_server_status(self) -> dict[str, Any] | None:
|
||||
await self._ensure_started()
|
||||
return await self._send_control_request("mcp_server_status")
|
||||
|
||||
@property
|
||||
def control_request_timeout(self) -> float:
|
||||
return self._options.timeout.control_request
|
||||
|
||||
def get_session_id(self) -> str:
|
||||
return self._session_id
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
def _fail_pending_control_requests(self, error: Exception) -> None:
|
||||
for request_id, pending in list(self._pending_control_requests.items()):
|
||||
pending.timeout_handle.cancel()
|
||||
pending.cancel_event.set()
|
||||
if not pending.future.done():
|
||||
pending.future.set_exception(error)
|
||||
self._pending_control_requests.pop(request_id, None)
|
||||
|
||||
async def _cancel_incoming_control_requests(self) -> None:
|
||||
current_task = asyncio.current_task()
|
||||
tasks: list[asyncio.Task[None]] = []
|
||||
|
||||
for incoming in list(self._incoming_control_requests.values()):
|
||||
incoming.cancel_event.set()
|
||||
if incoming.task is current_task:
|
||||
continue
|
||||
incoming.task.cancel()
|
||||
tasks.append(incoming.task)
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
self._cancel_event.set()
|
||||
|
||||
error = RuntimeError("Query is closed")
|
||||
self._fail_pending_control_requests(error)
|
||||
await self._cancel_incoming_control_requests()
|
||||
|
||||
await self._transport.close()
|
||||
|
||||
if self._input_task is not None:
|
||||
self._input_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError, Exception):
|
||||
await self._input_task
|
||||
|
||||
if self._router_task is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
await self._router_task
|
||||
|
||||
await self._finish()
|
||||
|
||||
async def _finish(self) -> None:
|
||||
if self._terminal_event_sent:
|
||||
return
|
||||
self._terminal_event_sent = True
|
||||
await self._message_queue.put(_DONE)
|
||||
|
||||
async def _finish_with_error(self, exc: Exception) -> None:
|
||||
if self._terminal_event_sent:
|
||||
return
|
||||
self._closed = True
|
||||
self._terminal_event_sent = True
|
||||
self._cancel_event.set()
|
||||
self._fail_pending_control_requests(exc)
|
||||
await self._cancel_incoming_control_requests()
|
||||
await self._transport.close()
|
||||
await self._message_queue.put(exc)
|
||||
await self._message_queue.put(_DONE)
|
||||
|
||||
def __aiter__(self) -> Query:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> SDKMessage:
|
||||
if self._exhausted:
|
||||
raise StopAsyncIteration
|
||||
await self._ensure_started()
|
||||
item = await self._message_queue.get()
|
||||
|
||||
if item is _DONE:
|
||||
self._exhausted = True
|
||||
raise StopAsyncIteration
|
||||
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
|
||||
return cast(SDKMessage, item)
|
||||
|
||||
async def __aenter__(self) -> Query:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
await self.close()
|
||||
|
||||
|
||||
def query(
|
||||
prompt: str | AsyncIterable[SDKUserMessage],
|
||||
options: QueryOptions | QueryOptionsDict | Mapping[str, Any] | None = None,
|
||||
) -> Query:
|
||||
if isinstance(options, QueryOptions):
|
||||
parsed_options = replace(options)
|
||||
else:
|
||||
parsed_options = QueryOptions.from_mapping(options)
|
||||
|
||||
validate_query_options(parsed_options)
|
||||
|
||||
session_id = parsed_options.resume or parsed_options.session_id
|
||||
if session_id is None and not parsed_options.continue_session:
|
||||
session_id = str(uuid4())
|
||||
if parsed_options.resume is None and not parsed_options.continue_session:
|
||||
parsed_options = replace(parsed_options, session_id=session_id)
|
||||
|
||||
transport = ProcessTransport(parsed_options)
|
||||
return Query(transport, parsed_options, prompt, session_id or "")
|
||||
217
packages/sdk-python/src/qwen_code_sdk/sync_query.py
Normal file
217
packages/sdk-python/src/qwen_code_sdk/sync_query.py
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
"""Synchronous wrapper around the async Query API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import warnings
|
||||
from collections.abc import AsyncIterable, AsyncIterator, Iterable, Mapping
|
||||
from queue import Queue
|
||||
from typing import Any, cast
|
||||
|
||||
from .protocol import SDKMessage, SDKUserMessage
|
||||
from .query import Query, query
|
||||
from .types import QueryOptions, QueryOptionsDict
|
||||
|
||||
_STOP = object()
|
||||
_SYNC_TIMEOUT_MARGIN = 5.0
|
||||
|
||||
|
||||
class SyncQuery:
|
||||
def __init__(
|
||||
self,
|
||||
prompt: str | Iterable[SDKUserMessage] | AsyncIterable[SDKUserMessage],
|
||||
options: QueryOptions | QueryOptionsDict | Mapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
self._queue: Queue[SDKMessage | Exception | object] = Queue()
|
||||
self._ready = threading.Event()
|
||||
self._shutdown = threading.Event()
|
||||
self._stop_sent = threading.Event()
|
||||
self._exhausted = False
|
||||
self._query: Query | None = None
|
||||
self._consumer_task: asyncio.Task[None] | None = None
|
||||
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._thread = threading.Thread(
|
||||
target=self._run_loop,
|
||||
name="qwen-sdk-sync-loop",
|
||||
daemon=True,
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
if isinstance(prompt, str) or isinstance(prompt, AsyncIterable):
|
||||
source_prompt: str | AsyncIterable[SDKUserMessage] = prompt
|
||||
else:
|
||||
source_prompt = _iterable_to_async(prompt)
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._bootstrap(source_prompt, options),
|
||||
self._loop,
|
||||
)
|
||||
try:
|
||||
future.result()
|
||||
except Exception:
|
||||
self._stop_loop()
|
||||
raise
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._loop.run_forever()
|
||||
|
||||
async def _bootstrap(
|
||||
self,
|
||||
prompt: str | AsyncIterable[SDKUserMessage],
|
||||
options: QueryOptions | QueryOptionsDict | Mapping[str, Any] | None,
|
||||
) -> None:
|
||||
self._query = query(prompt=prompt, options=options)
|
||||
self._ready.set()
|
||||
self._consumer_task = asyncio.create_task(self._consume())
|
||||
|
||||
async def _consume(self) -> None:
|
||||
assert self._query is not None
|
||||
try:
|
||||
async for message in self._query:
|
||||
self._queue.put(message)
|
||||
except Exception as exc:
|
||||
self._queue.put(exc)
|
||||
finally:
|
||||
if not self._stop_sent.is_set():
|
||||
self._stop_sent.set()
|
||||
self._queue.put(_STOP)
|
||||
|
||||
def _require_query(self) -> Query:
|
||||
self._ready.wait(timeout=30)
|
||||
if self._query is None:
|
||||
raise RuntimeError("SyncQuery failed to initialize")
|
||||
return self._query
|
||||
|
||||
def __iter__(self) -> SyncQuery:
|
||||
return self
|
||||
|
||||
def __next__(self) -> SDKMessage:
|
||||
if self._exhausted:
|
||||
raise StopIteration
|
||||
item = self._queue.get()
|
||||
|
||||
if item is _STOP:
|
||||
self._exhausted = True
|
||||
raise StopIteration
|
||||
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
|
||||
return cast(SDKMessage, item)
|
||||
|
||||
def __enter__(self) -> SyncQuery:
|
||||
return self
|
||||
|
||||
def __exit__(self, *_args: object) -> None:
|
||||
self.close()
|
||||
|
||||
def interrupt(self) -> None:
|
||||
q = self._require_query()
|
||||
asyncio.run_coroutine_threadsafe(q.interrupt(), self._loop).result(
|
||||
timeout=q.control_request_timeout + _SYNC_TIMEOUT_MARGIN
|
||||
)
|
||||
|
||||
def set_model(self, model: str) -> None:
|
||||
q = self._require_query()
|
||||
asyncio.run_coroutine_threadsafe(q.set_model(model), self._loop).result(
|
||||
timeout=q.control_request_timeout + _SYNC_TIMEOUT_MARGIN
|
||||
)
|
||||
|
||||
def set_permission_mode(self, mode: str) -> None:
|
||||
q = self._require_query()
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
q.set_permission_mode(mode),
|
||||
self._loop,
|
||||
).result(timeout=q.control_request_timeout + _SYNC_TIMEOUT_MARGIN)
|
||||
|
||||
def supported_commands(self) -> Any:
|
||||
q = self._require_query()
|
||||
return asyncio.run_coroutine_threadsafe(
|
||||
q.supported_commands(),
|
||||
self._loop,
|
||||
).result(timeout=q.control_request_timeout + _SYNC_TIMEOUT_MARGIN)
|
||||
|
||||
def mcp_server_status(self) -> Any:
|
||||
q = self._require_query()
|
||||
return asyncio.run_coroutine_threadsafe(
|
||||
q.mcp_server_status(),
|
||||
self._loop,
|
||||
).result(timeout=q.control_request_timeout + _SYNC_TIMEOUT_MARGIN)
|
||||
|
||||
def get_session_id(self) -> str:
|
||||
q = self._require_query()
|
||||
return q.get_session_id()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
q = self._require_query()
|
||||
return q.is_closed()
|
||||
|
||||
def close(self) -> None:
|
||||
if self._shutdown.is_set():
|
||||
return
|
||||
|
||||
self._shutdown.set()
|
||||
|
||||
q = self._query
|
||||
if q is not None:
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(q.close(), self._loop).result(
|
||||
timeout=30
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Wait for _consume() to put _STOP before stopping the loop,
|
||||
# otherwise consumers blocked on queue.get() will deadlock.
|
||||
if self._consumer_task is not None:
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._await_consumer(), self._loop
|
||||
).result(timeout=5)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not self._stop_sent.is_set():
|
||||
self._stop_sent.set()
|
||||
self._queue.put(_STOP)
|
||||
self._stop_loop()
|
||||
|
||||
async def _await_consumer(self) -> None:
|
||||
if self._consumer_task is not None:
|
||||
try:
|
||||
await asyncio.wait_for(self._consumer_task, timeout=5.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _stop_loop(self) -> None:
|
||||
if self._loop.is_running():
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
self._thread.join(timeout=5)
|
||||
if not self._loop.is_closed():
|
||||
self._loop.close()
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
if not self._shutdown.is_set():
|
||||
warnings.warn(
|
||||
"SyncQuery was not closed. "
|
||||
"Use 'with SyncQuery(...) as q:' or call q.close() explicitly.",
|
||||
ResourceWarning,
|
||||
stacklevel=1,
|
||||
)
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
async def _iterable_to_async(
|
||||
messages: Iterable[SDKUserMessage],
|
||||
) -> AsyncIterator[SDKUserMessage]:
|
||||
for message in messages:
|
||||
yield message
|
||||
243
packages/sdk-python/src/qwen_code_sdk/transport.py
Normal file
243
packages/sdk-python/src/qwen_code_sdk/transport.py
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
"""Process transport for qwen CLI stream-json protocol."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .errors import ProcessExitError
|
||||
from .json_lines import parse_json_line
|
||||
from .types import QueryOptions
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SpawnInfo:
|
||||
command: str
|
||||
args: list[str]
|
||||
|
||||
|
||||
def prepare_spawn_info(path_to_qwen_executable: str | None) -> SpawnInfo:
|
||||
if path_to_qwen_executable is None:
|
||||
return SpawnInfo(command="qwen", args=[])
|
||||
|
||||
spec = path_to_qwen_executable
|
||||
if os.path.sep not in spec and (
|
||||
os.path.altsep is None or os.path.altsep not in spec
|
||||
):
|
||||
return SpawnInfo(command=spec, args=[])
|
||||
|
||||
path = Path(spec).expanduser().resolve()
|
||||
|
||||
suffix = path.suffix.lower()
|
||||
if suffix == ".py":
|
||||
return SpawnInfo(command=sys.executable, args=[str(path)])
|
||||
if suffix in {".js", ".mjs", ".cjs"}:
|
||||
return SpawnInfo(command="node", args=[str(path)])
|
||||
|
||||
return SpawnInfo(command=str(path), args=[])
|
||||
|
||||
|
||||
class ProcessTransport:
|
||||
def __init__(self, options: QueryOptions):
|
||||
self._options = options
|
||||
self._process: asyncio.subprocess.Process | None = None
|
||||
self._stderr_task: asyncio.Task[None] | None = None
|
||||
self._closed = False
|
||||
self._input_closed = False
|
||||
self._exit_error: Exception | None = None
|
||||
|
||||
@property
|
||||
def exit_error(self) -> Exception | None:
|
||||
return self._exit_error
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._closed:
|
||||
raise RuntimeError("Transport is closed")
|
||||
if self._process is not None:
|
||||
return
|
||||
|
||||
spawn_info = prepare_spawn_info(self._options.path_to_qwen_executable)
|
||||
args = [*spawn_info.args, *build_cli_arguments(self._options)]
|
||||
stderr_target = (
|
||||
asyncio.subprocess.PIPE
|
||||
if self._options.debug or self._options.stderr is not None
|
||||
else subprocess.DEVNULL
|
||||
)
|
||||
|
||||
self._process = await asyncio.create_subprocess_exec(
|
||||
spawn_info.command,
|
||||
*args,
|
||||
cwd=self._options.cwd,
|
||||
env={**os.environ, **(self._options.env or {})},
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=stderr_target,
|
||||
)
|
||||
|
||||
if self._options.debug or self._options.stderr is not None:
|
||||
self._stderr_task = asyncio.create_task(self._forward_stderr())
|
||||
|
||||
async def _forward_stderr(self) -> None:
|
||||
if self._process is None or self._process.stderr is None:
|
||||
return
|
||||
|
||||
while True:
|
||||
chunk = await self._process.stderr.readline()
|
||||
if not chunk:
|
||||
return
|
||||
text = chunk.decode("utf-8", errors="replace").rstrip("\n")
|
||||
try:
|
||||
if self._options.stderr is not None:
|
||||
self._options.stderr(text)
|
||||
elif self._options.debug:
|
||||
print(text, file=sys.stderr)
|
||||
except Exception:
|
||||
print(text, file=sys.stderr)
|
||||
|
||||
def write(self, data: str) -> None:
|
||||
if self._closed:
|
||||
raise RuntimeError("Transport is closed")
|
||||
if self._process is None or self._process.stdin is None:
|
||||
raise RuntimeError("Transport is not started")
|
||||
if self._input_closed:
|
||||
raise RuntimeError("Transport input is already closed")
|
||||
|
||||
self._process.stdin.write(data.encode("utf-8"))
|
||||
|
||||
async def drain(self) -> None:
|
||||
if self._process is None or self._process.stdin is None:
|
||||
return
|
||||
await self._process.stdin.drain()
|
||||
|
||||
def end_input(self) -> None:
|
||||
if self._closed or self._input_closed:
|
||||
return
|
||||
if self._process is None or self._process.stdin is None:
|
||||
return
|
||||
self._process.stdin.close()
|
||||
self._input_closed = True
|
||||
|
||||
async def read_messages(self) -> AsyncIterator[Any]:
|
||||
if self._process is None or self._process.stdout is None:
|
||||
raise RuntimeError("Transport is not started")
|
||||
|
||||
while True:
|
||||
line = await self._process.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
|
||||
raw = line.decode("utf-8", errors="replace").strip()
|
||||
if not raw:
|
||||
continue
|
||||
try:
|
||||
yield parse_json_line(raw)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
await self._finalize_exit()
|
||||
|
||||
async def wait_for_exit(self) -> None:
|
||||
if self._process is None:
|
||||
return
|
||||
await self._finalize_exit()
|
||||
|
||||
async def _finalize_exit(self) -> None:
|
||||
if self._process is None:
|
||||
return
|
||||
|
||||
return_code = self._process.returncode
|
||||
if return_code is None:
|
||||
return_code = await self._process.wait()
|
||||
|
||||
if return_code != 0 and self._exit_error is None:
|
||||
self._exit_error = ProcessExitError(
|
||||
f"CLI process exited with code {return_code}",
|
||||
exit_code=return_code,
|
||||
)
|
||||
|
||||
if self._stderr_task is not None:
|
||||
await self._stderr_task
|
||||
self._stderr_task = None
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
if self._process is None:
|
||||
return
|
||||
|
||||
if self._process.stdin is not None and not self._input_closed:
|
||||
self._process.stdin.close()
|
||||
self._input_closed = True
|
||||
|
||||
if self._process.returncode is None:
|
||||
self._process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(self._process.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._process.kill()
|
||||
await self._process.wait()
|
||||
|
||||
await self._finalize_exit()
|
||||
|
||||
|
||||
def build_cli_arguments(options: QueryOptions) -> list[str]:
|
||||
args: list[str] = [
|
||||
"--input-format",
|
||||
"stream-json",
|
||||
"--output-format",
|
||||
"stream-json",
|
||||
"--channel=SDK",
|
||||
]
|
||||
|
||||
if options.model:
|
||||
args.extend(["--model", options.model])
|
||||
|
||||
if options.system_prompt:
|
||||
args.extend(["--system-prompt", options.system_prompt])
|
||||
|
||||
if options.append_system_prompt:
|
||||
args.extend(["--append-system-prompt", options.append_system_prompt])
|
||||
|
||||
if options.permission_mode:
|
||||
args.extend(["--approval-mode", options.permission_mode])
|
||||
|
||||
if options.max_session_turns is not None:
|
||||
args.extend(["--max-session-turns", str(options.max_session_turns)])
|
||||
|
||||
if options.core_tools:
|
||||
args.extend(["--core-tools", ",".join(options.core_tools)])
|
||||
|
||||
if options.exclude_tools:
|
||||
args.extend(["--exclude-tools", ",".join(options.exclude_tools)])
|
||||
|
||||
if options.allowed_tools:
|
||||
args.extend(["--allowed-tools", ",".join(options.allowed_tools)])
|
||||
|
||||
if options.auth_type:
|
||||
args.extend(["--auth-type", options.auth_type])
|
||||
|
||||
if options.include_partial_messages:
|
||||
args.append("--include-partial-messages")
|
||||
|
||||
if options.resume:
|
||||
args.extend(["--resume", options.resume])
|
||||
elif options.continue_session:
|
||||
args.append("--continue")
|
||||
elif options.session_id:
|
||||
args.extend(["--session-id", options.session_id])
|
||||
|
||||
return args
|
||||
323
packages/sdk-python/src/qwen_code_sdk/types.py
Normal file
323
packages/sdk-python/src/qwen_code_sdk/types.py
Normal file
|
|
@ -0,0 +1,323 @@
|
|||
"""Public type definitions for qwen_code_sdk."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable, Mapping, MutableMapping
|
||||
from dataclasses import dataclass
|
||||
from inspect import Parameter, Signature, iscoroutinefunction, signature
|
||||
from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
TypeAlias,
|
||||
TypedDict,
|
||||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
PermissionMode: TypeAlias = Literal["default", "plan", "auto-edit", "yolo"]
|
||||
AuthType: TypeAlias = Literal[
|
||||
"openai",
|
||||
"anthropic",
|
||||
"qwen-oauth",
|
||||
"gemini",
|
||||
"vertex-ai",
|
||||
]
|
||||
|
||||
|
||||
class PermissionSuggestion(TypedDict):
|
||||
type: Literal["allow", "deny", "modify"]
|
||||
label: str
|
||||
description: NotRequired[str]
|
||||
modifiedInput: NotRequired[Any]
|
||||
|
||||
|
||||
class PermissionAllowResult(TypedDict):
|
||||
behavior: Literal["allow"]
|
||||
updatedInput: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
class PermissionDenyResult(TypedDict):
|
||||
behavior: Literal["deny"]
|
||||
message: NotRequired[str]
|
||||
interrupt: NotRequired[bool]
|
||||
|
||||
|
||||
PermissionResult: TypeAlias = PermissionAllowResult | PermissionDenyResult
|
||||
|
||||
|
||||
class CanUseToolContext(TypedDict):
|
||||
cancel_event: Any
|
||||
suggestions: list[PermissionSuggestion] | None
|
||||
blocked_path: str | None
|
||||
|
||||
|
||||
CanUseTool: TypeAlias = Callable[
|
||||
[str, dict[str, Any], CanUseToolContext],
|
||||
Awaitable[PermissionResult],
|
||||
]
|
||||
|
||||
|
||||
class TimeoutOptionsDict(TypedDict, total=False):
|
||||
"""Timeout configuration. All values are in seconds."""
|
||||
|
||||
can_use_tool: float
|
||||
control_request: float
|
||||
stream_close: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TimeoutOptions:
|
||||
can_use_tool: float = 60.0
|
||||
control_request: float = 60.0
|
||||
stream_close: float = 60.0
|
||||
|
||||
@classmethod
|
||||
def from_mapping(cls, value: Mapping[str, Any] | None) -> TimeoutOptions:
|
||||
if value is None:
|
||||
return cls()
|
||||
|
||||
def _read(name: str, default: float) -> float:
|
||||
raw = value.get(name, default)
|
||||
if isinstance(raw, bool) or not isinstance(raw, (int, float)):
|
||||
raise TypeError(f"timeout.{name} must be a positive number")
|
||||
if raw <= 0:
|
||||
raise ValueError(f"timeout.{name} must be a positive number")
|
||||
return float(raw)
|
||||
|
||||
return cls(
|
||||
can_use_tool=_read("can_use_tool", 60.0),
|
||||
control_request=_read("control_request", 60.0),
|
||||
stream_close=_read("stream_close", 60.0),
|
||||
)
|
||||
|
||||
|
||||
class QueryOptionsDict(TypedDict, total=False):
|
||||
cwd: str
|
||||
model: str
|
||||
path_to_qwen_executable: str
|
||||
permission_mode: PermissionMode
|
||||
can_use_tool: CanUseTool
|
||||
env: dict[str, str]
|
||||
system_prompt: str
|
||||
append_system_prompt: str
|
||||
debug: bool
|
||||
max_session_turns: int
|
||||
core_tools: list[str]
|
||||
exclude_tools: list[str]
|
||||
allowed_tools: list[str]
|
||||
auth_type: AuthType
|
||||
include_partial_messages: bool
|
||||
resume: str
|
||||
continue_session: bool
|
||||
session_id: str
|
||||
timeout: TimeoutOptionsDict
|
||||
mcp_servers: dict[str, dict[str, Any]]
|
||||
stderr: Callable[[str], None]
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryOptions:
|
||||
cwd: str | None = None
|
||||
model: str | None = None
|
||||
path_to_qwen_executable: str | None = None
|
||||
permission_mode: PermissionMode | None = None
|
||||
can_use_tool: CanUseTool | None = None
|
||||
env: dict[str, str] | None = None
|
||||
system_prompt: str | None = None
|
||||
append_system_prompt: str | None = None
|
||||
debug: bool = False
|
||||
max_session_turns: int | None = None
|
||||
core_tools: list[str] | None = None
|
||||
exclude_tools: list[str] | None = None
|
||||
allowed_tools: list[str] | None = None
|
||||
auth_type: AuthType | None = None
|
||||
include_partial_messages: bool = False
|
||||
resume: str | None = None
|
||||
continue_session: bool = False
|
||||
session_id: str | None = None
|
||||
timeout: TimeoutOptions = TimeoutOptions()
|
||||
mcp_servers: dict[str, dict[str, Any]] | None = None
|
||||
stderr: Callable[[str], None] | None = None
|
||||
|
||||
@classmethod
|
||||
def from_mapping(cls, value: Mapping[str, Any] | None) -> QueryOptions:
|
||||
if value is None:
|
||||
return cls()
|
||||
|
||||
data: MutableMapping[str, Any] = dict(value)
|
||||
timeout = TimeoutOptions.from_mapping(data.get("timeout"))
|
||||
|
||||
return cls(
|
||||
cwd=_as_optional_str(data, "cwd"),
|
||||
model=_as_optional_str(data, "model"),
|
||||
path_to_qwen_executable=_as_optional_str(data, "path_to_qwen_executable"),
|
||||
permission_mode=cast(
|
||||
PermissionMode | None,
|
||||
_as_optional_str(data, "permission_mode"),
|
||||
),
|
||||
can_use_tool=cast(
|
||||
CanUseTool | None,
|
||||
_as_optional_callable(data, "can_use_tool"),
|
||||
),
|
||||
env=_as_optional_str_dict(data, "env"),
|
||||
system_prompt=_as_optional_str(data, "system_prompt"),
|
||||
append_system_prompt=_as_optional_str(data, "append_system_prompt"),
|
||||
debug=_as_optional_bool(data, "debug") or False,
|
||||
max_session_turns=_as_optional_int(data, "max_session_turns"),
|
||||
core_tools=_as_optional_str_list(data, "core_tools"),
|
||||
exclude_tools=_as_optional_str_list(data, "exclude_tools"),
|
||||
allowed_tools=_as_optional_str_list(data, "allowed_tools"),
|
||||
auth_type=cast(
|
||||
AuthType | None,
|
||||
_as_optional_str(data, "auth_type"),
|
||||
),
|
||||
include_partial_messages=_as_optional_bool(data, "include_partial_messages")
|
||||
or False,
|
||||
resume=_as_optional_str(data, "resume"),
|
||||
continue_session=_as_optional_bool(data, "continue_session") or False,
|
||||
session_id=_as_optional_str(data, "session_id"),
|
||||
timeout=timeout,
|
||||
mcp_servers=_as_optional_nested_dict(data, "mcp_servers"),
|
||||
stderr=cast(
|
||||
Callable[[str], None] | None,
|
||||
_as_optional_callable(data, "stderr"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _as_optional_str(data: Mapping[str, Any], key: str) -> str | None:
|
||||
raw = data.get(key)
|
||||
if raw is None:
|
||||
return None
|
||||
if not isinstance(raw, str):
|
||||
raise TypeError(f"{key} must be a string")
|
||||
return raw
|
||||
|
||||
|
||||
def _as_optional_int(data: Mapping[str, Any], key: str) -> int | None:
|
||||
raw = data.get(key)
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, bool) or not isinstance(raw, int):
|
||||
raise TypeError(f"{key} must be an integer")
|
||||
return int(raw)
|
||||
|
||||
|
||||
def _as_optional_bool(data: Mapping[str, Any], key: str) -> bool | None:
|
||||
raw = data.get(key)
|
||||
if raw is None:
|
||||
return None
|
||||
if not isinstance(raw, bool):
|
||||
raise TypeError(f"{key} must be a boolean")
|
||||
return raw
|
||||
|
||||
|
||||
def _as_optional_callable(
|
||||
data: Mapping[str, Any], key: str
|
||||
) -> Callable[..., Any] | None:
|
||||
raw = data.get(key)
|
||||
if raw is None:
|
||||
return None
|
||||
if not callable(raw):
|
||||
raise TypeError(f"{key} must be callable")
|
||||
if key == "can_use_tool":
|
||||
_validate_can_use_tool_callable(raw, error_type=TypeError)
|
||||
elif key == "stderr":
|
||||
_validate_stderr_callable(raw, error_type=TypeError)
|
||||
return cast(Callable[..., Any], raw)
|
||||
|
||||
|
||||
def _validate_can_use_tool_callable(value: object, error_type: type[Exception]) -> None:
|
||||
if not callable(value):
|
||||
raise error_type("can_use_tool must be callable")
|
||||
|
||||
if not iscoroutinefunction(value):
|
||||
raise error_type("can_use_tool must be an async callable")
|
||||
|
||||
try:
|
||||
sig = signature(value)
|
||||
except (TypeError, ValueError):
|
||||
return
|
||||
|
||||
if not _supports_argument_count(sig, 3):
|
||||
raise error_type("can_use_tool must accept exactly 3 positional arguments")
|
||||
|
||||
|
||||
def _validate_stderr_callable(value: object, error_type: type[Exception]) -> None:
|
||||
if not callable(value):
|
||||
raise error_type("stderr must be callable")
|
||||
|
||||
try:
|
||||
sig = signature(value)
|
||||
except (TypeError, ValueError):
|
||||
return
|
||||
|
||||
if not _supports_argument_count(sig, 1):
|
||||
raise error_type("stderr must accept exactly 1 positional argument")
|
||||
|
||||
|
||||
def _supports_argument_count(sig: Signature, count: int) -> bool:
|
||||
params = list(sig.parameters.values())
|
||||
positional_params = [
|
||||
param
|
||||
for param in params
|
||||
if param.kind
|
||||
in (
|
||||
Parameter.POSITIONAL_ONLY,
|
||||
Parameter.POSITIONAL_OR_KEYWORD,
|
||||
)
|
||||
]
|
||||
required_positional = [
|
||||
param for param in positional_params if param.default is Parameter.empty
|
||||
]
|
||||
has_var_positional = any(param.kind is Parameter.VAR_POSITIONAL for param in params)
|
||||
|
||||
if len(required_positional) > count:
|
||||
return False
|
||||
if has_var_positional:
|
||||
return True
|
||||
return len(positional_params) >= count
|
||||
|
||||
|
||||
def _as_optional_str_dict(data: Mapping[str, Any], key: str) -> dict[str, str] | None:
|
||||
raw = data.get(key)
|
||||
if raw is None:
|
||||
return None
|
||||
if not isinstance(raw, Mapping):
|
||||
raise TypeError(f"{key} must be a mapping of string to string")
|
||||
|
||||
parsed: dict[str, str] = {}
|
||||
for k, v in raw.items():
|
||||
if not isinstance(k, str) or not isinstance(v, str):
|
||||
raise TypeError(f"{key} must be a mapping of string to string")
|
||||
parsed[k] = v
|
||||
return parsed
|
||||
|
||||
|
||||
def _as_optional_str_list(data: Mapping[str, Any], key: str) -> list[str] | None:
|
||||
raw = data.get(key)
|
||||
if raw is None:
|
||||
return None
|
||||
if not isinstance(raw, list):
|
||||
raise TypeError(f"{key} must be a list of strings")
|
||||
if any(not isinstance(item, str) for item in raw):
|
||||
raise TypeError(f"{key} must be a list of strings")
|
||||
return list(raw)
|
||||
|
||||
|
||||
def _as_optional_nested_dict(
|
||||
data: Mapping[str, Any], key: str
|
||||
) -> dict[str, dict[str, Any]] | None:
|
||||
raw = data.get(key)
|
||||
if raw is None:
|
||||
return None
|
||||
if not isinstance(raw, Mapping):
|
||||
raise TypeError(f"{key} must be a mapping")
|
||||
|
||||
parsed: dict[str, dict[str, Any]] = {}
|
||||
for k, v in raw.items():
|
||||
if not isinstance(k, str) or not isinstance(v, Mapping):
|
||||
raise TypeError(f"{key} must be a mapping of string to mapping")
|
||||
parsed[k] = dict(v)
|
||||
return parsed
|
||||
94
packages/sdk-python/src/qwen_code_sdk/validation.py
Normal file
94
packages/sdk-python/src/qwen_code_sdk/validation.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
"""Validation helpers for query options."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from uuid import RFC_4122, UUID
|
||||
|
||||
from .errors import ValidationError
|
||||
from .types import (
|
||||
QueryOptions,
|
||||
_validate_can_use_tool_callable,
|
||||
_validate_stderr_callable,
|
||||
)
|
||||
|
||||
_VALID_PERMISSION_MODES = {"default", "plan", "auto-edit", "yolo"}
|
||||
_VALID_AUTH_TYPES = {"openai", "anthropic", "qwen-oauth", "gemini", "vertex-ai"}
|
||||
|
||||
|
||||
def validate_query_options(options: QueryOptions) -> None:
|
||||
if (
|
||||
options.permission_mode
|
||||
and options.permission_mode not in _VALID_PERMISSION_MODES
|
||||
):
|
||||
raise ValidationError(
|
||||
f"Invalid permission_mode: {options.permission_mode!r}. "
|
||||
"Expected one of: default, plan, auto-edit, yolo."
|
||||
)
|
||||
|
||||
if options.auth_type and options.auth_type not in _VALID_AUTH_TYPES:
|
||||
raise ValidationError(
|
||||
f"Invalid auth_type: {options.auth_type!r}. "
|
||||
"Expected one of: openai, anthropic, qwen-oauth, gemini, vertex-ai."
|
||||
)
|
||||
|
||||
_validate_optional_callable(options.can_use_tool, _validate_can_use_tool_callable)
|
||||
_validate_optional_callable(options.stderr, _validate_stderr_callable)
|
||||
|
||||
if options.resume and options.continue_session:
|
||||
raise ValidationError(
|
||||
"Cannot use resume together with continue_session. "
|
||||
"Use continue_session for latest session "
|
||||
"or resume for a specific session ID."
|
||||
)
|
||||
|
||||
if options.session_id and (options.resume or options.continue_session):
|
||||
raise ValidationError(
|
||||
"Cannot use session_id with resume or continue_session. "
|
||||
"session_id starts a new session, "
|
||||
"resume/continue_session restore existing sessions."
|
||||
)
|
||||
|
||||
if options.session_id:
|
||||
validate_session_id(options.session_id, "session_id")
|
||||
|
||||
if options.resume:
|
||||
validate_session_id(options.resume, "resume")
|
||||
|
||||
if options.max_session_turns is not None and options.max_session_turns < -1:
|
||||
raise ValidationError("max_session_turns must be -1 or a non-negative integer")
|
||||
|
||||
if (
|
||||
options.path_to_qwen_executable is not None
|
||||
and not options.path_to_qwen_executable.strip()
|
||||
):
|
||||
raise ValidationError("path_to_qwen_executable cannot be empty")
|
||||
|
||||
if options.mcp_servers:
|
||||
raise ValidationError(
|
||||
"mcp_servers is not supported in Python SDK v1. "
|
||||
"Remove the mcp_servers option or use the TypeScript SDK."
|
||||
)
|
||||
|
||||
|
||||
def _validate_optional_callable(
|
||||
value: object,
|
||||
validator: Callable[[object, type[ValidationError]], None],
|
||||
) -> None:
|
||||
if value is None:
|
||||
return
|
||||
validator(value, ValidationError)
|
||||
|
||||
|
||||
def validate_session_id(value: str, param_name: str) -> None:
|
||||
try:
|
||||
parsed = UUID(value)
|
||||
except ValueError as exc:
|
||||
raise ValidationError(
|
||||
f"Invalid {param_name}: {value!r}. Must be a valid UUID."
|
||||
) from exc
|
||||
|
||||
if parsed.variant != RFC_4122:
|
||||
raise ValidationError(
|
||||
f"Invalid {param_name}: {value!r}. UUID variant must be RFC 4122."
|
||||
)
|
||||
400
packages/sdk-python/tests/integration/conftest.py
Normal file
400
packages/sdk-python/tests/integration/conftest.py
Normal file
|
|
@ -0,0 +1,400 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import stat
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fake_qwen_path(tmp_path: Path) -> str:
|
||||
script_path = tmp_path / "fake_qwen.py"
|
||||
script_path.write_text(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
|
||||
def send(message):
|
||||
sys.stdout.write(json.dumps(message, separators=(",", ":")) + "\\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def parse_user_content(message):
|
||||
payload = message.get("message", {})
|
||||
content = payload.get("content", "")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text_parts.append(str(block.get("text", "")))
|
||||
return " ".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
def build_system_message():
|
||||
return {
|
||||
"type": "system",
|
||||
"subtype": "init",
|
||||
"uuid": session_id,
|
||||
"session_id": session_id,
|
||||
"cwd": ".",
|
||||
"tools": ["Read", "Edit", "Bash"],
|
||||
"mcp_servers": [],
|
||||
"model": state["model"],
|
||||
"permission_mode": state["permission_mode"],
|
||||
"qwen_code_version": "fake-1.0.0",
|
||||
"capabilities": {
|
||||
"canSetModel": True,
|
||||
"canSetPermissionMode": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def build_assistant_message(text):
|
||||
return {
|
||||
"type": "assistant",
|
||||
"uuid": str(uuid.uuid4()),
|
||||
"session_id": session_id,
|
||||
"message": {
|
||||
"id": str(uuid.uuid4()),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": state["model"],
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": text,
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"input_tokens": 1,
|
||||
"output_tokens": 1,
|
||||
},
|
||||
},
|
||||
"parent_tool_use_id": None,
|
||||
}
|
||||
|
||||
|
||||
def build_result_message(result_text):
|
||||
return {
|
||||
"type": "result",
|
||||
"subtype": "success",
|
||||
"uuid": str(uuid.uuid4()),
|
||||
"session_id": session_id,
|
||||
"is_error": False,
|
||||
"duration_ms": 5,
|
||||
"duration_api_ms": 1,
|
||||
"num_turns": 1,
|
||||
"result": result_text,
|
||||
"usage": {
|
||||
"input_tokens": 1,
|
||||
"output_tokens": 1,
|
||||
},
|
||||
"permission_denials": [],
|
||||
}
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument("--model")
|
||||
parser.add_argument("--approval-mode")
|
||||
parser.add_argument("--include-partial-messages", action="store_true")
|
||||
parser.add_argument("--session-id")
|
||||
parser.add_argument("--resume")
|
||||
parser.add_argument(
|
||||
"--continue",
|
||||
dest="continue_session",
|
||||
action="store_true",
|
||||
)
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
session_id = (
|
||||
args.resume
|
||||
or args.session_id
|
||||
or (
|
||||
"aaaaaaaa-bbbb-4ccc-8ddd-eeeeeeeeeeee"
|
||||
if args.continue_session
|
||||
else str(uuid.uuid4())
|
||||
)
|
||||
)
|
||||
state = {
|
||||
"model": args.model or "coder-model",
|
||||
"permission_mode": args.approval_mode or "default",
|
||||
"include_partial": bool(args.include_partial_messages),
|
||||
}
|
||||
|
||||
pending_permission = None
|
||||
pending_unknown_control = None
|
||||
|
||||
for line in sys.stdin:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
message = json.loads(line)
|
||||
msg_type = message.get("type")
|
||||
|
||||
if msg_type == "control_request":
|
||||
request_id = message["request_id"]
|
||||
request = message["request"]
|
||||
subtype = request.get("subtype")
|
||||
|
||||
if subtype == "initialize":
|
||||
send(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": request_id,
|
||||
"response": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
send(build_system_message())
|
||||
elif subtype == "set_model":
|
||||
state["model"] = request["model"]
|
||||
send(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": request_id,
|
||||
"response": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
send(build_system_message())
|
||||
elif subtype == "set_permission_mode":
|
||||
state["permission_mode"] = request["mode"]
|
||||
send(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": request_id,
|
||||
"response": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
send(build_system_message())
|
||||
elif subtype == "interrupt":
|
||||
send(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": request_id,
|
||||
"response": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
elif subtype == "supported_commands":
|
||||
send(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": request_id,
|
||||
"response": {
|
||||
"commands": [
|
||||
"initialize",
|
||||
"interrupt",
|
||||
"set_model",
|
||||
"set_permission_mode",
|
||||
]
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
elif subtype == "mcp_server_status":
|
||||
send(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": request_id,
|
||||
"response": {"servers": []},
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
send(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "error",
|
||||
"request_id": request_id,
|
||||
"error": f"unsupported request: {subtype}",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
elif msg_type == "user":
|
||||
prompt = parse_user_content(message)
|
||||
|
||||
if "exit nonzero" in prompt:
|
||||
sys.exit(9)
|
||||
|
||||
if "request unknown control" in prompt:
|
||||
request_id = str(uuid.uuid4())
|
||||
pending_unknown_control = {
|
||||
"request_id": request_id,
|
||||
"prompt": prompt,
|
||||
}
|
||||
send(
|
||||
{
|
||||
"type": "control_request",
|
||||
"request_id": request_id,
|
||||
"request": {
|
||||
"subtype": "something_new",
|
||||
"payload": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if "use tool" in prompt or "create file" in prompt:
|
||||
tool_use_id = str(uuid.uuid4())
|
||||
send(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": str(uuid.uuid4()),
|
||||
"session_id": session_id,
|
||||
"message": {
|
||||
"id": str(uuid.uuid4()),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": state["model"],
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tool_use_id,
|
||||
"name": "write_file",
|
||||
"input": {
|
||||
"path": "demo.txt",
|
||||
"content": "hello",
|
||||
},
|
||||
}
|
||||
],
|
||||
"usage": {"input_tokens": 1, "output_tokens": 1},
|
||||
},
|
||||
"parent_tool_use_id": None,
|
||||
}
|
||||
)
|
||||
request_id = str(uuid.uuid4())
|
||||
pending_permission = {
|
||||
"request_id": request_id,
|
||||
"tool_use_id": tool_use_id,
|
||||
"prompt": prompt,
|
||||
}
|
||||
send(
|
||||
{
|
||||
"type": "control_request",
|
||||
"request_id": request_id,
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "write_file",
|
||||
"tool_use_id": tool_use_id,
|
||||
"input": {"path": "demo.txt", "content": "hello"},
|
||||
"permission_suggestions": [
|
||||
{"type": "allow", "label": "Allow write"}
|
||||
],
|
||||
"blocked_path": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if state["include_partial"]:
|
||||
send(
|
||||
{
|
||||
"type": "stream_event",
|
||||
"uuid": str(uuid.uuid4()),
|
||||
"session_id": session_id,
|
||||
"event": {
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": "partial"},
|
||||
},
|
||||
"parent_tool_use_id": None,
|
||||
}
|
||||
)
|
||||
|
||||
send(build_assistant_message(f"Echo: {prompt}"))
|
||||
send(build_result_message(f"done: {prompt}"))
|
||||
|
||||
elif msg_type == "control_response":
|
||||
payload = message.get("response", {})
|
||||
request_id = payload.get("request_id")
|
||||
|
||||
if (
|
||||
pending_unknown_control
|
||||
and request_id == pending_unknown_control["request_id"]
|
||||
):
|
||||
if payload.get("subtype") != "error":
|
||||
sys.exit(3)
|
||||
prompt = pending_unknown_control["prompt"]
|
||||
pending_unknown_control = None
|
||||
send(
|
||||
build_assistant_message(
|
||||
f"Unknown control handled for: {prompt}"
|
||||
)
|
||||
)
|
||||
send(build_result_message(f"unknown-control: {prompt}"))
|
||||
continue
|
||||
|
||||
if (
|
||||
pending_permission
|
||||
and request_id == pending_permission["request_id"]
|
||||
):
|
||||
prompt = pending_permission["prompt"]
|
||||
tool_use_id = pending_permission["tool_use_id"]
|
||||
pending_permission = None
|
||||
|
||||
behavior = "deny"
|
||||
if payload.get("subtype") == "success":
|
||||
response_payload = payload.get("response") or {}
|
||||
behavior = response_payload.get("behavior", "deny")
|
||||
|
||||
is_allowed = behavior == "allow"
|
||||
send(
|
||||
{
|
||||
"type": "user",
|
||||
"session_id": session_id,
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use_id,
|
||||
"is_error": not is_allowed,
|
||||
"content": "ok" if is_allowed else "denied",
|
||||
}
|
||||
],
|
||||
},
|
||||
"parent_tool_use_id": tool_use_id,
|
||||
}
|
||||
)
|
||||
send(build_assistant_message(f"tool handled: {prompt}"))
|
||||
send(build_result_message(f"tool-result: {prompt}"))
|
||||
continue
|
||||
"""
|
||||
).strip()
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
script_path.chmod(script_path.stat().st_mode | stat.S_IEXEC)
|
||||
return str(script_path)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_history_expansion() -> None:
|
||||
# No-op fixture used as explicit marker for deterministic test env.
|
||||
os.environ.setdefault("PYTHONUTF8", "1")
|
||||
276
packages/sdk-python/tests/integration/test_async_query.py
Normal file
276
packages/sdk-python/tests/integration/test_async_query.py
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from qwen_code_sdk import (
|
||||
ProcessExitError,
|
||||
SDKUserMessage,
|
||||
is_sdk_assistant_message,
|
||||
is_sdk_partial_assistant_message,
|
||||
is_sdk_result_message,
|
||||
is_sdk_system_message,
|
||||
is_sdk_user_message,
|
||||
query,
|
||||
)
|
||||
|
||||
CONTINUED_SESSION_ID = "aaaaaaaa-bbbb-4ccc-8ddd-eeeeeeeeeeee"
|
||||
VALID_UUID = "123e4567-e89b-12d3-a456-426614174000"
|
||||
RESUME_UUID = "223e4567-e89b-12d3-a456-426614174000"
|
||||
|
||||
|
||||
async def _collect_messages(result: Any) -> list[dict[str, Any]]:
|
||||
messages: list[dict[str, Any]] = []
|
||||
async for message in result:
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
|
||||
async def _wait_for(predicate: Callable[[], bool], timeout: float = 2.0) -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
deadline = loop.time() + timeout
|
||||
while loop.time() < deadline:
|
||||
if predicate():
|
||||
return
|
||||
await asyncio.sleep(0.01)
|
||||
raise AssertionError("timed out waiting for expected SDK state")
|
||||
|
||||
|
||||
def _tool_result_error_flag(message: dict[str, Any]) -> bool:
|
||||
content = message["message"]["content"]
|
||||
assert isinstance(content, list)
|
||||
return bool(content[0]["is_error"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_turn_query(fake_qwen_path: str) -> None:
|
||||
result = query(
|
||||
"hello world",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
},
|
||||
)
|
||||
messages = await _collect_messages(result)
|
||||
|
||||
assistant = next(
|
||||
message for message in messages if is_sdk_assistant_message(message)
|
||||
)
|
||||
final = next(message for message in messages if is_sdk_result_message(message))
|
||||
|
||||
assert assistant["message"]["content"][0]["text"] == "Echo: hello world"
|
||||
assert final["result"] == "done: hello world"
|
||||
await result.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_include_partial_messages(fake_qwen_path: str) -> None:
|
||||
result = query(
|
||||
"stream partial",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
"include_partial_messages": True,
|
||||
},
|
||||
)
|
||||
messages = await _collect_messages(result)
|
||||
|
||||
partial = next(
|
||||
message for message in messages if is_sdk_partial_assistant_message(message)
|
||||
)
|
||||
assert partial["event"]["type"] == "content_block_delta"
|
||||
await result.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_permission_callback_denies_tool_use(fake_qwen_path: str) -> None:
|
||||
result = query(
|
||||
"use tool now",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
},
|
||||
)
|
||||
messages = await _collect_messages(result)
|
||||
|
||||
tool_result = next(
|
||||
message
|
||||
for message in messages
|
||||
if is_sdk_user_message(message)
|
||||
and isinstance(message["message"]["content"], list)
|
||||
)
|
||||
assert _tool_result_error_flag(tool_result) is True
|
||||
await result.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_callback_can_allow_tool_use(fake_qwen_path: str) -> None:
|
||||
async def can_use_tool(
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
assert tool_name == "write_file"
|
||||
assert tool_input["path"] == "demo.txt"
|
||||
assert context["suggestions"][0]["type"] == "allow"
|
||||
return {"behavior": "allow", "updatedInput": tool_input}
|
||||
|
||||
result = query(
|
||||
"create file with use tool",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
"can_use_tool": can_use_tool,
|
||||
},
|
||||
)
|
||||
messages = await _collect_messages(result)
|
||||
|
||||
tool_result = next(
|
||||
message
|
||||
for message in messages
|
||||
if is_sdk_user_message(message)
|
||||
and isinstance(message["message"]["content"], list)
|
||||
)
|
||||
assert _tool_result_error_flag(tool_result) is False
|
||||
await result.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_control_requests_are_rejected(fake_qwen_path: str) -> None:
|
||||
result = query(
|
||||
"request unknown control",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
},
|
||||
)
|
||||
messages = await _collect_messages(result)
|
||||
|
||||
final = next(message for message in messages if is_sdk_result_message(message))
|
||||
assert final["result"] == "unknown-control: request unknown control"
|
||||
await result.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_controls_and_status(fake_qwen_path: str) -> None:
|
||||
release_input = asyncio.Event()
|
||||
|
||||
async def prompts() -> AsyncIterator[SDKUserMessage]:
|
||||
yield {
|
||||
"type": "user",
|
||||
"session_id": VALID_UUID,
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": "first turn",
|
||||
},
|
||||
"parent_tool_use_id": None,
|
||||
}
|
||||
await release_input.wait()
|
||||
|
||||
result = query(
|
||||
prompts(),
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
"session_id": VALID_UUID,
|
||||
},
|
||||
)
|
||||
|
||||
messages: list[dict[str, Any]] = []
|
||||
|
||||
async def consume() -> list[dict[str, Any]]:
|
||||
async for message in result:
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
collector = asyncio.create_task(consume())
|
||||
await _wait_for(lambda: any(is_sdk_result_message(message) for message in messages))
|
||||
|
||||
assert await result.supported_commands() == {
|
||||
"commands": [
|
||||
"initialize",
|
||||
"interrupt",
|
||||
"set_model",
|
||||
"set_permission_mode",
|
||||
]
|
||||
}
|
||||
assert await result.mcp_server_status() == {"servers": []}
|
||||
|
||||
await result.set_model("new-model")
|
||||
await result.set_permission_mode("plan")
|
||||
release_input.set()
|
||||
await collector
|
||||
|
||||
system_messages = [
|
||||
message for message in messages if is_sdk_system_message(message)
|
||||
]
|
||||
assert any(message["model"] == "new-model" for message in system_messages)
|
||||
assert any(message["permission_mode"] == "plan" for message in system_messages)
|
||||
await result.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_id_resume_and_continue(fake_qwen_path: str) -> None:
|
||||
explicit = query(
|
||||
"hello explicit",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
"session_id": VALID_UUID,
|
||||
},
|
||||
)
|
||||
explicit_messages = await _collect_messages(explicit)
|
||||
assert explicit.get_session_id() == VALID_UUID
|
||||
assert all(message["session_id"] == VALID_UUID for message in explicit_messages)
|
||||
await explicit.close()
|
||||
|
||||
resumed = query(
|
||||
"hello resume",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
"resume": RESUME_UUID,
|
||||
},
|
||||
)
|
||||
resumed_messages = await _collect_messages(resumed)
|
||||
assert resumed.get_session_id() == RESUME_UUID
|
||||
assert all(message["session_id"] == RESUME_UUID for message in resumed_messages)
|
||||
await resumed.close()
|
||||
|
||||
continued = query(
|
||||
"hello continue",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
"continue_session": True,
|
||||
},
|
||||
)
|
||||
continued_messages = await _collect_messages(continued)
|
||||
assert continued.get_session_id() == CONTINUED_SESSION_ID
|
||||
assert any(
|
||||
message["session_id"] == CONTINUED_SESSION_ID for message in continued_messages
|
||||
)
|
||||
await continued.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_zero_process_exit_is_propagated(fake_qwen_path: str) -> None:
|
||||
result = query(
|
||||
"please exit nonzero",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(ProcessExitError, match="code 9"):
|
||||
await _collect_messages(result)
|
||||
|
||||
await result.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_context_manager(fake_qwen_path: str) -> None:
|
||||
async with query(
|
||||
"hello context",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
},
|
||||
) as result:
|
||||
messages = await _collect_messages(result)
|
||||
|
||||
assert result.is_closed()
|
||||
final = next(m for m in messages if is_sdk_result_message(m))
|
||||
assert final["result"] == "done: hello context"
|
||||
82
packages/sdk-python/tests/integration/test_sync_query.py
Normal file
82
packages/sdk-python/tests/integration/test_sync_query.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import qwen_code_sdk.sync_query as sync_query_module
|
||||
from qwen_code_sdk import is_sdk_result_message, query_sync
|
||||
from qwen_code_sdk.sync_query import SyncQuery
|
||||
|
||||
|
||||
def test_sync_query_single_turn(fake_qwen_path: str) -> None:
|
||||
result = query_sync(
|
||||
"hello sync",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
},
|
||||
)
|
||||
|
||||
commands = result.supported_commands()
|
||||
messages = list(result)
|
||||
|
||||
assert commands["commands"][0] == "initialize"
|
||||
assert any(
|
||||
is_sdk_result_message(message) and message["result"] == "done: hello sync"
|
||||
for message in messages
|
||||
)
|
||||
|
||||
result.close()
|
||||
result.close()
|
||||
|
||||
|
||||
def test_sync_query_bootstrap_failure_cleans_up_loop_thread(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
def raising_query(*args: object, **kwargs: object) -> object:
|
||||
raise RuntimeError("bootstrap failed")
|
||||
|
||||
monkeypatch.setattr(sync_query_module, "query", raising_query)
|
||||
|
||||
baseline_threads = {
|
||||
thread.ident
|
||||
for thread in threading.enumerate()
|
||||
if thread.name == "qwen-sdk-sync-loop"
|
||||
}
|
||||
|
||||
with pytest.raises(RuntimeError, match="bootstrap failed"):
|
||||
SyncQuery("hello")
|
||||
|
||||
deadline = time.time() + 1.0
|
||||
while time.time() < deadline:
|
||||
active_threads = {
|
||||
thread.ident
|
||||
for thread in threading.enumerate()
|
||||
if thread.name == "qwen-sdk-sync-loop"
|
||||
}
|
||||
if active_threads == baseline_threads:
|
||||
break
|
||||
time.sleep(0.01)
|
||||
|
||||
active_threads = {
|
||||
thread.ident
|
||||
for thread in threading.enumerate()
|
||||
if thread.name == "qwen-sdk-sync-loop"
|
||||
}
|
||||
assert active_threads == baseline_threads
|
||||
|
||||
|
||||
def test_sync_query_context_manager(fake_qwen_path: str) -> None:
|
||||
with query_sync(
|
||||
"hello context",
|
||||
{
|
||||
"path_to_qwen_executable": fake_qwen_path,
|
||||
},
|
||||
) as result:
|
||||
messages = list(result)
|
||||
assert any(
|
||||
is_sdk_result_message(m) and m["result"] == "done: hello context"
|
||||
for m in messages
|
||||
)
|
||||
|
||||
assert result.is_closed()
|
||||
612
packages/sdk-python/tests/unit/test_query_core.py
Normal file
612
packages/sdk-python/tests/unit/test_query_core.py
Normal file
|
|
@ -0,0 +1,612 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from qwen_code_sdk.errors import AbortError, ControlRequestTimeoutError
|
||||
from qwen_code_sdk.json_lines import parse_json_line
|
||||
from qwen_code_sdk.query import Query
|
||||
from qwen_code_sdk.types import QueryOptions, TimeoutOptions
|
||||
|
||||
VALID_UUID = "123e4567-e89b-12d3-a456-426614174000"
|
||||
_EOF = object()
|
||||
|
||||
|
||||
class FakeTransport:
|
||||
def __init__(self) -> None:
|
||||
self.writes: list[dict[str, Any]] = []
|
||||
self.exit_error: Exception | None = None
|
||||
self.closed = False
|
||||
self.close_calls = 0
|
||||
self.input_closed = False
|
||||
self._queue: asyncio.Queue[dict[str, Any] | object] = asyncio.Queue()
|
||||
|
||||
async def start(self) -> None:
|
||||
return None
|
||||
|
||||
def write(self, data: str) -> None:
|
||||
self.writes.append(parse_json_line(data))
|
||||
|
||||
async def drain(self) -> None:
|
||||
return None
|
||||
|
||||
def end_input(self) -> None:
|
||||
self.input_closed = True
|
||||
|
||||
async def read_messages(self): # type: ignore[no-untyped-def]
|
||||
while True:
|
||||
item = await self._queue.get()
|
||||
if item is _EOF:
|
||||
break
|
||||
yield item
|
||||
|
||||
async def close(self) -> None:
|
||||
self.closed = True
|
||||
self.close_calls += 1
|
||||
self.input_closed = True
|
||||
self._queue.put_nowait(_EOF)
|
||||
|
||||
def push(self, payload: dict[str, Any]) -> None:
|
||||
self._queue.put_nowait(payload)
|
||||
|
||||
|
||||
async def _wait_for(predicate: Callable[[], bool], timeout: float = 1.0) -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
deadline = loop.time() + timeout
|
||||
while loop.time() < deadline:
|
||||
if predicate():
|
||||
return
|
||||
await asyncio.sleep(0.01)
|
||||
raise AssertionError("timed out waiting for test condition")
|
||||
|
||||
|
||||
async def _wait_for_request(
|
||||
transport: FakeTransport,
|
||||
subtype: str,
|
||||
timeout: float = 1.0,
|
||||
) -> dict[str, Any]:
|
||||
await _wait_for(
|
||||
lambda: any(
|
||||
payload.get("type") == "control_request"
|
||||
and payload.get("request", {}).get("subtype") == subtype
|
||||
for payload in transport.writes
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
for payload in transport.writes:
|
||||
if (
|
||||
payload.get("type") == "control_request"
|
||||
and payload.get("request", {}).get("subtype") == subtype
|
||||
):
|
||||
return payload
|
||||
raise AssertionError(f"missing control request: {subtype}")
|
||||
|
||||
|
||||
async def _wait_for_control_response(
|
||||
transport: FakeTransport,
|
||||
request_id: str,
|
||||
timeout: float = 1.0,
|
||||
) -> dict[str, Any]:
|
||||
await _wait_for(
|
||||
lambda: any(
|
||||
payload.get("type") == "control_response"
|
||||
and payload.get("response", {}).get("request_id") == request_id
|
||||
for payload in transport.writes
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
for payload in transport.writes:
|
||||
if (
|
||||
payload.get("type") == "control_response"
|
||||
and payload.get("response", {}).get("request_id") == request_id
|
||||
):
|
||||
return payload
|
||||
raise AssertionError(f"missing control response: {request_id}")
|
||||
|
||||
|
||||
async def _start_query(transport: FakeTransport) -> Query:
|
||||
query = Query(
|
||||
transport=transport, # type: ignore[arg-type]
|
||||
options=QueryOptions(
|
||||
timeout=TimeoutOptions(
|
||||
can_use_tool=0.05,
|
||||
control_request=0.05,
|
||||
stream_close=0.05,
|
||||
)
|
||||
),
|
||||
prompt="hello",
|
||||
session_id=VALID_UUID,
|
||||
)
|
||||
await query._ensure_started()
|
||||
|
||||
init_request = await _wait_for_request(transport, "initialize")
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": init_request["request_id"],
|
||||
"response": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
await _wait_for(
|
||||
lambda: any(payload.get("type") == "user" for payload in transport.writes)
|
||||
)
|
||||
return query
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_control_request_returns_error_response() -> None:
|
||||
transport = FakeTransport()
|
||||
query = await _start_query(transport)
|
||||
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_request",
|
||||
"request_id": "unknown-1",
|
||||
"request": {
|
||||
"subtype": "something_new",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
response = await _wait_for_control_response(transport, "unknown-1")
|
||||
|
||||
assert response["response"]["subtype"] == "error"
|
||||
assert "Unknown control request subtype" in response["response"]["error"]
|
||||
await query.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_control_request_times_out() -> None:
|
||||
transport = FakeTransport()
|
||||
query = await _start_query(transport)
|
||||
|
||||
with pytest.raises(ControlRequestTimeoutError, match="supported_commands"):
|
||||
await query.supported_commands()
|
||||
|
||||
await query.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_control_request_cancel_propagates_abort_error() -> None:
|
||||
transport = FakeTransport()
|
||||
query = await _start_query(transport)
|
||||
|
||||
task = asyncio.create_task(query.supported_commands())
|
||||
request = await _wait_for_request(transport, "supported_commands")
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_cancel_request",
|
||||
"request_id": request["request_id"],
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(AbortError, match="Control request cancelled"):
|
||||
await task
|
||||
|
||||
await query.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_incoming_control_request_cancel_does_not_block_router() -> None:
|
||||
transport = FakeTransport()
|
||||
started = asyncio.Event()
|
||||
cancelled = asyncio.Event()
|
||||
captured_cancel_events: list[asyncio.Event] = []
|
||||
|
||||
async def can_use_tool(
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
assert tool_name == "write_file"
|
||||
assert tool_input["path"] == "demo.txt"
|
||||
cancel_event = cast(asyncio.Event, context["cancel_event"])
|
||||
captured_cancel_events.append(cancel_event)
|
||||
started.set()
|
||||
try:
|
||||
await cancel_event.wait()
|
||||
cancelled.set()
|
||||
return {"behavior": "deny", "message": "Cancelled"}
|
||||
except asyncio.CancelledError:
|
||||
if cancel_event.is_set():
|
||||
cancelled.set()
|
||||
raise
|
||||
|
||||
query = Query(
|
||||
transport=transport, # type: ignore[arg-type]
|
||||
options=QueryOptions(
|
||||
can_use_tool=can_use_tool,
|
||||
timeout=TimeoutOptions(
|
||||
can_use_tool=1.0,
|
||||
control_request=0.2,
|
||||
stream_close=0.05,
|
||||
),
|
||||
),
|
||||
prompt="hello",
|
||||
session_id=VALID_UUID,
|
||||
)
|
||||
await query._ensure_started()
|
||||
|
||||
init_request = await _wait_for_request(transport, "initialize")
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": init_request["request_id"],
|
||||
"response": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
await _wait_for(
|
||||
lambda: any(payload.get("type") == "user" for payload in transport.writes)
|
||||
)
|
||||
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_request",
|
||||
"request_id": "incoming-1",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "write_file",
|
||||
"tool_use_id": "tool-1",
|
||||
"input": {"path": "demo.txt", "content": "hello"},
|
||||
"permission_suggestions": [],
|
||||
"blocked_path": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
await _wait_for(lambda: started.is_set())
|
||||
assert captured_cancel_events[0] is not query._cancel_event
|
||||
|
||||
supported_commands_task = asyncio.create_task(query.supported_commands())
|
||||
supported_request = await _wait_for_request(transport, "supported_commands")
|
||||
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_cancel_request",
|
||||
"request_id": "incoming-1",
|
||||
}
|
||||
)
|
||||
await _wait_for(lambda: cancelled.is_set())
|
||||
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": supported_request["request_id"],
|
||||
"response": {"commands": ["supported_commands"]},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert await supported_commands_task == {"commands": ["supported_commands"]}
|
||||
assert all(
|
||||
not (
|
||||
payload.get("type") == "control_response"
|
||||
and payload.get("response", {}).get("request_id") == "incoming-1"
|
||||
)
|
||||
for payload in transport.writes
|
||||
)
|
||||
await query.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_request_passes_blocked_path_to_callback() -> None:
|
||||
transport = FakeTransport()
|
||||
captured_context: dict[str, Any] | None = None
|
||||
|
||||
async def can_use_tool(
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
nonlocal captured_context
|
||||
assert tool_name == "write_file"
|
||||
assert tool_input["path"] == "demo.txt"
|
||||
captured_context = context
|
||||
return {"behavior": "deny", "message": "blocked"}
|
||||
|
||||
query = Query(
|
||||
transport=transport, # type: ignore[arg-type]
|
||||
options=QueryOptions(
|
||||
can_use_tool=can_use_tool,
|
||||
timeout=TimeoutOptions(
|
||||
can_use_tool=1.0,
|
||||
control_request=0.2,
|
||||
stream_close=0.05,
|
||||
),
|
||||
),
|
||||
prompt="hello",
|
||||
session_id=VALID_UUID,
|
||||
)
|
||||
await query._ensure_started()
|
||||
|
||||
init_request = await _wait_for_request(transport, "initialize")
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": init_request["request_id"],
|
||||
"response": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
await _wait_for(
|
||||
lambda: any(payload.get("type") == "user" for payload in transport.writes)
|
||||
)
|
||||
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_request",
|
||||
"request_id": "incoming-2",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "write_file",
|
||||
"tool_use_id": "tool-2",
|
||||
"input": {"path": "demo.txt", "content": "hello"},
|
||||
"permission_suggestions": [],
|
||||
"blocked_path": "/tmp/demo.txt",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
response = await _wait_for_control_response(transport, "incoming-2")
|
||||
|
||||
assert captured_context is not None
|
||||
assert isinstance(captured_context["cancel_event"], asyncio.Event)
|
||||
assert captured_context["suggestions"] == []
|
||||
assert captured_context["blocked_path"] == "/tmp/demo.txt"
|
||||
assert response["response"]["subtype"] == "success"
|
||||
assert response["response"]["response"] == {
|
||||
"behavior": "deny",
|
||||
"message": "blocked",
|
||||
}
|
||||
await query.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_request_cancelled_callback_returns_deny() -> None:
|
||||
transport = FakeTransport()
|
||||
|
||||
async def can_use_tool(
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
assert tool_name == "write_file"
|
||||
assert tool_input["path"] == "demo.txt"
|
||||
assert isinstance(context["cancel_event"], asyncio.Event)
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
query = Query(
|
||||
transport=transport, # type: ignore[arg-type]
|
||||
options=QueryOptions(
|
||||
can_use_tool=can_use_tool,
|
||||
timeout=TimeoutOptions(
|
||||
can_use_tool=1.0,
|
||||
control_request=0.2,
|
||||
stream_close=0.05,
|
||||
),
|
||||
),
|
||||
prompt="hello",
|
||||
session_id=VALID_UUID,
|
||||
)
|
||||
await query._ensure_started()
|
||||
|
||||
init_request = await _wait_for_request(transport, "initialize")
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": init_request["request_id"],
|
||||
"response": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
await _wait_for(
|
||||
lambda: any(payload.get("type") == "user" for payload in transport.writes)
|
||||
)
|
||||
|
||||
transport.push(
|
||||
{
|
||||
"type": "control_request",
|
||||
"request_id": "incoming-3",
|
||||
"request": {
|
||||
"subtype": "can_use_tool",
|
||||
"tool_name": "write_file",
|
||||
"tool_use_id": "tool-3",
|
||||
"input": {"path": "demo.txt", "content": "hello"},
|
||||
"permission_suggestions": [],
|
||||
"blocked_path": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
response = await _wait_for_control_response(transport, "incoming-3")
|
||||
|
||||
assert response["response"]["subtype"] == "success"
|
||||
assert response["response"]["response"] == {
|
||||
"behavior": "deny",
|
||||
"message": "Permission check failed: callback cancelled",
|
||||
}
|
||||
await query.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finish_with_error_closes_transport_and_fails_pending_requests() -> None:
|
||||
transport = FakeTransport()
|
||||
query = await _start_query(transport)
|
||||
|
||||
supported_commands_task = asyncio.create_task(query.supported_commands())
|
||||
await _wait_for_request(transport, "supported_commands")
|
||||
|
||||
await query._finish_with_error(RuntimeError("boom"))
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await supported_commands_task
|
||||
|
||||
assert query.is_closed() is True
|
||||
assert transport.closed is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_started_raises_after_close() -> None:
|
||||
transport = FakeTransport()
|
||||
query = await _start_query(transport)
|
||||
|
||||
await query.close()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Query is closed"):
|
||||
await query.supported_commands()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anext_after_exhaustion_raises_stop_async_iteration() -> None:
|
||||
"""After the async iterator is exhausted, subsequent __anext__ calls must
|
||||
raise StopAsyncIteration immediately instead of blocking."""
|
||||
transport = FakeTransport()
|
||||
query = await _start_query(transport)
|
||||
|
||||
# Deliver one assistant message, then a result to end the turn.
|
||||
transport.push(
|
||||
{
|
||||
"type": "assistant",
|
||||
"session_id": VALID_UUID,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hi"}],
|
||||
},
|
||||
}
|
||||
)
|
||||
transport.push(
|
||||
{
|
||||
"type": "result",
|
||||
"session_id": VALID_UUID,
|
||||
"result": "done",
|
||||
"is_error": False,
|
||||
"duration_ms": 10,
|
||||
"duration_api_ms": 5,
|
||||
"num_turns": 1,
|
||||
}
|
||||
)
|
||||
# Signal end of transport stream so the router finishes naturally.
|
||||
await transport.close()
|
||||
|
||||
# Consume all messages until exhaustion.
|
||||
messages: list[Any] = []
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
while True:
|
||||
messages.append(await query.__anext__())
|
||||
|
||||
assert len(messages) >= 1
|
||||
|
||||
# The iterator is now exhausted — a second call must raise immediately.
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await query.__anext__()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_failure_no_unhandled_task_exception(
|
||||
recwarn: pytest.WarningsChecker,
|
||||
) -> None:
|
||||
"""When _initialize fails, no 'Task exception was never retrieved' warning
|
||||
should appear — _finish_with_error already surfaces the error."""
|
||||
transport = FakeTransport()
|
||||
query = Query(
|
||||
transport=transport, # type: ignore[arg-type]
|
||||
options=QueryOptions(
|
||||
timeout=TimeoutOptions(
|
||||
can_use_tool=0.05,
|
||||
control_request=0.05,
|
||||
stream_close=0.05,
|
||||
)
|
||||
),
|
||||
prompt="hello",
|
||||
session_id=VALID_UUID,
|
||||
)
|
||||
await query._ensure_started()
|
||||
|
||||
# Let the initialize request time out — this triggers _finish_with_error
|
||||
# inside _initialize.
|
||||
init_request = await _wait_for_request(transport, "initialize")
|
||||
assert init_request is not None # init was sent
|
||||
|
||||
# Don't respond to initialize — let the control-request timeout fire.
|
||||
# The error propagates through _message_queue.
|
||||
with pytest.raises(ControlRequestTimeoutError):
|
||||
await query.__anext__()
|
||||
|
||||
await query.close()
|
||||
|
||||
# Give the event loop a moment to report any unhandled task exceptions.
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# No "Task exception was never retrieved" warnings should have appeared.
|
||||
task_warnings = [w for w in recwarn.list if "never retrieved" in str(w.message)]
|
||||
assert task_warnings == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_context_manager_closes_on_exit() -> None:
|
||||
transport = FakeTransport()
|
||||
query = Query(
|
||||
transport=transport, # type: ignore[arg-type]
|
||||
options=QueryOptions(
|
||||
timeout=TimeoutOptions(
|
||||
can_use_tool=0.05,
|
||||
control_request=0.05,
|
||||
stream_close=0.05,
|
||||
)
|
||||
),
|
||||
prompt="hello",
|
||||
session_id=VALID_UUID,
|
||||
)
|
||||
|
||||
async with query as q:
|
||||
assert q is query
|
||||
assert not q.is_closed()
|
||||
|
||||
assert query.is_closed() is True
|
||||
|
||||
|
||||
def test_sync_next_after_exhaustion_raises_stop_iteration() -> None:
|
||||
"""After the sync iterator is exhausted, subsequent __next__ calls must
|
||||
raise StopIteration immediately instead of blocking on queue.get()."""
|
||||
from queue import Queue
|
||||
|
||||
from qwen_code_sdk.sync_query import _STOP, SyncQuery
|
||||
|
||||
# Build a minimal SyncQuery without spawning the real event-loop thread.
|
||||
sq = object.__new__(SyncQuery)
|
||||
sq._queue = Queue()
|
||||
sq._exhausted = False
|
||||
|
||||
# Put one message then the sentinel.
|
||||
msg_payload = {
|
||||
"type": "assistant",
|
||||
"message": {"role": "assistant", "content": []},
|
||||
}
|
||||
sq._queue.put(msg_payload)
|
||||
sq._queue.put(_STOP)
|
||||
|
||||
# First call returns the message.
|
||||
msg = next(sq)
|
||||
assert msg["type"] == "assistant"
|
||||
|
||||
# Second call should exhaust.
|
||||
with pytest.raises(StopIteration):
|
||||
next(sq)
|
||||
|
||||
# Third call must raise immediately, not block.
|
||||
with pytest.raises(StopIteration):
|
||||
next(sq)
|
||||
291
packages/sdk-python/tests/unit/test_transport.py
Normal file
291
packages/sdk-python/tests/unit/test_transport.py
Normal file
|
|
@ -0,0 +1,291 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from qwen_code_sdk.transport import build_cli_arguments, prepare_spawn_info
|
||||
from qwen_code_sdk.types import QueryOptions, TimeoutOptions
|
||||
|
||||
VALID_UUID = "123e4567-e89b-12d3-a456-426614174000"
|
||||
|
||||
|
||||
class DummyProcess:
|
||||
def __init__(self) -> None:
|
||||
self.stdin = None
|
||||
self.stdout = None
|
||||
self.stderr = None
|
||||
self.returncode = 0
|
||||
|
||||
|
||||
def test_build_cli_arguments_maps_supported_options() -> None:
|
||||
args = build_cli_arguments(
|
||||
QueryOptions(
|
||||
model="qwen3-coder",
|
||||
system_prompt="system prompt",
|
||||
append_system_prompt="append prompt",
|
||||
permission_mode="auto-edit",
|
||||
max_session_turns=7,
|
||||
core_tools=["Read", "Edit"],
|
||||
exclude_tools=["Bash(rm *)"],
|
||||
allowed_tools=["Bash(git status)"],
|
||||
auth_type="openai",
|
||||
include_partial_messages=True,
|
||||
session_id=VALID_UUID,
|
||||
)
|
||||
)
|
||||
|
||||
assert args == [
|
||||
"--input-format",
|
||||
"stream-json",
|
||||
"--output-format",
|
||||
"stream-json",
|
||||
"--channel=SDK",
|
||||
"--model",
|
||||
"qwen3-coder",
|
||||
"--system-prompt",
|
||||
"system prompt",
|
||||
"--append-system-prompt",
|
||||
"append prompt",
|
||||
"--approval-mode",
|
||||
"auto-edit",
|
||||
"--max-session-turns",
|
||||
"7",
|
||||
"--core-tools",
|
||||
"Read,Edit",
|
||||
"--exclude-tools",
|
||||
"Bash(rm *)",
|
||||
"--allowed-tools",
|
||||
"Bash(git status)",
|
||||
"--auth-type",
|
||||
"openai",
|
||||
"--include-partial-messages",
|
||||
"--session-id",
|
||||
VALID_UUID,
|
||||
]
|
||||
|
||||
|
||||
def test_cli_argument_precedence_prefers_resume_then_continue_then_session_id() -> None:
|
||||
args = build_cli_arguments(
|
||||
QueryOptions(
|
||||
resume=VALID_UUID,
|
||||
continue_session=True,
|
||||
session_id="223e4567-e89b-12d3-a456-426614174000",
|
||||
)
|
||||
)
|
||||
|
||||
assert "--resume" in args
|
||||
assert "--continue" not in args
|
||||
assert "--session-id" not in args
|
||||
|
||||
|
||||
def test_prepare_spawn_info_uses_runtime_for_python_scripts(tmp_path: Path) -> None:
|
||||
script_path = tmp_path / "fake-qwen.py"
|
||||
script_path.write_text("print('ok')\n", encoding="utf-8")
|
||||
|
||||
spawn_info = prepare_spawn_info(str(script_path))
|
||||
|
||||
assert spawn_info.command == sys.executable
|
||||
assert spawn_info.args == [str(script_path.resolve())]
|
||||
|
||||
|
||||
def test_prepare_spawn_info_uses_node_for_javascript_files(tmp_path: Path) -> None:
|
||||
script_path = tmp_path / "fake-qwen.js"
|
||||
script_path.write_text("console.log('ok');\n", encoding="utf-8")
|
||||
|
||||
spawn_info = prepare_spawn_info(str(script_path))
|
||||
|
||||
assert spawn_info.command == "node"
|
||||
assert spawn_info.args == [str(script_path.resolve())]
|
||||
|
||||
|
||||
def test_prepare_spawn_info_keeps_plain_command_names() -> None:
|
||||
spawn_info = prepare_spawn_info("qwen-custom")
|
||||
|
||||
assert spawn_info.command == "qwen-custom"
|
||||
assert spawn_info.args == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_discards_stderr_when_debug_is_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def fake_create_subprocess_exec(*args: Any, **kwargs: Any) -> DummyProcess:
|
||||
captured["args"] = args
|
||||
captured["kwargs"] = kwargs
|
||||
return DummyProcess()
|
||||
|
||||
monkeypatch.setattr(
|
||||
asyncio,
|
||||
"create_subprocess_exec",
|
||||
fake_create_subprocess_exec,
|
||||
)
|
||||
|
||||
transport_module = __import__(
|
||||
"qwen_code_sdk.transport",
|
||||
fromlist=["ProcessTransport"],
|
||||
)
|
||||
transport = transport_module.ProcessTransport(
|
||||
QueryOptions(timeout=TimeoutOptions())
|
||||
)
|
||||
|
||||
await transport.start()
|
||||
|
||||
assert captured["kwargs"]["stderr"] is subprocess.DEVNULL
|
||||
|
||||
|
||||
def test_prepare_spawn_info_defaults_to_qwen_when_none() -> None:
|
||||
spawn_info = prepare_spawn_info(None)
|
||||
|
||||
assert spawn_info.command == "qwen"
|
||||
assert spawn_info.args == []
|
||||
|
||||
|
||||
def test_prepare_spawn_info_uses_node_for_mjs_files(tmp_path: Path) -> None:
|
||||
script_path = tmp_path / "cli.mjs"
|
||||
script_path.write_text("export default {};\n", encoding="utf-8")
|
||||
|
||||
spawn_info = prepare_spawn_info(str(script_path))
|
||||
|
||||
assert spawn_info.command == "node"
|
||||
assert spawn_info.args == [str(script_path.resolve())]
|
||||
|
||||
|
||||
def test_prepare_spawn_info_uses_node_for_cjs_files(tmp_path: Path) -> None:
|
||||
script_path = tmp_path / "cli.cjs"
|
||||
script_path.write_text("module.exports = {};\n", encoding="utf-8")
|
||||
|
||||
spawn_info = prepare_spawn_info(str(script_path))
|
||||
|
||||
assert spawn_info.command == "node"
|
||||
assert spawn_info.args == [str(script_path.resolve())]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_start_raises_after_close(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
async def fake_create_subprocess_exec(*args: Any, **kwargs: Any) -> DummyProcess:
|
||||
return DummyProcess()
|
||||
|
||||
monkeypatch.setattr(
|
||||
asyncio,
|
||||
"create_subprocess_exec",
|
||||
fake_create_subprocess_exec,
|
||||
)
|
||||
|
||||
transport_module = __import__(
|
||||
"qwen_code_sdk.transport",
|
||||
fromlist=["ProcessTransport"],
|
||||
)
|
||||
transport = transport_module.ProcessTransport(
|
||||
QueryOptions(timeout=TimeoutOptions())
|
||||
)
|
||||
transport._closed = True
|
||||
|
||||
with pytest.raises(RuntimeError, match="Transport is closed"):
|
||||
await transport.start()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_messages_skips_malformed_json_lines() -> None:
|
||||
"""Malformed JSON lines should be skipped, not crash the stream."""
|
||||
|
||||
class FakeStdout:
|
||||
def __init__(self, lines: list[bytes]) -> None:
|
||||
self._lines = iter(lines)
|
||||
|
||||
async def readline(self) -> bytes:
|
||||
return next(self._lines, b"")
|
||||
|
||||
transport_module = __import__(
|
||||
"qwen_code_sdk.transport",
|
||||
fromlist=["ProcessTransport"],
|
||||
)
|
||||
transport = transport_module.ProcessTransport(
|
||||
QueryOptions(timeout=TimeoutOptions())
|
||||
)
|
||||
|
||||
class FakeProcess:
|
||||
returncode = 0
|
||||
stdin = None
|
||||
stderr = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.stdout = FakeStdout(
|
||||
[
|
||||
b"not valid json\n",
|
||||
b'{"type":"system","subtype":"init","uuid":"u","session_id":"s"}\n',
|
||||
b"also bad\n",
|
||||
b"",
|
||||
]
|
||||
)
|
||||
|
||||
async def wait(self) -> int:
|
||||
return 0
|
||||
|
||||
transport._process = FakeProcess()
|
||||
|
||||
messages: list[Any] = []
|
||||
async for msg in transport.read_messages():
|
||||
messages.append(msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["type"] == "system"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stderr_callback_exceptions_do_not_fail_transport() -> None:
|
||||
class FakeStdout:
|
||||
async def readline(self) -> bytes:
|
||||
return b""
|
||||
|
||||
class FakeStderr:
|
||||
def __init__(self) -> None:
|
||||
self._lines = iter([b"error message\n", b""])
|
||||
|
||||
async def readline(self) -> bytes:
|
||||
return next(self._lines, b"")
|
||||
|
||||
transport_module = __import__(
|
||||
"qwen_code_sdk.transport",
|
||||
fromlist=["ProcessTransport"],
|
||||
)
|
||||
|
||||
callback_calls = 0
|
||||
|
||||
def stderr_callback(text: str) -> None:
|
||||
nonlocal callback_calls
|
||||
callback_calls += 1
|
||||
assert text == "error message"
|
||||
raise RuntimeError("sink failed")
|
||||
|
||||
transport = transport_module.ProcessTransport(
|
||||
QueryOptions(
|
||||
stderr=stderr_callback,
|
||||
timeout=TimeoutOptions(),
|
||||
)
|
||||
)
|
||||
|
||||
class FakeProcess:
|
||||
returncode = 0
|
||||
stdin = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.stdout = FakeStdout()
|
||||
self.stderr = FakeStderr()
|
||||
|
||||
async def wait(self) -> int:
|
||||
return 0
|
||||
|
||||
transport._process = FakeProcess()
|
||||
transport._stderr_task = asyncio.create_task(transport._forward_stderr())
|
||||
|
||||
await transport.wait_for_exit()
|
||||
|
||||
assert callback_calls == 1
|
||||
174
packages/sdk-python/tests/unit/test_validation.py
Normal file
174
packages/sdk-python/tests/unit/test_validation.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from qwen_code_sdk.errors import ValidationError
|
||||
from qwen_code_sdk.types import QueryOptions, TimeoutOptions
|
||||
from qwen_code_sdk.validation import validate_query_options
|
||||
|
||||
VALID_UUID = "123e4567-e89b-12d3-a456-426614174000"
|
||||
|
||||
|
||||
def test_rejects_resume_with_continue_session() -> None:
|
||||
with pytest.raises(ValidationError, match="resume together with continue_session"):
|
||||
validate_query_options(
|
||||
QueryOptions(
|
||||
resume=VALID_UUID,
|
||||
continue_session=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_rejects_session_id_with_resume() -> None:
|
||||
with pytest.raises(ValidationError, match="Cannot use session_id with resume"):
|
||||
validate_query_options(
|
||||
QueryOptions(
|
||||
session_id=VALID_UUID,
|
||||
resume="223e4567-e89b-12d3-a456-426614174000",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_rejects_invalid_session_id() -> None:
|
||||
with pytest.raises(ValidationError, match="Invalid session_id"):
|
||||
validate_query_options(QueryOptions(session_id="not-a-uuid"))
|
||||
|
||||
|
||||
def test_rejects_invalid_resume() -> None:
|
||||
with pytest.raises(ValidationError, match="Invalid resume"):
|
||||
validate_query_options(QueryOptions(resume="not-a-uuid"))
|
||||
|
||||
|
||||
def test_rejects_invalid_permission_mode() -> None:
|
||||
with pytest.raises(ValidationError, match="Invalid permission_mode"):
|
||||
validate_query_options(
|
||||
QueryOptions.from_mapping({"permission_mode": "unsafe-mode"})
|
||||
)
|
||||
|
||||
|
||||
def test_rejects_invalid_auth_type() -> None:
|
||||
with pytest.raises(ValidationError, match="Invalid auth_type"):
|
||||
validate_query_options(QueryOptions.from_mapping({"auth_type": "custom"}))
|
||||
|
||||
|
||||
def test_from_mapping_rejects_non_callable_can_use_tool() -> None:
|
||||
with pytest.raises(TypeError, match="can_use_tool must be callable"):
|
||||
QueryOptions.from_mapping({"can_use_tool": "bad"})
|
||||
|
||||
|
||||
def test_from_mapping_rejects_non_callable_stderr() -> None:
|
||||
with pytest.raises(TypeError, match="stderr must be callable"):
|
||||
QueryOptions.from_mapping({"stderr": "bad"})
|
||||
|
||||
|
||||
def test_validation_rejects_non_callable_can_use_tool() -> None:
|
||||
with pytest.raises(ValidationError, match="can_use_tool must be callable"):
|
||||
validate_query_options(QueryOptions(can_use_tool=cast(Any, "bad")))
|
||||
|
||||
|
||||
def test_validation_rejects_non_callable_stderr() -> None:
|
||||
with pytest.raises(ValidationError, match="stderr must be callable"):
|
||||
validate_query_options(QueryOptions(stderr=cast(Any, "bad")))
|
||||
|
||||
|
||||
def test_from_mapping_rejects_sync_can_use_tool() -> None:
|
||||
def can_use_tool( # type: ignore[no-untyped-def]
|
||||
tool_name, tool_input, context
|
||||
):
|
||||
return {"behavior": "deny", "message": "bad"}
|
||||
|
||||
with pytest.raises(TypeError, match="can_use_tool must be an async callable"):
|
||||
QueryOptions.from_mapping({"can_use_tool": can_use_tool})
|
||||
|
||||
|
||||
def test_validation_rejects_sync_can_use_tool() -> None:
|
||||
def can_use_tool( # type: ignore[no-untyped-def]
|
||||
tool_name, tool_input, context
|
||||
):
|
||||
return {"behavior": "deny", "message": "bad"}
|
||||
|
||||
with pytest.raises(ValidationError, match="can_use_tool must be an async callable"):
|
||||
validate_query_options(QueryOptions(can_use_tool=cast(Any, can_use_tool)))
|
||||
|
||||
|
||||
def test_from_mapping_rejects_can_use_tool_with_wrong_arity() -> None:
|
||||
async def can_use_tool(
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
) -> dict[str, str]:
|
||||
return {"behavior": "deny"}
|
||||
|
||||
with pytest.raises(
|
||||
TypeError, match="can_use_tool must accept exactly 3 positional arguments"
|
||||
):
|
||||
QueryOptions.from_mapping({"can_use_tool": can_use_tool})
|
||||
|
||||
|
||||
def test_validation_rejects_can_use_tool_with_wrong_arity() -> None:
|
||||
async def can_use_tool(
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
) -> dict[str, str]:
|
||||
return {"behavior": "deny"}
|
||||
|
||||
with pytest.raises(
|
||||
ValidationError,
|
||||
match="can_use_tool must accept exactly 3 positional arguments",
|
||||
):
|
||||
validate_query_options(QueryOptions(can_use_tool=cast(Any, can_use_tool)))
|
||||
|
||||
|
||||
def test_from_mapping_rejects_stderr_with_wrong_arity() -> None:
|
||||
def stderr() -> None:
|
||||
return None
|
||||
|
||||
with pytest.raises(
|
||||
TypeError, match="stderr must accept exactly 1 positional argument"
|
||||
):
|
||||
QueryOptions.from_mapping({"stderr": stderr})
|
||||
|
||||
|
||||
def test_validation_rejects_stderr_with_wrong_arity() -> None:
|
||||
def stderr() -> None:
|
||||
return None
|
||||
|
||||
with pytest.raises(
|
||||
ValidationError, match="stderr must accept exactly 1 positional argument"
|
||||
):
|
||||
validate_query_options(QueryOptions(stderr=cast(Any, stderr)))
|
||||
|
||||
|
||||
def test_rejects_invalid_max_session_turns() -> None:
|
||||
with pytest.raises(ValidationError, match="max_session_turns"):
|
||||
validate_query_options(QueryOptions(max_session_turns=-2))
|
||||
|
||||
|
||||
def test_rejects_empty_qwen_executable_path() -> None:
|
||||
with pytest.raises(
|
||||
ValidationError, match="path_to_qwen_executable cannot be empty"
|
||||
):
|
||||
validate_query_options(QueryOptions(path_to_qwen_executable=" "))
|
||||
|
||||
|
||||
def test_timeout_rejects_non_numeric_value() -> None:
|
||||
with pytest.raises(TypeError, match=r"timeout\.can_use_tool must be a positive"):
|
||||
TimeoutOptions.from_mapping({"can_use_tool": "fast"})
|
||||
|
||||
|
||||
def test_timeout_rejects_negative_value() -> None:
|
||||
pattern = r"timeout\.control_request must be a positive"
|
||||
with pytest.raises(ValueError, match=pattern):
|
||||
TimeoutOptions.from_mapping({"control_request": -1})
|
||||
|
||||
|
||||
def test_timeout_rejects_boolean_value() -> None:
|
||||
with pytest.raises(TypeError, match=r"timeout\.stream_close must be a positive"):
|
||||
TimeoutOptions.from_mapping({"stream_close": True})
|
||||
|
||||
|
||||
def test_rejects_mcp_servers() -> None:
|
||||
with pytest.raises(ValidationError, match="mcp_servers is not supported"):
|
||||
validate_query_options(
|
||||
QueryOptions(mcp_servers={"my-server": {"command": "node", "args": []}})
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue