mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 11:49:51 +00:00
parent
e7d277d163
commit
d8046e1bb4
65 changed files with 12111 additions and 2502 deletions
3
kt-kernel/python/cli/commands/__init__.py
Normal file
3
kt-kernel/python/cli/commands/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
Command modules for kt-cli.
|
||||
"""
|
||||
274
kt-kernel/python/cli/commands/bench.py
Normal file
274
kt-kernel/python/cli/commands/bench.py
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
"""
|
||||
Bench commands for kt-cli.
|
||||
|
||||
Runs benchmarks for performance testing.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import (
|
||||
console,
|
||||
print_error,
|
||||
print_info,
|
||||
print_step,
|
||||
print_success,
|
||||
)
|
||||
|
||||
|
||||
class BenchType(str, Enum):
|
||||
"""Benchmark type."""
|
||||
|
||||
INFERENCE = "inference"
|
||||
MLA = "mla"
|
||||
MOE = "moe"
|
||||
LINEAR = "linear"
|
||||
ATTENTION = "attention"
|
||||
ALL = "all"
|
||||
|
||||
|
||||
def bench(
|
||||
type: BenchType = typer.Option(
|
||||
BenchType.ALL,
|
||||
"--type",
|
||||
"-t",
|
||||
help="Benchmark type",
|
||||
),
|
||||
model: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--model",
|
||||
"-m",
|
||||
help="Model to benchmark",
|
||||
),
|
||||
output: Optional[Path] = typer.Option(
|
||||
None,
|
||||
"--output",
|
||||
"-o",
|
||||
help="Output file for results (JSON)",
|
||||
),
|
||||
iterations: int = typer.Option(
|
||||
10,
|
||||
"--iterations",
|
||||
"-n",
|
||||
help="Number of iterations",
|
||||
),
|
||||
) -> None:
|
||||
"""Run full benchmark suite."""
|
||||
console.print()
|
||||
print_step(t("bench_starting"))
|
||||
print_info(t("bench_type", type=type.value))
|
||||
console.print()
|
||||
|
||||
if type == BenchType.ALL:
|
||||
_run_all_benchmarks(model, output, iterations)
|
||||
elif type == BenchType.INFERENCE:
|
||||
_run_inference_benchmark(model, output, iterations)
|
||||
elif type == BenchType.MLA:
|
||||
_run_component_benchmark("mla", output, iterations)
|
||||
elif type == BenchType.MOE:
|
||||
_run_component_benchmark("moe", output, iterations)
|
||||
elif type == BenchType.LINEAR:
|
||||
_run_component_benchmark("linear", output, iterations)
|
||||
elif type == BenchType.ATTENTION:
|
||||
_run_component_benchmark("attention", output, iterations)
|
||||
|
||||
console.print()
|
||||
print_success(t("bench_complete"))
|
||||
if output:
|
||||
console.print(f" Results saved to: {output}")
|
||||
console.print()
|
||||
|
||||
|
||||
def microbench(
|
||||
component: str = typer.Argument(
|
||||
"moe",
|
||||
help="Component to benchmark (moe, mla, linear, attention)",
|
||||
),
|
||||
batch_size: int = typer.Option(
|
||||
1,
|
||||
"--batch-size",
|
||||
"-b",
|
||||
help="Batch size",
|
||||
),
|
||||
seq_len: int = typer.Option(
|
||||
1,
|
||||
"--seq-len",
|
||||
"-s",
|
||||
help="Sequence length",
|
||||
),
|
||||
iterations: int = typer.Option(
|
||||
100,
|
||||
"--iterations",
|
||||
"-n",
|
||||
help="Number of iterations",
|
||||
),
|
||||
warmup: int = typer.Option(
|
||||
10,
|
||||
"--warmup",
|
||||
"-w",
|
||||
help="Warmup iterations",
|
||||
),
|
||||
output: Optional[Path] = typer.Option(
|
||||
None,
|
||||
"--output",
|
||||
"-o",
|
||||
help="Output file for results (JSON)",
|
||||
),
|
||||
) -> None:
|
||||
"""Run micro-benchmark for specific components."""
|
||||
console.print()
|
||||
console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]")
|
||||
console.print()
|
||||
raise typer.Exit(0)
|
||||
|
||||
# Try to find the benchmark script
|
||||
kt_kernel_path = _find_kt_kernel_path()
|
||||
|
||||
if kt_kernel_path is None:
|
||||
print_error("kt-kernel not found. Install with: kt install inference")
|
||||
raise typer.Exit(1)
|
||||
|
||||
bench_dir = kt_kernel_path / "bench"
|
||||
|
||||
# Map component to script
|
||||
component_scripts = {
|
||||
"moe": "bench_moe.py",
|
||||
"mla": "bench_mla.py",
|
||||
"linear": "bench_linear.py",
|
||||
"attention": "bench_attention.py",
|
||||
"mlp": "bench_mlp.py",
|
||||
}
|
||||
|
||||
script_name = component_scripts.get(component.lower())
|
||||
if script_name is None:
|
||||
print_error(f"Unknown component: {component}")
|
||||
console.print(f"Available: {', '.join(component_scripts.keys())}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
script_path = bench_dir / script_name
|
||||
if not script_path.exists():
|
||||
print_error(f"Benchmark script not found: {script_path}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Run benchmark
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(script_path),
|
||||
"--batch-size",
|
||||
str(batch_size),
|
||||
"--seq-len",
|
||||
str(seq_len),
|
||||
"--iterations",
|
||||
str(iterations),
|
||||
"--warmup",
|
||||
str(warmup),
|
||||
]
|
||||
|
||||
if output:
|
||||
cmd.extend(["--output", str(output)])
|
||||
|
||||
console.print(f"[dim]$ {' '.join(cmd)}[/dim]")
|
||||
console.print()
|
||||
|
||||
try:
|
||||
process = subprocess.run(cmd)
|
||||
|
||||
if process.returncode == 0:
|
||||
console.print()
|
||||
print_success(t("bench_complete"))
|
||||
if output:
|
||||
console.print(f" Results saved to: {output}")
|
||||
else:
|
||||
print_error(f"Benchmark failed with exit code {process.returncode}")
|
||||
raise typer.Exit(process.returncode)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print_error(f"Failed to run benchmark: {e}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def _find_kt_kernel_path() -> Optional[Path]:
|
||||
"""Find the kt-kernel installation path."""
|
||||
try:
|
||||
import kt_kernel
|
||||
|
||||
return Path(kt_kernel.__file__).parent.parent
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Check common locations
|
||||
possible_paths = [
|
||||
Path.home() / "Projects" / "ktransformers" / "kt-kernel",
|
||||
Path("/opt/ktransformers/kt-kernel"),
|
||||
Path.cwd() / "kt-kernel",
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
if path.exists() and (path / "bench").exists():
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _run_all_benchmarks(model: Optional[str], output: Optional[Path], iterations: int) -> None:
|
||||
"""Run all benchmarks."""
|
||||
components = ["moe", "mla", "linear", "attention"]
|
||||
|
||||
for component in components:
|
||||
console.print(f"\n[bold]Running {component} benchmark...[/bold]")
|
||||
_run_component_benchmark(component, None, iterations)
|
||||
|
||||
|
||||
def _run_inference_benchmark(model: Optional[str], output: Optional[Path], iterations: int) -> None:
|
||||
"""Run inference benchmark."""
|
||||
if model is None:
|
||||
print_error("Model required for inference benchmark. Use --model flag.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
print_info(f"Running inference benchmark on {model}...")
|
||||
console.print()
|
||||
console.print("[dim]This will start the server and run test requests.[/dim]")
|
||||
console.print()
|
||||
|
||||
# TODO: Implement actual inference benchmarking
|
||||
print_error("Inference benchmarking not yet implemented.")
|
||||
|
||||
|
||||
def _run_component_benchmark(component: str, output: Optional[Path], iterations: int) -> None:
|
||||
"""Run a component benchmark."""
|
||||
kt_kernel_path = _find_kt_kernel_path()
|
||||
|
||||
if kt_kernel_path is None:
|
||||
print_error("kt-kernel not found.")
|
||||
return
|
||||
|
||||
bench_dir = kt_kernel_path / "bench"
|
||||
script_map = {
|
||||
"moe": "bench_moe.py",
|
||||
"mla": "bench_mla.py",
|
||||
"linear": "bench_linear.py",
|
||||
"attention": "bench_attention.py",
|
||||
}
|
||||
|
||||
script_name = script_map.get(component)
|
||||
if script_name is None:
|
||||
print_error(f"Unknown component: {component}")
|
||||
return
|
||||
|
||||
script_path = bench_dir / script_name
|
||||
if not script_path.exists():
|
||||
print_error(f"Script not found: {script_path}")
|
||||
return
|
||||
|
||||
cmd = [sys.executable, str(script_path), "--iterations", str(iterations)]
|
||||
|
||||
try:
|
||||
subprocess.run(cmd)
|
||||
except Exception as e:
|
||||
print_error(f"Benchmark failed: {e}")
|
||||
437
kt-kernel/python/cli/commands/chat.py
Normal file
437
kt-kernel/python/cli/commands/chat.py
Normal file
|
|
@ -0,0 +1,437 @@
|
|||
"""
|
||||
Chat command for kt-cli.
|
||||
|
||||
Provides interactive chat interface with running model server.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt, Confirm
|
||||
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import (
|
||||
console,
|
||||
print_error,
|
||||
print_info,
|
||||
print_success,
|
||||
print_warning,
|
||||
)
|
||||
|
||||
# Try to import OpenAI SDK
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
HAS_OPENAI = True
|
||||
except ImportError:
|
||||
HAS_OPENAI = False
|
||||
|
||||
|
||||
def chat(
|
||||
host: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--host",
|
||||
"-H",
|
||||
help="Server host address",
|
||||
),
|
||||
port: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--port",
|
||||
"-p",
|
||||
help="Server port",
|
||||
),
|
||||
model: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--model",
|
||||
"-m",
|
||||
help="Model name (if server hosts multiple models)",
|
||||
),
|
||||
temperature: float = typer.Option(
|
||||
0.7,
|
||||
"--temperature",
|
||||
"-t",
|
||||
help="Sampling temperature (0.0 to 2.0)",
|
||||
),
|
||||
max_tokens: int = typer.Option(
|
||||
2048,
|
||||
"--max-tokens",
|
||||
help="Maximum tokens to generate",
|
||||
),
|
||||
system_prompt: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--system",
|
||||
"-s",
|
||||
help="System prompt",
|
||||
),
|
||||
save_history: bool = typer.Option(
|
||||
True,
|
||||
"--save-history/--no-save-history",
|
||||
help="Save conversation history",
|
||||
),
|
||||
history_file: Optional[Path] = typer.Option(
|
||||
None,
|
||||
"--history-file",
|
||||
help="Path to save conversation history",
|
||||
),
|
||||
stream: bool = typer.Option(
|
||||
True,
|
||||
"--stream/--no-stream",
|
||||
help="Enable streaming output",
|
||||
),
|
||||
) -> None:
|
||||
"""Start interactive chat with a running model server.
|
||||
|
||||
Examples:
|
||||
kt chat # Connect to default server
|
||||
kt chat --host 127.0.0.1 -p 8080 # Connect to specific server
|
||||
kt chat -t 0.9 --max-tokens 4096 # Adjust generation parameters
|
||||
"""
|
||||
if not HAS_OPENAI:
|
||||
print_error("OpenAI Python SDK is required for chat functionality.")
|
||||
console.print()
|
||||
console.print("Install it with:")
|
||||
console.print(" pip install openai")
|
||||
raise typer.Exit(1)
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# Resolve server connection
|
||||
final_host = host or settings.get("server.host", "127.0.0.1")
|
||||
final_port = port or settings.get("server.port", 30000)
|
||||
|
||||
# Construct base URL for OpenAI-compatible API
|
||||
base_url = f"http://{final_host}:{final_port}/v1"
|
||||
|
||||
console.print()
|
||||
console.print(
|
||||
Panel.fit(
|
||||
f"[bold cyan]KTransformers Chat[/bold cyan]\n\n"
|
||||
f"Server: [yellow]{final_host}:{final_port}[/yellow]\n"
|
||||
f"Temperature: [cyan]{temperature}[/cyan] | Max tokens: [cyan]{max_tokens}[/cyan]\n\n"
|
||||
f"[dim]Type '/help' for commands, '/quit' to exit[/dim]",
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
console.print()
|
||||
|
||||
# Check for proxy environment variables
|
||||
proxy_vars = ["HTTP_PROXY", "HTTPS_PROXY", "http_proxy", "https_proxy", "ALL_PROXY", "all_proxy"]
|
||||
detected_proxies = {var: os.environ.get(var) for var in proxy_vars if os.environ.get(var)}
|
||||
|
||||
if detected_proxies:
|
||||
proxy_info = ", ".join(f"{k}={v}" for k, v in detected_proxies.items())
|
||||
console.print()
|
||||
print_warning(t("chat_proxy_detected"))
|
||||
console.print(f" [dim]{proxy_info}[/dim]")
|
||||
console.print()
|
||||
|
||||
use_proxy = Confirm.ask(t("chat_proxy_confirm"), default=False)
|
||||
|
||||
if not use_proxy:
|
||||
# Temporarily disable proxy for this connection
|
||||
for var in proxy_vars:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
print_info(t("chat_proxy_disabled"))
|
||||
console.print()
|
||||
|
||||
# Initialize OpenAI client
|
||||
try:
|
||||
client = OpenAI(
|
||||
base_url=base_url,
|
||||
api_key="EMPTY", # SGLang doesn't require API key
|
||||
)
|
||||
|
||||
# Test connection
|
||||
print_info("Connecting to server...")
|
||||
models = client.models.list()
|
||||
available_models = [m.id for m in models.data]
|
||||
|
||||
if not available_models:
|
||||
print_error("No models available on server")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Select model
|
||||
if model:
|
||||
if model not in available_models:
|
||||
print_warning(f"Model '{model}' not found. Available models: {', '.join(available_models)}")
|
||||
selected_model = available_models[0]
|
||||
else:
|
||||
selected_model = model
|
||||
else:
|
||||
selected_model = available_models[0]
|
||||
|
||||
print_success(f"Connected to model: {selected_model}")
|
||||
console.print()
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Failed to connect to server: {e}")
|
||||
console.print()
|
||||
console.print("Make sure the model server is running:")
|
||||
console.print(" kt run <model>")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Initialize conversation history
|
||||
messages = []
|
||||
|
||||
# Add system prompt if provided
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# Setup history file
|
||||
if save_history:
|
||||
if history_file is None:
|
||||
history_dir = settings.config_dir / "chat_history"
|
||||
history_dir.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
history_file = history_dir / f"chat_{timestamp}.json"
|
||||
else:
|
||||
history_file = Path(history_file)
|
||||
history_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Main chat loop
|
||||
try:
|
||||
while True:
|
||||
# Get user input
|
||||
try:
|
||||
user_input = Prompt.ask("[bold green]You[/bold green]")
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
console.print()
|
||||
print_info("Goodbye!")
|
||||
break
|
||||
|
||||
if not user_input.strip():
|
||||
continue
|
||||
|
||||
# Handle special commands
|
||||
if user_input.startswith("/"):
|
||||
if _handle_command(user_input, messages, temperature, max_tokens):
|
||||
continue
|
||||
else:
|
||||
break # Exit command
|
||||
|
||||
# Add user message to history
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
# Generate response
|
||||
console.print()
|
||||
console.print("[bold cyan]Assistant[/bold cyan]")
|
||||
|
||||
try:
|
||||
if stream:
|
||||
# Streaming response
|
||||
response_content = _stream_response(client, selected_model, messages, temperature, max_tokens)
|
||||
else:
|
||||
# Non-streaming response
|
||||
response_content = _generate_response(client, selected_model, messages, temperature, max_tokens)
|
||||
|
||||
# Add assistant response to history
|
||||
messages.append({"role": "assistant", "content": response_content})
|
||||
|
||||
console.print()
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Error generating response: {e}")
|
||||
# Remove the user message that caused the error
|
||||
messages.pop()
|
||||
continue
|
||||
|
||||
# Save history if enabled
|
||||
if save_history:
|
||||
_save_history(history_file, messages, selected_model)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
console.print()
|
||||
console.print()
|
||||
print_info("Chat interrupted. Goodbye!")
|
||||
|
||||
# Final history save
|
||||
if save_history and messages:
|
||||
_save_history(history_file, messages, selected_model)
|
||||
console.print(f"[dim]History saved to: {history_file}[/dim]")
|
||||
console.print()
|
||||
|
||||
|
||||
def _stream_response(
|
||||
client: "OpenAI",
|
||||
model: str,
|
||||
messages: list,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> str:
|
||||
"""Generate streaming response and display in real-time."""
|
||||
response_content = ""
|
||||
|
||||
try:
|
||||
stream = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
response_content += content
|
||||
console.print(content, end="")
|
||||
|
||||
console.print() # Newline after streaming
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Streaming error: {e}")
|
||||
|
||||
return response_content
|
||||
|
||||
|
||||
def _generate_response(
|
||||
client: "OpenAI",
|
||||
model: str,
|
||||
messages: list,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> str:
|
||||
"""Generate non-streaming response."""
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# Display as markdown
|
||||
md = Markdown(content)
|
||||
console.print(md)
|
||||
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Generation error: {e}")
|
||||
|
||||
|
||||
def _handle_command(command: str, messages: list, temperature: float, max_tokens: int) -> bool:
|
||||
"""Handle special commands. Returns True to continue chat, False to exit."""
|
||||
cmd = command.lower().strip()
|
||||
|
||||
if cmd in ["/quit", "/exit", "/q"]:
|
||||
console.print()
|
||||
print_info("Goodbye!")
|
||||
return False
|
||||
|
||||
elif cmd in ["/help", "/h"]:
|
||||
console.print()
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold]Available Commands:[/bold]\n\n"
|
||||
"/help, /h - Show this help message\n"
|
||||
"/quit, /exit, /q - Exit chat\n"
|
||||
"/clear, /c - Clear conversation history\n"
|
||||
"/history, /hist - Show conversation history\n"
|
||||
"/info, /i - Show current settings\n"
|
||||
"/retry, /r - Regenerate last response",
|
||||
title="Help",
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
console.print()
|
||||
return True
|
||||
|
||||
elif cmd in ["/clear", "/c"]:
|
||||
messages.clear()
|
||||
console.print()
|
||||
print_success("Conversation history cleared")
|
||||
console.print()
|
||||
return True
|
||||
|
||||
elif cmd in ["/history", "/hist"]:
|
||||
console.print()
|
||||
if not messages:
|
||||
print_info("No conversation history")
|
||||
else:
|
||||
console.print(
|
||||
Panel(
|
||||
_format_history(messages),
|
||||
title=f"History ({len(messages)} messages)",
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
console.print()
|
||||
return True
|
||||
|
||||
elif cmd in ["/info", "/i"]:
|
||||
console.print()
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold]Current Settings:[/bold]\n\n"
|
||||
f"Temperature: [cyan]{temperature}[/cyan]\n"
|
||||
f"Max tokens: [cyan]{max_tokens}[/cyan]\n"
|
||||
f"Messages: [cyan]{len(messages)}[/cyan]",
|
||||
title="Info",
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
console.print()
|
||||
return True
|
||||
|
||||
elif cmd in ["/retry", "/r"]:
|
||||
if len(messages) >= 2 and messages[-1]["role"] == "assistant":
|
||||
# Remove last assistant response
|
||||
messages.pop()
|
||||
print_info("Retrying last response...")
|
||||
console.print()
|
||||
else:
|
||||
print_warning("No previous response to retry")
|
||||
console.print()
|
||||
return True
|
||||
|
||||
else:
|
||||
print_warning(f"Unknown command: {command}")
|
||||
console.print("[dim]Type /help for available commands[/dim]")
|
||||
console.print()
|
||||
return True
|
||||
|
||||
|
||||
def _format_history(messages: list) -> str:
|
||||
"""Format conversation history for display."""
|
||||
lines = []
|
||||
for i, msg in enumerate(messages, 1):
|
||||
role = msg["role"].capitalize()
|
||||
content = msg["content"]
|
||||
|
||||
# Truncate long messages
|
||||
if len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
|
||||
lines.append(f"[bold]{i}. {role}:[/bold] {content}")
|
||||
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
def _save_history(file_path: Path, messages: list, model: str) -> None:
|
||||
"""Save conversation history to file."""
|
||||
try:
|
||||
history_data = {
|
||||
"model": model,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(history_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
print_warning(f"Failed to save history: {e}")
|
||||
167
kt-kernel/python/cli/commands/config.py
Normal file
167
kt-kernel/python/cli/commands/config.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
"""
|
||||
Config command for kt-cli.
|
||||
|
||||
Manages kt-cli configuration.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
import yaml
|
||||
from rich.syntax import Syntax
|
||||
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import confirm, console, print_error, print_success
|
||||
|
||||
app = typer.Typer(help="Manage kt-cli configuration")
|
||||
|
||||
|
||||
@app.command(name="init")
|
||||
def init() -> None:
|
||||
"""Initialize or re-run the first-time setup wizard."""
|
||||
from kt_kernel.cli.main import _show_first_run_setup
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
_show_first_run_setup(settings)
|
||||
|
||||
|
||||
@app.command(name="show")
|
||||
def show(
|
||||
key: Optional[str] = typer.Argument(None, help="Configuration key to show (e.g., server.port)"),
|
||||
) -> None:
|
||||
"""Show current configuration."""
|
||||
settings = get_settings()
|
||||
|
||||
if key:
|
||||
value = settings.get(key)
|
||||
if value is not None:
|
||||
if isinstance(value, (dict, list)):
|
||||
console.print(yaml.dump({key: value}, default_flow_style=False, allow_unicode=True))
|
||||
else:
|
||||
console.print(t("config_get_value", key=key, value=value))
|
||||
else:
|
||||
print_error(t("config_get_not_found", key=key))
|
||||
raise typer.Exit(1)
|
||||
else:
|
||||
console.print(f"\n[bold]{t('config_show_title')}[/bold]\n")
|
||||
console.print(f"[dim]{t('config_file_location', path=str(settings.config_path))}[/dim]\n")
|
||||
|
||||
config_yaml = yaml.dump(settings.get_all(), default_flow_style=False, allow_unicode=True)
|
||||
syntax = Syntax(config_yaml, "yaml", theme="monokai", line_numbers=False)
|
||||
console.print(syntax)
|
||||
|
||||
|
||||
@app.command(name="set")
|
||||
def set_config(
|
||||
key: str = typer.Argument(..., help="Configuration key (e.g., server.port)"),
|
||||
value: str = typer.Argument(..., help="Value to set"),
|
||||
) -> None:
|
||||
"""Set a configuration value."""
|
||||
settings = get_settings()
|
||||
|
||||
# Try to parse value as JSON/YAML for complex types
|
||||
parsed_value = _parse_value(value)
|
||||
|
||||
settings.set(key, parsed_value)
|
||||
print_success(t("config_set_success", key=key, value=parsed_value))
|
||||
|
||||
|
||||
@app.command(name="get")
|
||||
def get_config(
|
||||
key: str = typer.Argument(..., help="Configuration key (e.g., server.port)"),
|
||||
) -> None:
|
||||
"""Get a configuration value."""
|
||||
settings = get_settings()
|
||||
value = settings.get(key)
|
||||
|
||||
if value is not None:
|
||||
if isinstance(value, (dict, list)):
|
||||
console.print(yaml.dump(value, default_flow_style=False, allow_unicode=True))
|
||||
else:
|
||||
console.print(str(value))
|
||||
else:
|
||||
print_error(t("config_get_not_found", key=key))
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
@app.command(name="reset")
|
||||
def reset(
|
||||
yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation"),
|
||||
) -> None:
|
||||
"""Reset configuration to defaults."""
|
||||
if not yes:
|
||||
if not confirm(t("config_reset_confirm"), default=False):
|
||||
raise typer.Abort()
|
||||
|
||||
settings = get_settings()
|
||||
settings.reset()
|
||||
print_success(t("config_reset_success"))
|
||||
|
||||
|
||||
@app.command(name="path")
|
||||
def path() -> None:
|
||||
"""Show configuration file path."""
|
||||
settings = get_settings()
|
||||
console.print(str(settings.config_path))
|
||||
|
||||
|
||||
@app.command(name="model-path-list", deprecated=True, hidden=True)
|
||||
def model_path_list() -> None:
|
||||
"""[Deprecated] Use 'kt model path-list' instead."""
|
||||
console.print("[yellow]⚠ This command is deprecated. Use 'kt model path-list' instead.[/yellow]\n")
|
||||
import subprocess
|
||||
subprocess.run(["kt", "model", "path-list"])
|
||||
|
||||
|
||||
@app.command(name="model-path-add", deprecated=True, hidden=True)
|
||||
def model_path_add(
|
||||
path: str = typer.Argument(..., help="Path to add"),
|
||||
) -> None:
|
||||
"""[Deprecated] Use 'kt model path-add' instead."""
|
||||
console.print("[yellow]⚠ This command is deprecated. Use 'kt model path-add' instead.[/yellow]\n")
|
||||
import subprocess
|
||||
subprocess.run(["kt", "model", "path-add", path])
|
||||
|
||||
|
||||
@app.command(name="model-path-remove", deprecated=True, hidden=True)
|
||||
def model_path_remove(
|
||||
path: str = typer.Argument(..., help="Path to remove"),
|
||||
) -> None:
|
||||
"""[Deprecated] Use 'kt model path-remove' instead."""
|
||||
console.print("[yellow]⚠ This command is deprecated. Use 'kt model path-remove' instead.[/yellow]\n")
|
||||
import subprocess
|
||||
subprocess.run(["kt", "model", "path-remove", path])
|
||||
|
||||
|
||||
def _parse_value(value: str):
|
||||
"""Parse a string value into appropriate Python type."""
|
||||
# Try boolean
|
||||
if value.lower() in ("true", "yes", "on", "1"):
|
||||
return True
|
||||
if value.lower() in ("false", "no", "off", "0"):
|
||||
return False
|
||||
|
||||
# Try integer
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try float
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try YAML/JSON parsing for lists/dicts
|
||||
try:
|
||||
parsed = yaml.safe_load(value)
|
||||
if isinstance(parsed, (dict, list)):
|
||||
return parsed
|
||||
except yaml.YAMLError:
|
||||
pass
|
||||
|
||||
# Return as string
|
||||
return value
|
||||
394
kt-kernel/python/cli/commands/doctor.py
Normal file
394
kt-kernel/python/cli/commands/doctor.py
Normal file
|
|
@ -0,0 +1,394 @@
|
|||
"""
|
||||
Doctor command for kt-cli.
|
||||
|
||||
Diagnoses environment issues and provides recommendations.
|
||||
"""
|
||||
|
||||
import platform
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
from rich.table import Table
|
||||
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import console, print_error, print_info, print_success, print_warning
|
||||
from kt_kernel.cli.utils.environment import (
|
||||
check_docker,
|
||||
detect_available_ram_gb,
|
||||
detect_cpu_info,
|
||||
detect_cuda_version,
|
||||
detect_disk_space_gb,
|
||||
detect_env_managers,
|
||||
detect_gpus,
|
||||
detect_memory_info,
|
||||
detect_ram_gb,
|
||||
get_installed_package_version,
|
||||
)
|
||||
|
||||
|
||||
def doctor(
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed diagnostics"),
|
||||
) -> None:
|
||||
"""Diagnose environment issues."""
|
||||
console.print(f"\n[bold]{t('doctor_title')}[/bold]\n")
|
||||
|
||||
issues_found = False
|
||||
checks = []
|
||||
|
||||
# 1. Python version
|
||||
python_version = platform.python_version()
|
||||
python_ok = _check_python_version(python_version)
|
||||
checks.append(
|
||||
{
|
||||
"name": t("doctor_check_python"),
|
||||
"status": "ok" if python_ok else "error",
|
||||
"value": python_version,
|
||||
"hint": "Python 3.10+ required" if not python_ok else None,
|
||||
}
|
||||
)
|
||||
if not python_ok:
|
||||
issues_found = True
|
||||
|
||||
# 2. CUDA availability
|
||||
cuda_version = detect_cuda_version()
|
||||
checks.append(
|
||||
{
|
||||
"name": t("doctor_check_cuda"),
|
||||
"status": "ok" if cuda_version else "warning",
|
||||
"value": cuda_version or t("version_cuda_not_found"),
|
||||
"hint": "CUDA is optional but recommended for GPU acceleration" if not cuda_version else None,
|
||||
}
|
||||
)
|
||||
|
||||
# 3. GPU detection
|
||||
gpus = detect_gpus()
|
||||
if gpus:
|
||||
gpu_names = ", ".join(g.name for g in gpus)
|
||||
total_vram = sum(g.vram_gb for g in gpus)
|
||||
checks.append(
|
||||
{
|
||||
"name": t("doctor_check_gpu"),
|
||||
"status": "ok",
|
||||
"value": t("doctor_gpu_found", count=len(gpus), names=gpu_names),
|
||||
"hint": f"Total VRAM: {total_vram}GB",
|
||||
}
|
||||
)
|
||||
else:
|
||||
checks.append(
|
||||
{
|
||||
"name": t("doctor_check_gpu"),
|
||||
"status": "warning",
|
||||
"value": t("doctor_gpu_not_found"),
|
||||
"hint": "GPU recommended for best performance",
|
||||
}
|
||||
)
|
||||
|
||||
# 4. CPU information
|
||||
cpu_info = detect_cpu_info()
|
||||
checks.append(
|
||||
{
|
||||
"name": t("doctor_check_cpu"),
|
||||
"status": "ok",
|
||||
"value": t("doctor_cpu_info", name=cpu_info.name, cores=cpu_info.cores, threads=cpu_info.threads),
|
||||
"hint": None,
|
||||
}
|
||||
)
|
||||
|
||||
# 5. CPU instruction sets (critical for kt-kernel)
|
||||
isa_list = cpu_info.instruction_sets
|
||||
# Check for recommended instruction sets
|
||||
recommended_isa = {"AVX2", "AVX512F", "AMX-INT8"}
|
||||
has_recommended = bool(set(isa_list) & recommended_isa)
|
||||
has_avx2 = "AVX2" in isa_list
|
||||
has_avx512 = any(isa.startswith("AVX512") for isa in isa_list)
|
||||
has_amx = any(isa.startswith("AMX") for isa in isa_list)
|
||||
|
||||
# Determine status and build display string
|
||||
if has_amx:
|
||||
isa_status = "ok"
|
||||
isa_hint = "AMX available - best performance for INT4/INT8"
|
||||
elif has_avx512:
|
||||
isa_status = "ok"
|
||||
isa_hint = "AVX512 available - good performance"
|
||||
elif has_avx2:
|
||||
isa_status = "warning"
|
||||
isa_hint = "AVX2 only - consider upgrading CPU for better performance"
|
||||
else:
|
||||
isa_status = "error"
|
||||
isa_hint = "AVX2 required for kt-kernel"
|
||||
|
||||
# Show top instruction sets (prioritize important ones)
|
||||
display_isa = isa_list[:8] if len(isa_list) > 8 else isa_list
|
||||
isa_display = ", ".join(display_isa)
|
||||
if len(isa_list) > 8:
|
||||
isa_display += f" (+{len(isa_list) - 8} more)"
|
||||
|
||||
checks.append(
|
||||
{
|
||||
"name": t("doctor_check_cpu_isa"),
|
||||
"status": isa_status,
|
||||
"value": isa_display if isa_display else "None detected",
|
||||
"hint": isa_hint,
|
||||
}
|
||||
)
|
||||
|
||||
# 6. NUMA topology
|
||||
numa_detail = []
|
||||
for node, cpus in sorted(cpu_info.numa_info.items()):
|
||||
if len(cpus) > 6:
|
||||
cpu_str = f"{cpus[0]}-{cpus[-1]}"
|
||||
else:
|
||||
cpu_str = ",".join(str(c) for c in cpus)
|
||||
numa_detail.append(f"{node}: {cpu_str}")
|
||||
|
||||
numa_value = t("doctor_numa_info", nodes=cpu_info.numa_nodes)
|
||||
if verbose and numa_detail:
|
||||
numa_value += " (" + "; ".join(numa_detail) + ")"
|
||||
|
||||
checks.append(
|
||||
{
|
||||
"name": t("doctor_check_numa"),
|
||||
"status": "ok",
|
||||
"value": numa_value,
|
||||
"hint": f"{cpu_info.threads // cpu_info.numa_nodes} threads per node" if cpu_info.numa_nodes > 1 else None,
|
||||
}
|
||||
)
|
||||
|
||||
# 7. System memory (with frequency if available)
|
||||
mem_info = detect_memory_info()
|
||||
if mem_info.frequency_mhz and mem_info.type:
|
||||
mem_value = t(
|
||||
"doctor_memory_freq",
|
||||
available=f"{mem_info.available_gb}GB",
|
||||
total=f"{mem_info.total_gb}GB",
|
||||
freq=mem_info.frequency_mhz,
|
||||
type=mem_info.type,
|
||||
)
|
||||
else:
|
||||
mem_value = t("doctor_memory_info", available=f"{mem_info.available_gb}GB", total=f"{mem_info.total_gb}GB")
|
||||
|
||||
ram_ok = mem_info.total_gb >= 32
|
||||
checks.append(
|
||||
{
|
||||
"name": t("doctor_check_memory"),
|
||||
"status": "ok" if ram_ok else "warning",
|
||||
"value": mem_value,
|
||||
"hint": "32GB+ RAM recommended for large models" if not ram_ok else None,
|
||||
}
|
||||
)
|
||||
|
||||
# 8. Disk space - check all model paths
|
||||
settings = get_settings()
|
||||
model_paths = settings.get_model_paths()
|
||||
|
||||
# Check all configured model paths
|
||||
for i, disk_path in enumerate(model_paths):
|
||||
available_disk, total_disk = detect_disk_space_gb(str(disk_path))
|
||||
disk_ok = available_disk >= 100
|
||||
|
||||
# For multiple paths, add index to name
|
||||
path_label = f"Model Path {i+1}" if len(model_paths) > 1 else t("doctor_check_disk")
|
||||
|
||||
checks.append(
|
||||
{
|
||||
"name": path_label,
|
||||
"status": "ok" if disk_ok else "warning",
|
||||
"value": t("doctor_disk_info", available=f"{available_disk}GB", path=str(disk_path)),
|
||||
"hint": "100GB+ free space recommended for model storage" if not disk_ok else None,
|
||||
}
|
||||
)
|
||||
|
||||
# 6. Required packages
|
||||
packages = [
|
||||
("kt-kernel", ">=0.4.0", False), # name, version_req, required
|
||||
("ktransformers", ">=0.4.0", False),
|
||||
("sglang", ">=0.4.0", False),
|
||||
("torch", ">=2.4.0", True),
|
||||
("transformers", ">=4.45.0", True),
|
||||
]
|
||||
|
||||
package_issues = []
|
||||
for pkg_name, version_req, required in packages:
|
||||
version = get_installed_package_version(pkg_name)
|
||||
if version:
|
||||
package_issues.append((pkg_name, version, "ok"))
|
||||
elif required:
|
||||
package_issues.append((pkg_name, t("version_not_installed"), "error"))
|
||||
issues_found = True
|
||||
else:
|
||||
package_issues.append((pkg_name, t("version_not_installed"), "warning"))
|
||||
|
||||
if verbose:
|
||||
checks.append(
|
||||
{
|
||||
"name": t("doctor_check_packages"),
|
||||
"status": "ok" if not any(p[2] == "error" for p in package_issues) else "error",
|
||||
"value": f"{sum(1 for p in package_issues if p[2] == 'ok')}/{len(package_issues)} installed",
|
||||
"packages": package_issues,
|
||||
}
|
||||
)
|
||||
|
||||
# 7. SGLang installation source check
|
||||
from kt_kernel.cli.utils.sglang_checker import check_sglang_installation, check_sglang_kt_kernel_support
|
||||
|
||||
sglang_info = check_sglang_installation()
|
||||
|
||||
if sglang_info["installed"]:
|
||||
if sglang_info["from_source"]:
|
||||
if sglang_info["git_info"]:
|
||||
git_remote = sglang_info["git_info"].get("remote", "unknown")
|
||||
git_branch = sglang_info["git_info"].get("branch", "unknown")
|
||||
sglang_source_value = f"Source (GitHub: {git_remote}, branch: {git_branch})"
|
||||
sglang_source_status = "ok"
|
||||
sglang_source_hint = None
|
||||
else:
|
||||
sglang_source_value = "Source (editable)"
|
||||
sglang_source_status = "ok"
|
||||
sglang_source_hint = None
|
||||
else:
|
||||
sglang_source_value = "PyPI (not recommended)"
|
||||
sglang_source_status = "warning"
|
||||
sglang_source_hint = t("sglang_pypi_hint")
|
||||
else:
|
||||
sglang_source_value = "Not installed"
|
||||
sglang_source_status = "warning"
|
||||
sglang_source_hint = t("sglang_install_hint")
|
||||
|
||||
checks.append(
|
||||
{
|
||||
"name": "SGLang Source",
|
||||
"status": sglang_source_status,
|
||||
"value": sglang_source_value,
|
||||
"hint": sglang_source_hint,
|
||||
}
|
||||
)
|
||||
|
||||
# 7b. SGLang kt-kernel support check (only if SGLang is installed)
|
||||
kt_kernel_support = {"supported": True} # Default to True if not checked
|
||||
if sglang_info["installed"]:
|
||||
# Use cache=False to force re-check in doctor, but silent=True since we show in table
|
||||
kt_kernel_support = check_sglang_kt_kernel_support(use_cache=False, silent=True)
|
||||
|
||||
if kt_kernel_support["supported"]:
|
||||
kt_kernel_value = t("sglang_kt_kernel_supported")
|
||||
kt_kernel_status = "ok"
|
||||
kt_kernel_hint = None
|
||||
else:
|
||||
kt_kernel_value = t("sglang_kt_kernel_not_supported")
|
||||
kt_kernel_status = "error"
|
||||
kt_kernel_hint = 'Reinstall SGLang from: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"'
|
||||
issues_found = True
|
||||
|
||||
checks.append(
|
||||
{
|
||||
"name": "SGLang kt-kernel",
|
||||
"status": kt_kernel_status,
|
||||
"value": kt_kernel_value,
|
||||
"hint": kt_kernel_hint,
|
||||
}
|
||||
)
|
||||
|
||||
# 8. Environment managers
|
||||
env_managers = detect_env_managers()
|
||||
docker = check_docker()
|
||||
env_list = [f"{m.name} {m.version}" for m in env_managers]
|
||||
if docker:
|
||||
env_list.append(f"docker {docker.version}")
|
||||
|
||||
checks.append(
|
||||
{
|
||||
"name": "Environment Managers",
|
||||
"status": "ok" if env_list else "warning",
|
||||
"value": ", ".join(env_list) if env_list else "None found",
|
||||
"hint": "conda or docker recommended for installation" if not env_list else None,
|
||||
}
|
||||
)
|
||||
|
||||
# Display results
|
||||
_display_results(checks, verbose)
|
||||
|
||||
# Show SGLang installation instructions if not installed
|
||||
if not sglang_info["installed"]:
|
||||
from kt_kernel.cli.utils.sglang_checker import print_sglang_install_instructions
|
||||
|
||||
console.print()
|
||||
print_sglang_install_instructions()
|
||||
# Show kt-kernel installation instructions if SGLang is installed but doesn't support kt-kernel
|
||||
elif sglang_info["installed"] and not kt_kernel_support.get("supported", True):
|
||||
from kt_kernel.cli.utils.sglang_checker import print_sglang_kt_kernel_instructions
|
||||
|
||||
console.print()
|
||||
print_sglang_kt_kernel_instructions()
|
||||
|
||||
# Summary
|
||||
console.print()
|
||||
if issues_found:
|
||||
print_warning(t("doctor_has_issues"))
|
||||
else:
|
||||
print_success(t("doctor_all_ok"))
|
||||
console.print()
|
||||
|
||||
|
||||
def _check_python_version(version: str) -> bool:
|
||||
"""Check if Python version meets requirements."""
|
||||
parts = version.split(".")
|
||||
try:
|
||||
major, minor = int(parts[0]), int(parts[1])
|
||||
return major >= 3 and minor >= 10
|
||||
except (IndexError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
def _display_results(checks: list[dict], verbose: bool) -> None:
|
||||
"""Display diagnostic results."""
|
||||
table = Table(show_header=True, header_style="bold")
|
||||
table.add_column("Check", style="bold")
|
||||
table.add_column("Status", width=8)
|
||||
table.add_column("Value")
|
||||
if verbose:
|
||||
table.add_column("Notes", style="dim")
|
||||
|
||||
for check in checks:
|
||||
status = check["status"]
|
||||
if status == "ok":
|
||||
status_str = f"[green]{t('doctor_status_ok')}[/green]"
|
||||
elif status == "warning":
|
||||
status_str = f"[yellow]{t('doctor_status_warning')}[/yellow]"
|
||||
else:
|
||||
status_str = f"[red]{t('doctor_status_error')}[/red]"
|
||||
|
||||
if verbose:
|
||||
table.add_row(
|
||||
check["name"],
|
||||
status_str,
|
||||
check["value"],
|
||||
check.get("hint", ""),
|
||||
)
|
||||
else:
|
||||
table.add_row(
|
||||
check["name"],
|
||||
status_str,
|
||||
check["value"],
|
||||
)
|
||||
|
||||
# Show package details if verbose
|
||||
if verbose and "packages" in check:
|
||||
for pkg_name, pkg_version, pkg_status in check["packages"]:
|
||||
if pkg_status == "ok":
|
||||
pkg_status_str = "[green]✓[/green]"
|
||||
elif pkg_status == "warning":
|
||||
pkg_status_str = "[yellow]○[/yellow]"
|
||||
else:
|
||||
pkg_status_str = "[red]✗[/red]"
|
||||
|
||||
table.add_row(
|
||||
f" └─ {pkg_name}",
|
||||
pkg_status_str,
|
||||
pkg_version,
|
||||
"",
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
409
kt-kernel/python/cli/commands/model.py
Normal file
409
kt-kernel/python/cli/commands/model.py
Normal file
|
|
@ -0,0 +1,409 @@
|
|||
"""
|
||||
Model command for kt-cli.
|
||||
|
||||
Manages models: download, list, and storage paths.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import (
|
||||
confirm,
|
||||
console,
|
||||
print_error,
|
||||
print_info,
|
||||
print_success,
|
||||
print_warning,
|
||||
prompt_choice,
|
||||
)
|
||||
|
||||
app = typer.Typer(
|
||||
help="Manage models and storage paths",
|
||||
invoke_without_command=True,
|
||||
no_args_is_help=False,
|
||||
)
|
||||
|
||||
|
||||
@app.callback()
|
||||
def callback(ctx: typer.Context) -> None:
|
||||
"""
|
||||
Model management commands.
|
||||
|
||||
Run without arguments to see available models.
|
||||
"""
|
||||
# If no subcommand is provided, show the model list
|
||||
if ctx.invoked_subcommand is None:
|
||||
show_model_list()
|
||||
|
||||
|
||||
def show_model_list() -> None:
|
||||
"""Display available models with their status and paths."""
|
||||
from rich.table import Table
|
||||
from kt_kernel.cli.utils.model_registry import get_registry
|
||||
from kt_kernel.cli.i18n import get_lang
|
||||
|
||||
registry = get_registry()
|
||||
settings = get_settings()
|
||||
|
||||
console.print()
|
||||
console.print(f"[bold cyan]{t('model_supported_title')}[/bold cyan]\n")
|
||||
|
||||
# Get local models mapping
|
||||
local_models = {m.name: p for m, p in registry.find_local_models()}
|
||||
|
||||
# Create table
|
||||
table = Table(show_header=True, header_style="bold")
|
||||
table.add_column(t("model_column_model"), style="cyan", no_wrap=True)
|
||||
table.add_column(t("model_column_status"), justify="center")
|
||||
|
||||
all_models = registry.list_all()
|
||||
for model in all_models:
|
||||
if model.name in local_models:
|
||||
status = f"[green]✓ {t('model_status_local')}[/green]"
|
||||
else:
|
||||
status = "[dim]-[/dim]"
|
||||
|
||||
table.add_row(model.name, status)
|
||||
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
||||
# Usage instructions
|
||||
console.print(f"[bold]{t('model_usage_title')}:[/bold]")
|
||||
console.print(f" • {t('model_usage_download')} [cyan]kt model download <model-name>[/cyan]")
|
||||
console.print(f" • {t('model_usage_list_local')} [cyan]kt model list --local[/cyan]")
|
||||
console.print(f" • {t('model_usage_search')} [cyan]kt model search <query>[/cyan]")
|
||||
console.print()
|
||||
|
||||
# Show model storage paths
|
||||
model_paths = settings.get_model_paths()
|
||||
console.print(f"[bold]{t('model_storage_paths_title')}:[/bold]")
|
||||
for path in model_paths:
|
||||
marker = "[green]✓[/green]" if path.exists() else "[dim]✗[/dim]"
|
||||
console.print(f" {marker} {path}")
|
||||
console.print()
|
||||
|
||||
|
||||
@app.command(name="download")
|
||||
def download(
|
||||
model: Optional[str] = typer.Argument(
|
||||
None,
|
||||
help="Model name or HuggingFace repo (e.g., deepseek-v3, Qwen/Qwen3-30B)",
|
||||
),
|
||||
path: Optional[Path] = typer.Option(
|
||||
None,
|
||||
"--path",
|
||||
"-p",
|
||||
help="Custom download path",
|
||||
),
|
||||
list_models: bool = typer.Option(
|
||||
False,
|
||||
"--list",
|
||||
"-l",
|
||||
help="List available models",
|
||||
),
|
||||
resume: bool = typer.Option(
|
||||
True,
|
||||
"--resume/--no-resume",
|
||||
help="Resume incomplete downloads",
|
||||
),
|
||||
yes: bool = typer.Option(
|
||||
False,
|
||||
"--yes",
|
||||
"-y",
|
||||
help="Skip confirmation prompts",
|
||||
),
|
||||
) -> None:
|
||||
"""Download model weights from HuggingFace."""
|
||||
import subprocess
|
||||
from kt_kernel.cli.i18n import get_lang
|
||||
from kt_kernel.cli.utils.console import print_model_table, print_step
|
||||
from kt_kernel.cli.utils.model_registry import get_registry
|
||||
|
||||
settings = get_settings()
|
||||
registry = get_registry()
|
||||
|
||||
console.print()
|
||||
|
||||
# List mode
|
||||
if list_models or model is None:
|
||||
print_step(t("download_list_title"))
|
||||
console.print()
|
||||
|
||||
models = registry.list_all()
|
||||
model_dicts = []
|
||||
for m in models:
|
||||
lang = get_lang()
|
||||
desc = m.description_zh if lang == "zh" and m.description_zh else m.description
|
||||
model_dicts.append(
|
||||
{
|
||||
"name": m.name,
|
||||
"hf_repo": m.hf_repo,
|
||||
"type": m.type,
|
||||
"gpu_vram_gb": m.gpu_vram_gb,
|
||||
"cpu_ram_gb": m.cpu_ram_gb,
|
||||
}
|
||||
)
|
||||
|
||||
print_model_table(model_dicts)
|
||||
console.print()
|
||||
|
||||
if model is None:
|
||||
console.print(f"[dim]{t('model_download_usage_hint')}[/dim]")
|
||||
console.print()
|
||||
return
|
||||
|
||||
# Search for model
|
||||
print_step(t("download_searching", name=model))
|
||||
|
||||
# Check if it's a direct HuggingFace repo path
|
||||
if "/" in model:
|
||||
hf_repo = model
|
||||
model_info = None
|
||||
model_name = model.split("/")[-1]
|
||||
else:
|
||||
matches = registry.search(model)
|
||||
|
||||
if not matches:
|
||||
print_error(t("run_model_not_found", name=model))
|
||||
console.print()
|
||||
console.print(t("model_download_list_hint"))
|
||||
console.print(t("model_download_hf_hint"))
|
||||
raise typer.Exit(1)
|
||||
|
||||
if len(matches) == 1:
|
||||
model_info = matches[0]
|
||||
else:
|
||||
console.print()
|
||||
print_info(t("download_multiple_found"))
|
||||
choices = [f"{m.name} ({m.hf_repo})" for m in matches]
|
||||
selected = prompt_choice(t("download_select"), choices)
|
||||
idx = choices.index(selected)
|
||||
model_info = matches[idx]
|
||||
|
||||
hf_repo = model_info.hf_repo
|
||||
model_name = model_info.name
|
||||
|
||||
print_success(t("download_found", name=hf_repo))
|
||||
|
||||
# Determine download path
|
||||
if path is None:
|
||||
download_path = settings.models_dir / model_name.replace(" ", "-")
|
||||
else:
|
||||
download_path = path
|
||||
|
||||
console.print()
|
||||
print_info(t("download_destination", path=str(download_path)))
|
||||
|
||||
# Check if already exists
|
||||
if download_path.exists() and (download_path / "config.json").exists():
|
||||
print_warning(t("download_already_exists", path=str(download_path)))
|
||||
if not yes:
|
||||
if not confirm(t("download_overwrite_prompt"), default=False):
|
||||
raise typer.Abort()
|
||||
|
||||
# Confirm download
|
||||
if not yes:
|
||||
console.print()
|
||||
if not confirm(t("prompt_continue")):
|
||||
raise typer.Abort()
|
||||
|
||||
# Download using huggingface-cli
|
||||
console.print()
|
||||
print_step(t("download_starting"))
|
||||
|
||||
cmd = [
|
||||
"huggingface-cli",
|
||||
"download",
|
||||
hf_repo,
|
||||
"--local-dir",
|
||||
str(download_path),
|
||||
]
|
||||
|
||||
if resume:
|
||||
cmd.append("--resume-download")
|
||||
|
||||
# Add mirror if configured
|
||||
mirror = settings.get("download.mirror", "")
|
||||
if mirror:
|
||||
cmd.extend(["--endpoint", mirror])
|
||||
|
||||
try:
|
||||
process = subprocess.run(cmd, check=True)
|
||||
|
||||
console.print()
|
||||
print_success(t("download_complete"))
|
||||
console.print()
|
||||
console.print(f" {t('model_saved_to', path=download_path)}")
|
||||
console.print()
|
||||
console.print(f" {t('model_start_with', name=model_name)}")
|
||||
console.print()
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print_error(t("model_download_failed", error=str(e)))
|
||||
raise typer.Exit(1)
|
||||
except FileNotFoundError:
|
||||
print_error(t("model_hf_cli_not_found"))
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
@app.command(name="list")
|
||||
def list_models(
|
||||
local_only: bool = typer.Option(False, "--local", help="Show only locally downloaded models"),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed info including paths"),
|
||||
) -> None:
|
||||
"""List available models."""
|
||||
from rich.table import Table
|
||||
from kt_kernel.cli.utils.model_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
console.print()
|
||||
|
||||
if local_only:
|
||||
# Show only local models
|
||||
local_models = registry.find_local_models()
|
||||
|
||||
if not local_models:
|
||||
print_warning(t("model_no_local_models"))
|
||||
console.print()
|
||||
console.print(f" {t('model_download_hint')} [cyan]kt model download <model-name>[/cyan]")
|
||||
console.print()
|
||||
return
|
||||
|
||||
table = Table(title=t("model_local_models_title"), show_header=True, header_style="bold")
|
||||
table.add_column(t("model_column_model"), style="cyan", no_wrap=True)
|
||||
if verbose:
|
||||
table.add_column(t("model_column_local_path"), style="dim")
|
||||
|
||||
for model_info, model_path in local_models:
|
||||
if verbose:
|
||||
table.add_row(model_info.name, str(model_path))
|
||||
else:
|
||||
table.add_row(model_info.name)
|
||||
|
||||
console.print(table)
|
||||
else:
|
||||
# Show all registered models
|
||||
all_models = registry.list_all()
|
||||
local_models_dict = {m.name: p for m, p in registry.find_local_models()}
|
||||
|
||||
table = Table(title=t("model_available_models_title"), show_header=True, header_style="bold")
|
||||
table.add_column(t("model_column_model"), style="cyan", no_wrap=True)
|
||||
table.add_column(t("model_column_status"), justify="center")
|
||||
if verbose:
|
||||
table.add_column(t("model_column_local_path"), style="dim")
|
||||
|
||||
for model in all_models:
|
||||
if model.name in local_models_dict:
|
||||
status = f"[green]✓ {t('model_status_local')}[/green]"
|
||||
local_path = str(local_models_dict[model.name])
|
||||
else:
|
||||
status = "[dim]-[/dim]"
|
||||
local_path = f"[dim]{t('model_status_not_downloaded')}[/dim]"
|
||||
|
||||
if verbose:
|
||||
table.add_row(model.name, status, local_path)
|
||||
else:
|
||||
table.add_row(model.name, status)
|
||||
|
||||
console.print(table)
|
||||
|
||||
console.print()
|
||||
|
||||
|
||||
@app.command(name="path-list")
|
||||
def path_list() -> None:
|
||||
"""List all configured model storage paths."""
|
||||
settings = get_settings()
|
||||
model_paths = settings.get_model_paths()
|
||||
|
||||
console.print()
|
||||
console.print(f"[bold]{t('model_storage_paths_title')}:[/bold]\n")
|
||||
|
||||
for i, path in enumerate(model_paths, 1):
|
||||
marker = "[green]✓[/green]" if path.exists() else "[red]✗[/red]"
|
||||
console.print(f" {marker} [{i}] {path}")
|
||||
|
||||
console.print()
|
||||
|
||||
|
||||
@app.command(name="path-add")
|
||||
def path_add(
|
||||
path: str = typer.Argument(..., help="Path to add"),
|
||||
) -> None:
|
||||
"""Add a new model storage path."""
|
||||
# Expand user home directory
|
||||
path = os.path.expanduser(path)
|
||||
|
||||
# Check if path exists or can be created
|
||||
path_obj = Path(path)
|
||||
if not path_obj.exists():
|
||||
console.print(f"[yellow]{t('model_path_not_exist', path=path)}[/yellow]")
|
||||
if confirm(t("model_create_directory", path=path), default=True):
|
||||
try:
|
||||
path_obj.mkdir(parents=True, exist_ok=True)
|
||||
console.print(f"[green]✓[/green] {t('model_created_directory', path=path)}")
|
||||
except (OSError, PermissionError) as e:
|
||||
print_error(t("model_create_dir_failed", error=str(e)))
|
||||
raise typer.Exit(1)
|
||||
else:
|
||||
raise typer.Abort()
|
||||
|
||||
# Add to configuration
|
||||
settings = get_settings()
|
||||
settings.add_model_path(path)
|
||||
print_success(t("model_path_added", path=path))
|
||||
|
||||
|
||||
@app.command(name="path-remove")
|
||||
def path_remove(
|
||||
path: str = typer.Argument(..., help="Path to remove"),
|
||||
) -> None:
|
||||
"""Remove a model storage path from configuration."""
|
||||
# Expand user home directory
|
||||
path = os.path.expanduser(path)
|
||||
|
||||
settings = get_settings()
|
||||
if settings.remove_model_path(path):
|
||||
print_success(t("model_path_removed", path=path))
|
||||
else:
|
||||
print_error(t("model_path_not_found", path=path))
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
@app.command(name="search")
|
||||
def search(
|
||||
query: str = typer.Argument(..., help="Search query (model name or keyword)"),
|
||||
) -> None:
|
||||
"""Search for models in the registry."""
|
||||
from rich.table import Table
|
||||
from kt_kernel.cli.utils.model_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
matches = registry.search(query)
|
||||
|
||||
console.print()
|
||||
|
||||
if not matches:
|
||||
print_warning(t("model_search_no_results", query=query))
|
||||
console.print()
|
||||
return
|
||||
|
||||
table = Table(title=t("model_search_results_title", query=query), show_header=True)
|
||||
table.add_column(t("model_column_name"), style="cyan")
|
||||
table.add_column(t("model_column_hf_repo"), style="dim")
|
||||
table.add_column(t("model_column_aliases"), style="yellow")
|
||||
|
||||
for model in matches:
|
||||
aliases = ", ".join(model.aliases[:3])
|
||||
if len(model.aliases) > 3:
|
||||
aliases += f" +{len(model.aliases) - 3} more"
|
||||
table.add_row(model.name, model.hf_repo, aliases)
|
||||
|
||||
console.print(table)
|
||||
console.print()
|
||||
239
kt-kernel/python/cli/commands/quant.py
Normal file
239
kt-kernel/python/cli/commands/quant.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
"""
|
||||
Quant command for kt-cli.
|
||||
|
||||
Quantizes model weights for CPU inference.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import (
|
||||
confirm,
|
||||
console,
|
||||
create_progress,
|
||||
print_error,
|
||||
print_info,
|
||||
print_step,
|
||||
print_success,
|
||||
print_warning,
|
||||
)
|
||||
from kt_kernel.cli.utils.environment import detect_cpu_info
|
||||
|
||||
|
||||
class QuantMethod(str, Enum):
|
||||
"""Quantization method."""
|
||||
|
||||
INT4 = "int4"
|
||||
INT8 = "int8"
|
||||
|
||||
|
||||
def quant(
|
||||
model: str = typer.Argument(
|
||||
...,
|
||||
help="Model name or path to quantize",
|
||||
),
|
||||
method: QuantMethod = typer.Option(
|
||||
QuantMethod.INT4,
|
||||
"--method",
|
||||
"-m",
|
||||
help="Quantization method",
|
||||
),
|
||||
output: Optional[Path] = typer.Option(
|
||||
None,
|
||||
"--output",
|
||||
"-o",
|
||||
help="Output path for quantized weights",
|
||||
),
|
||||
input_type: str = typer.Option(
|
||||
"fp8",
|
||||
"--input-type",
|
||||
"-i",
|
||||
help="Input weight type (fp8, fp16, bf16)",
|
||||
),
|
||||
cpu_threads: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--cpu-threads",
|
||||
help="Number of CPU threads for quantization",
|
||||
),
|
||||
numa_nodes: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--numa-nodes",
|
||||
help="Number of NUMA nodes",
|
||||
),
|
||||
no_merge: bool = typer.Option(
|
||||
False,
|
||||
"--no-merge",
|
||||
help="Don't merge safetensor files",
|
||||
),
|
||||
yes: bool = typer.Option(
|
||||
False,
|
||||
"--yes",
|
||||
"-y",
|
||||
help="Skip confirmation prompts",
|
||||
),
|
||||
) -> None:
|
||||
"""Quantize model weights for CPU inference."""
|
||||
settings = get_settings()
|
||||
console.print()
|
||||
|
||||
# Resolve input path
|
||||
input_path = _resolve_input_path(model, settings)
|
||||
if input_path is None:
|
||||
print_error(t("quant_input_not_found", path=model))
|
||||
raise typer.Exit(1)
|
||||
|
||||
print_info(t("quant_input_path", path=str(input_path)))
|
||||
|
||||
# Resolve output path
|
||||
if output is None:
|
||||
output = input_path.parent / f"{input_path.name}-{method.value.upper()}"
|
||||
|
||||
print_info(t("quant_output_path", path=str(output)))
|
||||
print_info(t("quant_method", method=method.value.upper()))
|
||||
|
||||
# Detect CPU configuration
|
||||
cpu = detect_cpu_info()
|
||||
final_cpu_threads = cpu_threads or cpu.cores
|
||||
final_numa_nodes = numa_nodes or cpu.numa_nodes
|
||||
|
||||
print_info(f"CPU threads: {final_cpu_threads}")
|
||||
print_info(f"NUMA nodes: {final_numa_nodes}")
|
||||
|
||||
# Check if output exists
|
||||
if output.exists():
|
||||
print_warning(f"Output path already exists: {output}")
|
||||
if not yes:
|
||||
if not confirm("Overwrite?", default=False):
|
||||
raise typer.Abort()
|
||||
|
||||
# Confirm
|
||||
if not yes:
|
||||
console.print()
|
||||
console.print("[bold]Quantization Settings:[/bold]")
|
||||
console.print(f" Input: {input_path}")
|
||||
console.print(f" Output: {output}")
|
||||
console.print(f" Method: {method.value.upper()}")
|
||||
console.print(f" Input type: {input_type}")
|
||||
console.print()
|
||||
print_warning("Quantization may take 30-60 minutes depending on model size.")
|
||||
console.print()
|
||||
|
||||
if not confirm(t("prompt_continue")):
|
||||
raise typer.Abort()
|
||||
|
||||
# Find conversion script
|
||||
kt_kernel_path = _find_kt_kernel_path()
|
||||
if kt_kernel_path is None:
|
||||
print_error("kt-kernel not found. Install with: kt install inference")
|
||||
raise typer.Exit(1)
|
||||
|
||||
script_path = kt_kernel_path / "scripts" / "convert_cpu_weights.py"
|
||||
if not script_path.exists():
|
||||
print_error(f"Conversion script not found: {script_path}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Build command
|
||||
cmd = [
|
||||
sys.executable, str(script_path),
|
||||
"--input-path", str(input_path),
|
||||
"--input-type", input_type,
|
||||
"--output", str(output),
|
||||
"--quant-method", method.value,
|
||||
"--cpuinfer-threads", str(final_cpu_threads),
|
||||
"--threadpool-count", str(final_numa_nodes),
|
||||
]
|
||||
|
||||
if no_merge:
|
||||
cmd.append("--no-merge-safetensor")
|
||||
|
||||
# Run quantization
|
||||
console.print()
|
||||
print_step(t("quant_starting"))
|
||||
console.print()
|
||||
console.print(f"[dim]$ {' '.join(cmd)}[/dim]")
|
||||
console.print()
|
||||
|
||||
try:
|
||||
process = subprocess.run(cmd)
|
||||
|
||||
if process.returncode == 0:
|
||||
console.print()
|
||||
print_success(t("quant_complete"))
|
||||
console.print()
|
||||
console.print(f" Quantized weights saved to: {output}")
|
||||
console.print()
|
||||
console.print(" Use with:")
|
||||
console.print(f" kt run {model} --weights-path {output}")
|
||||
console.print()
|
||||
else:
|
||||
print_error(f"Quantization failed with exit code {process.returncode}")
|
||||
raise typer.Exit(process.returncode)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print_error(f"Failed to run quantization: {e}")
|
||||
raise typer.Exit(1)
|
||||
except KeyboardInterrupt:
|
||||
console.print()
|
||||
print_warning("Quantization interrupted.")
|
||||
raise typer.Exit(130)
|
||||
|
||||
|
||||
def _resolve_input_path(model: str, settings) -> Optional[Path]:
|
||||
"""Resolve the input model path."""
|
||||
# Check if it's already a path
|
||||
path = Path(model)
|
||||
if path.exists() and (path / "config.json").exists():
|
||||
return path
|
||||
|
||||
# Search in models directory
|
||||
from kt_kernel.cli.utils.model_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
matches = registry.search(model)
|
||||
|
||||
if matches:
|
||||
model_info = matches[0]
|
||||
# Try to find in all configured model directories
|
||||
model_paths = settings.get_model_paths()
|
||||
|
||||
for models_dir in model_paths:
|
||||
possible_paths = [
|
||||
models_dir / model_info.name,
|
||||
models_dir / model_info.name.lower(),
|
||||
models_dir / model_info.hf_repo.split("/")[-1],
|
||||
]
|
||||
|
||||
for p in possible_paths:
|
||||
if p.exists() and (p / "config.json").exists():
|
||||
return p
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_kt_kernel_path() -> Optional[Path]:
|
||||
"""Find the kt-kernel installation path."""
|
||||
try:
|
||||
import kt_kernel
|
||||
return Path(kt_kernel.__file__).parent.parent
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Check common locations
|
||||
possible_paths = [
|
||||
Path.home() / "Projects" / "ktransformers" / "kt-kernel",
|
||||
Path.cwd().parent / "kt-kernel",
|
||||
Path.cwd() / "kt-kernel",
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
if path.exists() and (path / "scripts").exists():
|
||||
return path
|
||||
|
||||
return None
|
||||
831
kt-kernel/python/cli/commands/run.py
Normal file
831
kt-kernel/python/cli/commands/run.py
Normal file
|
|
@ -0,0 +1,831 @@
|
|||
"""
|
||||
Run command for kt-cli.
|
||||
|
||||
Starts the model inference server using SGLang + kt-kernel.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import (
|
||||
confirm,
|
||||
console,
|
||||
print_api_info,
|
||||
print_error,
|
||||
print_info,
|
||||
print_server_info,
|
||||
print_step,
|
||||
print_success,
|
||||
print_warning,
|
||||
prompt_choice,
|
||||
)
|
||||
from kt_kernel.cli.utils.environment import detect_cpu_info, detect_gpus, detect_ram_gb
|
||||
from kt_kernel.cli.utils.model_registry import MODEL_COMPUTE_FUNCTIONS, ModelInfo, get_registry
|
||||
|
||||
|
||||
def run(
|
||||
model: Optional[str] = typer.Argument(
|
||||
None,
|
||||
help="Model name or path (e.g., deepseek-v3, qwen3-30b). If not specified, shows interactive selection.",
|
||||
),
|
||||
host: str = typer.Option(
|
||||
None,
|
||||
"--host",
|
||||
"-H",
|
||||
help="Server host address",
|
||||
),
|
||||
port: int = typer.Option(
|
||||
None,
|
||||
"--port",
|
||||
"-p",
|
||||
help="Server port",
|
||||
),
|
||||
# CPU/GPU configuration
|
||||
gpu_experts: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--gpu-experts",
|
||||
help="Number of GPU experts per layer",
|
||||
),
|
||||
cpu_threads: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--cpu-threads",
|
||||
help="Number of CPU inference threads (kt-cpuinfer, defaults to 80% of CPU cores)",
|
||||
),
|
||||
numa_nodes: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--numa-nodes",
|
||||
help="Number of NUMA nodes",
|
||||
),
|
||||
tensor_parallel_size: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--tensor-parallel-size",
|
||||
"--tp",
|
||||
help="Tensor parallel size (number of GPUs)",
|
||||
),
|
||||
# Model paths
|
||||
model_path: Optional[Path] = typer.Option(
|
||||
None,
|
||||
"--model-path",
|
||||
help="Custom model path",
|
||||
),
|
||||
weights_path: Optional[Path] = typer.Option(
|
||||
None,
|
||||
"--weights-path",
|
||||
help="Custom quantized weights path",
|
||||
),
|
||||
# KT-kernel options
|
||||
kt_method: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--kt-method",
|
||||
help="KT quantization method (AMXINT4, RAWFP8, etc.)",
|
||||
),
|
||||
kt_gpu_prefill_token_threshold: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--kt-gpu-prefill-threshold",
|
||||
help="GPU prefill token threshold for kt-kernel",
|
||||
),
|
||||
# SGLang options
|
||||
attention_backend: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--attention-backend",
|
||||
help="Attention backend (triton, flashinfer)",
|
||||
),
|
||||
max_total_tokens: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--max-total-tokens",
|
||||
help="Maximum total tokens",
|
||||
),
|
||||
max_running_requests: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--max-running-requests",
|
||||
help="Maximum running requests",
|
||||
),
|
||||
chunked_prefill_size: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--chunked-prefill-size",
|
||||
help="Chunked prefill size",
|
||||
),
|
||||
mem_fraction_static: Optional[float] = typer.Option(
|
||||
None,
|
||||
"--mem-fraction-static",
|
||||
help="Memory fraction for static allocation",
|
||||
),
|
||||
watchdog_timeout: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--watchdog-timeout",
|
||||
help="Watchdog timeout in seconds",
|
||||
),
|
||||
served_model_name: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--served-model-name",
|
||||
help="Custom model name for API responses",
|
||||
),
|
||||
# Performance flags
|
||||
disable_shared_experts_fusion: Optional[bool] = typer.Option(
|
||||
None,
|
||||
"--disable-shared-experts-fusion/--enable-shared-experts-fusion",
|
||||
help="Disable/enable shared experts fusion",
|
||||
),
|
||||
# Other options
|
||||
quantize: bool = typer.Option(
|
||||
False,
|
||||
"--quantize",
|
||||
"-q",
|
||||
help="Quantize model if weights not found",
|
||||
),
|
||||
advanced: bool = typer.Option(
|
||||
False,
|
||||
"--advanced",
|
||||
help="Show advanced options",
|
||||
),
|
||||
dry_run: bool = typer.Option(
|
||||
False,
|
||||
"--dry-run",
|
||||
help="Show command without executing",
|
||||
),
|
||||
) -> None:
|
||||
"""Start model inference server."""
|
||||
# Check if SGLang is installed before proceeding
|
||||
from kt_kernel.cli.utils.sglang_checker import (
|
||||
check_sglang_installation,
|
||||
check_sglang_kt_kernel_support,
|
||||
print_sglang_install_instructions,
|
||||
print_sglang_kt_kernel_instructions,
|
||||
)
|
||||
|
||||
sglang_info = check_sglang_installation()
|
||||
if not sglang_info["installed"]:
|
||||
console.print()
|
||||
print_error(t("sglang_not_found"))
|
||||
console.print()
|
||||
print_sglang_install_instructions()
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Check if SGLang supports kt-kernel (has --kt-gpu-prefill-token-threshold parameter)
|
||||
kt_kernel_support = check_sglang_kt_kernel_support()
|
||||
if not kt_kernel_support["supported"]:
|
||||
console.print()
|
||||
print_error(t("sglang_kt_kernel_not_supported"))
|
||||
console.print()
|
||||
print_sglang_kt_kernel_instructions()
|
||||
raise typer.Exit(1)
|
||||
|
||||
settings = get_settings()
|
||||
registry = get_registry()
|
||||
|
||||
console.print()
|
||||
|
||||
# If no model specified, show interactive selection
|
||||
if model is None:
|
||||
model = _interactive_model_selection(registry, settings)
|
||||
if model is None:
|
||||
raise typer.Exit(0)
|
||||
|
||||
# Step 1: Detect hardware
|
||||
print_step(t("run_detecting_hardware"))
|
||||
gpus = detect_gpus()
|
||||
cpu = detect_cpu_info()
|
||||
ram = detect_ram_gb()
|
||||
|
||||
if gpus:
|
||||
gpu_info = f"{gpus[0].name} ({gpus[0].vram_gb}GB VRAM)"
|
||||
if len(gpus) > 1:
|
||||
gpu_info += f" + {len(gpus) - 1} more"
|
||||
print_info(t("run_gpu_info", name=gpus[0].name, vram=gpus[0].vram_gb))
|
||||
else:
|
||||
print_warning(t("doctor_gpu_not_found"))
|
||||
gpu_info = "None"
|
||||
|
||||
print_info(t("run_cpu_info", name=cpu.name, cores=cpu.cores, numa=cpu.numa_nodes))
|
||||
print_info(t("run_ram_info", total=int(ram)))
|
||||
|
||||
# Step 2: Resolve model
|
||||
console.print()
|
||||
print_step(t("run_checking_model"))
|
||||
|
||||
model_info = None
|
||||
resolved_model_path = model_path
|
||||
|
||||
# Check if model is a path
|
||||
if Path(model).exists():
|
||||
resolved_model_path = Path(model)
|
||||
print_info(t("run_model_path", path=str(resolved_model_path)))
|
||||
|
||||
# Try to infer model type from path to use default configurations
|
||||
# Check directory name against known models
|
||||
dir_name = resolved_model_path.name.lower()
|
||||
for registered_model in registry.list_all():
|
||||
# Check if directory name matches model name or aliases
|
||||
if dir_name == registered_model.name.lower():
|
||||
model_info = registered_model
|
||||
print_info(f"Detected model type: {registered_model.name}")
|
||||
break
|
||||
for alias in registered_model.aliases:
|
||||
if dir_name == alias.lower() or alias.lower() in dir_name:
|
||||
model_info = registered_model
|
||||
print_info(f"Detected model type: {registered_model.name}")
|
||||
break
|
||||
if model_info:
|
||||
break
|
||||
|
||||
# Also check HuggingFace repo format (org--model)
|
||||
if not model_info:
|
||||
for registered_model in registry.list_all():
|
||||
repo_slug = registered_model.hf_repo.replace("/", "--").lower()
|
||||
if repo_slug in dir_name or dir_name in repo_slug:
|
||||
model_info = registered_model
|
||||
print_info(f"Detected model type: {registered_model.name}")
|
||||
break
|
||||
|
||||
if not model_info:
|
||||
print_warning("Could not detect model type from path. Using default parameters.")
|
||||
console.print(" [dim]Tip: Use model name (e.g., 'kt run m2') to apply optimized configurations[/dim]")
|
||||
else:
|
||||
# Search in registry
|
||||
matches = registry.search(model)
|
||||
|
||||
if not matches:
|
||||
print_error(t("run_model_not_found", name=model))
|
||||
console.print()
|
||||
console.print("Available models:")
|
||||
for m in registry.list_all()[:5]:
|
||||
console.print(f" - {m.name} ({', '.join(m.aliases[:2])})")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if len(matches) == 1:
|
||||
model_info = matches[0]
|
||||
else:
|
||||
# Multiple matches - prompt user
|
||||
console.print()
|
||||
print_info(t("run_multiple_matches"))
|
||||
choices = [f"{m.name} ({m.hf_repo})" for m in matches]
|
||||
selected = prompt_choice(t("run_select_model"), choices)
|
||||
idx = choices.index(selected)
|
||||
model_info = matches[idx]
|
||||
|
||||
# Find model path
|
||||
if model_path is None:
|
||||
resolved_model_path = _find_model_path(model_info, settings)
|
||||
if resolved_model_path is None:
|
||||
print_error(t("run_model_not_found", name=model_info.name))
|
||||
console.print()
|
||||
console.print(
|
||||
f" Download with: kt download {model_info.aliases[0] if model_info.aliases else model_info.name}"
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
print_info(t("run_model_path", path=str(resolved_model_path)))
|
||||
|
||||
# Step 3: Check quantized weights (only if explicitly requested)
|
||||
resolved_weights_path = None
|
||||
|
||||
# Only use quantized weights if explicitly specified by user
|
||||
if weights_path is not None:
|
||||
# User explicitly specified weights path
|
||||
resolved_weights_path = weights_path
|
||||
if not resolved_weights_path.exists():
|
||||
print_error(t("run_weights_not_found"))
|
||||
console.print(f" Path: {resolved_weights_path}")
|
||||
raise typer.Exit(1)
|
||||
print_info(f"Using quantized weights: {resolved_weights_path}")
|
||||
elif quantize:
|
||||
# User requested quantization
|
||||
console.print()
|
||||
print_step(t("run_quantizing"))
|
||||
# TODO: Implement quantization
|
||||
print_warning("Quantization not yet implemented. Please run 'kt quant' manually.")
|
||||
raise typer.Exit(1)
|
||||
else:
|
||||
# Default: use original precision model without quantization
|
||||
console.print()
|
||||
print_info("Using original precision model (no quantization)")
|
||||
|
||||
# Step 4: Build command
|
||||
# Resolve all parameters (CLI > model defaults > config > auto-detect)
|
||||
final_host = host or settings.get("server.host", "0.0.0.0")
|
||||
final_port = port or settings.get("server.port", 30000)
|
||||
|
||||
# Get defaults from model info if available
|
||||
model_defaults = model_info.default_params if model_info else {}
|
||||
|
||||
# Determine tensor parallel size first (needed for GPU expert calculation)
|
||||
# Priority: CLI > model defaults > config > auto-detect (with model constraints)
|
||||
|
||||
# Check if explicitly specified by user or configuration
|
||||
explicitly_specified = (
|
||||
tensor_parallel_size # CLI argument (highest priority)
|
||||
or model_defaults.get("tensor-parallel-size") # Model defaults
|
||||
or settings.get("inference.tensor_parallel_size") # Config file
|
||||
)
|
||||
|
||||
if explicitly_specified:
|
||||
# Use explicitly specified value
|
||||
requested_tensor_parallel_size = explicitly_specified
|
||||
else:
|
||||
# Auto-detect from GPUs, considering model's max constraint
|
||||
detected_gpu_count = len(gpus) if gpus else 1
|
||||
if model_info and model_info.max_tensor_parallel_size is not None:
|
||||
# Automatically limit to model's maximum to use as many GPUs as possible
|
||||
requested_tensor_parallel_size = min(detected_gpu_count, model_info.max_tensor_parallel_size)
|
||||
else:
|
||||
requested_tensor_parallel_size = detected_gpu_count
|
||||
|
||||
# Apply model's max_tensor_parallel_size constraint if explicitly specified value exceeds it
|
||||
final_tensor_parallel_size = requested_tensor_parallel_size
|
||||
if model_info and model_info.max_tensor_parallel_size is not None:
|
||||
if requested_tensor_parallel_size > model_info.max_tensor_parallel_size:
|
||||
console.print()
|
||||
print_warning(
|
||||
f"Model {model_info.name} only supports up to {model_info.max_tensor_parallel_size}-way "
|
||||
f"tensor parallelism, but {requested_tensor_parallel_size} was requested. "
|
||||
f"Reducing to {model_info.max_tensor_parallel_size}."
|
||||
)
|
||||
final_tensor_parallel_size = model_info.max_tensor_parallel_size
|
||||
|
||||
# CPU/GPU configuration with smart defaults
|
||||
# kt-cpuinfer: default to 80% of total CPU threads (cores * NUMA nodes)
|
||||
total_threads = cpu.cores * cpu.numa_nodes
|
||||
final_cpu_threads = (
|
||||
cpu_threads
|
||||
or model_defaults.get("kt-cpuinfer")
|
||||
or settings.get("inference.cpu_threads")
|
||||
or int(total_threads * 0.8)
|
||||
)
|
||||
|
||||
# kt-threadpool-count: default to NUMA node count
|
||||
final_numa_nodes = (
|
||||
numa_nodes
|
||||
or model_defaults.get("kt-threadpool-count")
|
||||
or settings.get("inference.numa_nodes")
|
||||
or cpu.numa_nodes
|
||||
)
|
||||
|
||||
# kt-num-gpu-experts: use model-specific computation if available and not explicitly set
|
||||
if gpu_experts is not None:
|
||||
# User explicitly set it
|
||||
final_gpu_experts = gpu_experts
|
||||
elif model_info and model_info.name in MODEL_COMPUTE_FUNCTIONS and gpus:
|
||||
# Use model-specific computation function (only if GPUs detected)
|
||||
vram_per_gpu = gpus[0].vram_gb
|
||||
compute_func = MODEL_COMPUTE_FUNCTIONS[model_info.name]
|
||||
final_gpu_experts = compute_func(final_tensor_parallel_size, vram_per_gpu)
|
||||
console.print()
|
||||
print_info(
|
||||
f"Auto-computed kt-num-gpu-experts: {final_gpu_experts} (TP={final_tensor_parallel_size}, VRAM={vram_per_gpu}GB per GPU)"
|
||||
)
|
||||
else:
|
||||
# Fall back to defaults
|
||||
final_gpu_experts = model_defaults.get("kt-num-gpu-experts") or settings.get("inference.gpu_experts", 1)
|
||||
|
||||
# KT-kernel options
|
||||
final_kt_method = kt_method or model_defaults.get("kt-method") or settings.get("inference.kt_method", "AMXINT4")
|
||||
final_kt_gpu_prefill_threshold = (
|
||||
kt_gpu_prefill_token_threshold
|
||||
or model_defaults.get("kt-gpu-prefill-token-threshold")
|
||||
or settings.get("inference.kt_gpu_prefill_token_threshold", 4096)
|
||||
)
|
||||
|
||||
# SGLang options
|
||||
final_attention_backend = (
|
||||
attention_backend
|
||||
or model_defaults.get("attention-backend")
|
||||
or settings.get("inference.attention_backend", "triton")
|
||||
)
|
||||
final_max_total_tokens = (
|
||||
max_total_tokens or model_defaults.get("max-total-tokens") or settings.get("inference.max_total_tokens", 40000)
|
||||
)
|
||||
final_max_running_requests = (
|
||||
max_running_requests
|
||||
or model_defaults.get("max-running-requests")
|
||||
or settings.get("inference.max_running_requests", 32)
|
||||
)
|
||||
final_chunked_prefill_size = (
|
||||
chunked_prefill_size
|
||||
or model_defaults.get("chunked-prefill-size")
|
||||
or settings.get("inference.chunked_prefill_size", 4096)
|
||||
)
|
||||
final_mem_fraction_static = (
|
||||
mem_fraction_static
|
||||
or model_defaults.get("mem-fraction-static")
|
||||
or settings.get("inference.mem_fraction_static", 0.98)
|
||||
)
|
||||
final_watchdog_timeout = (
|
||||
watchdog_timeout or model_defaults.get("watchdog-timeout") or settings.get("inference.watchdog_timeout", 3000)
|
||||
)
|
||||
final_served_model_name = (
|
||||
served_model_name or model_defaults.get("served-model-name") or settings.get("inference.served_model_name", "")
|
||||
)
|
||||
|
||||
# Performance flags
|
||||
if disable_shared_experts_fusion is not None:
|
||||
final_disable_shared_experts_fusion = disable_shared_experts_fusion
|
||||
elif "disable-shared-experts-fusion" in model_defaults:
|
||||
final_disable_shared_experts_fusion = model_defaults["disable-shared-experts-fusion"]
|
||||
else:
|
||||
final_disable_shared_experts_fusion = settings.get("inference.disable_shared_experts_fusion", False)
|
||||
|
||||
# Pass all model default params to handle any extra parameters
|
||||
extra_params = model_defaults if model_info else {}
|
||||
|
||||
cmd = _build_sglang_command(
|
||||
model_path=resolved_model_path,
|
||||
weights_path=resolved_weights_path,
|
||||
model_info=model_info,
|
||||
host=final_host,
|
||||
port=final_port,
|
||||
gpu_experts=final_gpu_experts,
|
||||
cpu_threads=final_cpu_threads,
|
||||
numa_nodes=final_numa_nodes,
|
||||
tensor_parallel_size=final_tensor_parallel_size,
|
||||
kt_method=final_kt_method,
|
||||
kt_gpu_prefill_threshold=final_kt_gpu_prefill_threshold,
|
||||
attention_backend=final_attention_backend,
|
||||
max_total_tokens=final_max_total_tokens,
|
||||
max_running_requests=final_max_running_requests,
|
||||
chunked_prefill_size=final_chunked_prefill_size,
|
||||
mem_fraction_static=final_mem_fraction_static,
|
||||
watchdog_timeout=final_watchdog_timeout,
|
||||
served_model_name=final_served_model_name,
|
||||
disable_shared_experts_fusion=final_disable_shared_experts_fusion,
|
||||
settings=settings,
|
||||
extra_model_params=extra_params,
|
||||
)
|
||||
|
||||
# Prepare environment variables
|
||||
env = os.environ.copy()
|
||||
# Add environment variables from advanced.env
|
||||
env.update(settings.get_env_vars())
|
||||
# Add environment variables from inference.env
|
||||
inference_env = settings.get("inference.env", {})
|
||||
if isinstance(inference_env, dict):
|
||||
env.update({k: str(v) for k, v in inference_env.items()})
|
||||
|
||||
# Step 5: Show configuration summary
|
||||
console.print()
|
||||
print_step("Configuration")
|
||||
|
||||
# Model info
|
||||
if model_info:
|
||||
console.print(f" Model: [bold]{model_info.name}[/bold]")
|
||||
else:
|
||||
console.print(f" Model: [bold]{resolved_model_path.name}[/bold]")
|
||||
|
||||
console.print(f" Path: [dim]{resolved_model_path}[/dim]")
|
||||
|
||||
# Key parameters
|
||||
console.print()
|
||||
console.print(f" GPU Experts: [cyan]{final_gpu_experts}[/cyan] per layer")
|
||||
console.print(f" CPU Threads (kt-cpuinfer): [cyan]{final_cpu_threads}[/cyan]")
|
||||
console.print(f" NUMA Nodes (kt-threadpool-count): [cyan]{final_numa_nodes}[/cyan]")
|
||||
console.print(f" Tensor Parallel: [cyan]{final_tensor_parallel_size}[/cyan]")
|
||||
console.print(f" Method: [cyan]{final_kt_method}[/cyan]")
|
||||
console.print(f" Attention: [cyan]{final_attention_backend}[/cyan]")
|
||||
|
||||
# Weights info
|
||||
if resolved_weights_path:
|
||||
console.print()
|
||||
console.print(f" Quantized weights: [yellow]{resolved_weights_path}[/yellow]")
|
||||
|
||||
console.print()
|
||||
console.print(f" Server: [green]http://{final_host}:{final_port}[/green]")
|
||||
console.print()
|
||||
|
||||
# Step 6: Show or execute
|
||||
if dry_run:
|
||||
console.print()
|
||||
console.print("[bold]Command:[/bold]")
|
||||
console.print()
|
||||
console.print(f" [dim]{' '.join(cmd)}[/dim]")
|
||||
console.print()
|
||||
return
|
||||
|
||||
# Execute with prepared environment variables
|
||||
# Don't print "Server started" or API info here - let sglang's logs speak for themselves
|
||||
# The actual startup takes time and these messages are misleading
|
||||
|
||||
# Print the command being executed
|
||||
console.print()
|
||||
console.print("[bold]Launching server with command:[/bold]")
|
||||
console.print()
|
||||
console.print(f" [dim]{' '.join(cmd)}[/dim]")
|
||||
console.print()
|
||||
|
||||
try:
|
||||
# Execute directly without intercepting output or signals
|
||||
# This allows direct output to terminal and Ctrl+C to work naturally
|
||||
process = subprocess.run(cmd, env=env)
|
||||
sys.exit(process.returncode)
|
||||
|
||||
except FileNotFoundError:
|
||||
from kt_kernel.cli.utils.sglang_checker import print_sglang_install_instructions
|
||||
|
||||
print_error(t("sglang_not_found"))
|
||||
console.print()
|
||||
print_sglang_install_instructions()
|
||||
raise typer.Exit(1)
|
||||
except Exception as e:
|
||||
print_error(f"Failed to start server: {e}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def _find_model_path(model_info: ModelInfo, settings) -> Optional[Path]:
|
||||
"""Find the model path on disk by searching all configured model paths."""
|
||||
model_paths = settings.get_model_paths()
|
||||
|
||||
# Search in all configured model directories
|
||||
for models_dir in model_paths:
|
||||
# Check common path patterns
|
||||
possible_paths = [
|
||||
models_dir / model_info.name,
|
||||
models_dir / model_info.name.lower(),
|
||||
models_dir / model_info.name.replace(" ", "-"),
|
||||
models_dir / model_info.hf_repo.split("/")[-1],
|
||||
models_dir / model_info.hf_repo.replace("/", "--"),
|
||||
]
|
||||
|
||||
# Add alias-based paths
|
||||
for alias in model_info.aliases:
|
||||
possible_paths.append(models_dir / alias)
|
||||
possible_paths.append(models_dir / alias.lower())
|
||||
|
||||
for path in possible_paths:
|
||||
if path.exists() and (path / "config.json").exists():
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_weights_path(model_info: ModelInfo, settings) -> Optional[Path]:
|
||||
"""Find the quantized weights path on disk by searching all configured paths."""
|
||||
model_paths = settings.get_model_paths()
|
||||
weights_dir = settings.weights_dir
|
||||
|
||||
# Check common patterns
|
||||
base_names = [
|
||||
model_info.name,
|
||||
model_info.name.lower(),
|
||||
model_info.hf_repo.split("/")[-1],
|
||||
]
|
||||
|
||||
suffixes = ["-INT4", "-int4", "_INT4", "_int4", "-quant", "-quantized"]
|
||||
|
||||
# Prepare search directories
|
||||
search_dirs = [weights_dir] if weights_dir else []
|
||||
search_dirs.extend(model_paths)
|
||||
|
||||
for base in base_names:
|
||||
for suffix in suffixes:
|
||||
for dir_path in search_dirs:
|
||||
if dir_path:
|
||||
path = dir_path / f"{base}{suffix}"
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _build_sglang_command(
|
||||
model_path: Path,
|
||||
weights_path: Optional[Path],
|
||||
model_info: Optional[ModelInfo],
|
||||
host: str,
|
||||
port: int,
|
||||
gpu_experts: int,
|
||||
cpu_threads: int,
|
||||
numa_nodes: int,
|
||||
tensor_parallel_size: int,
|
||||
kt_method: str,
|
||||
kt_gpu_prefill_threshold: int,
|
||||
attention_backend: str,
|
||||
max_total_tokens: int,
|
||||
max_running_requests: int,
|
||||
chunked_prefill_size: int,
|
||||
mem_fraction_static: float,
|
||||
watchdog_timeout: int,
|
||||
served_model_name: str,
|
||||
disable_shared_experts_fusion: bool,
|
||||
settings,
|
||||
extra_model_params: Optional[dict] = None, # New parameter for additional params
|
||||
) -> list[str]:
|
||||
"""Build the SGLang launch command."""
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
str(port),
|
||||
"--model",
|
||||
str(model_path),
|
||||
]
|
||||
|
||||
# Add kt-kernel options
|
||||
# kt-kernel is needed for:
|
||||
# 1. Quantized models (when weights_path is provided)
|
||||
# 2. MoE models with CPU offloading (when kt-cpuinfer > 0 or kt-num-gpu-experts is configured)
|
||||
use_kt_kernel = False
|
||||
|
||||
# Check if we should use kt-kernel
|
||||
if weights_path:
|
||||
# Quantized model - always use kt-kernel
|
||||
use_kt_kernel = True
|
||||
elif cpu_threads > 0 or gpu_experts > 1:
|
||||
# CPU offloading configured - use kt-kernel
|
||||
use_kt_kernel = True
|
||||
elif model_info and model_info.type == "moe":
|
||||
# MoE model - likely needs kt-kernel for expert offloading
|
||||
use_kt_kernel = True
|
||||
|
||||
if use_kt_kernel:
|
||||
# Add kt-weight-path: use quantized weights if available, otherwise use model path
|
||||
weight_path_to_use = weights_path if weights_path else model_path
|
||||
|
||||
# Add kt-kernel configuration
|
||||
cmd.extend(
|
||||
[
|
||||
"--kt-weight-path",
|
||||
str(weight_path_to_use),
|
||||
"--kt-cpuinfer",
|
||||
str(cpu_threads),
|
||||
"--kt-threadpool-count",
|
||||
str(numa_nodes),
|
||||
"--kt-num-gpu-experts",
|
||||
str(gpu_experts),
|
||||
"--kt-method",
|
||||
kt_method,
|
||||
"--kt-gpu-prefill-token-threshold",
|
||||
str(kt_gpu_prefill_threshold),
|
||||
]
|
||||
)
|
||||
|
||||
# Add SGLang options
|
||||
cmd.extend(
|
||||
[
|
||||
"--attention-backend",
|
||||
attention_backend,
|
||||
"--trust-remote-code",
|
||||
"--mem-fraction-static",
|
||||
str(mem_fraction_static),
|
||||
"--chunked-prefill-size",
|
||||
str(chunked_prefill_size),
|
||||
"--max-running-requests",
|
||||
str(max_running_requests),
|
||||
"--max-total-tokens",
|
||||
str(max_total_tokens),
|
||||
"--watchdog-timeout",
|
||||
str(watchdog_timeout),
|
||||
"--enable-mixed-chunk",
|
||||
"--tensor-parallel-size",
|
||||
str(tensor_parallel_size),
|
||||
"--enable-p2p-check",
|
||||
]
|
||||
)
|
||||
|
||||
# Add served model name if specified
|
||||
if served_model_name:
|
||||
cmd.extend(["--served-model-name", served_model_name])
|
||||
|
||||
# Add performance flags
|
||||
if disable_shared_experts_fusion:
|
||||
cmd.append("--disable-shared-experts-fusion")
|
||||
|
||||
# Add any extra parameters from model defaults that weren't explicitly handled
|
||||
if extra_model_params:
|
||||
# List of parameters already handled above
|
||||
handled_params = {
|
||||
"kt-num-gpu-experts",
|
||||
"kt-cpuinfer",
|
||||
"kt-threadpool-count",
|
||||
"kt-method",
|
||||
"kt-gpu-prefill-token-threshold",
|
||||
"attention-backend",
|
||||
"tensor-parallel-size",
|
||||
"max-total-tokens",
|
||||
"max-running-requests",
|
||||
"chunked-prefill-size",
|
||||
"mem-fraction-static",
|
||||
"watchdog-timeout",
|
||||
"served-model-name",
|
||||
"disable-shared-experts-fusion",
|
||||
}
|
||||
|
||||
for key, value in extra_model_params.items():
|
||||
if key not in handled_params:
|
||||
# Add unhandled parameters dynamically
|
||||
cmd.append(f"--{key}")
|
||||
if isinstance(value, bool):
|
||||
# Boolean flags don't need a value
|
||||
if not value:
|
||||
# For False boolean, skip the flag entirely
|
||||
cmd.pop() # Remove the flag we just added
|
||||
else:
|
||||
cmd.append(str(value))
|
||||
|
||||
# Add extra args from settings
|
||||
extra_args = settings.get("advanced.sglang_args", [])
|
||||
if extra_args:
|
||||
cmd.extend(extra_args)
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def _interactive_model_selection(registry, settings) -> Optional[str]:
|
||||
"""Show interactive model selection interface.
|
||||
|
||||
Returns:
|
||||
Selected model name or None if cancelled.
|
||||
"""
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.prompt import Prompt
|
||||
|
||||
from kt_kernel.cli.i18n import get_lang
|
||||
|
||||
lang = get_lang()
|
||||
|
||||
# Find local models first
|
||||
local_models = registry.find_local_models()
|
||||
|
||||
# Get all registered models
|
||||
all_models = registry.list_all()
|
||||
|
||||
console.print()
|
||||
console.print(
|
||||
Panel.fit(
|
||||
t("run_select_model_title"),
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
console.print()
|
||||
|
||||
# Build choices list
|
||||
choices = []
|
||||
choice_map = {} # index -> model name
|
||||
|
||||
# Section 1: Local models (downloaded)
|
||||
if local_models:
|
||||
console.print(f"[bold green]{t('run_local_models')}[/bold green]")
|
||||
console.print()
|
||||
|
||||
for i, (model_info, path) in enumerate(local_models, 1):
|
||||
desc = model_info.description_zh if lang == "zh" else model_info.description
|
||||
short_desc = desc[:50] + "..." if len(desc) > 50 else desc
|
||||
console.print(f" [cyan][{i}][/cyan] [bold]{model_info.name}[/bold]")
|
||||
console.print(f" [dim]{short_desc}[/dim]")
|
||||
console.print(f" [dim]{path}[/dim]")
|
||||
choices.append(str(i))
|
||||
choice_map[str(i)] = model_info.name
|
||||
|
||||
console.print()
|
||||
|
||||
# Section 2: All registered models (for reference)
|
||||
start_idx = len(local_models) + 1
|
||||
console.print(f"[bold yellow]{t('run_registered_models')}[/bold yellow]")
|
||||
console.print()
|
||||
|
||||
# Filter out already shown local models
|
||||
local_model_names = {m.name for m, _ in local_models}
|
||||
|
||||
for i, model_info in enumerate(all_models, start_idx):
|
||||
if model_info.name in local_model_names:
|
||||
continue
|
||||
|
||||
desc = model_info.description_zh if lang == "zh" else model_info.description
|
||||
short_desc = desc[:50] + "..." if len(desc) > 50 else desc
|
||||
console.print(f" [cyan][{i}][/cyan] [bold]{model_info.name}[/bold]")
|
||||
console.print(f" [dim]{short_desc}[/dim]")
|
||||
console.print(f" [dim]{model_info.hf_repo}[/dim]")
|
||||
choices.append(str(i))
|
||||
choice_map[str(i)] = model_info.name
|
||||
|
||||
console.print()
|
||||
|
||||
# Add cancel option
|
||||
cancel_idx = str(len(choices) + 1)
|
||||
console.print(f" [cyan][{cancel_idx}][/cyan] [dim]{t('cancel')}[/dim]")
|
||||
choices.append(cancel_idx)
|
||||
console.print()
|
||||
|
||||
# Prompt for selection
|
||||
try:
|
||||
selection = Prompt.ask(
|
||||
t("run_select_model_prompt"),
|
||||
choices=choices,
|
||||
default="1" if choices else cancel_idx,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
console.print()
|
||||
return None
|
||||
|
||||
if selection == cancel_idx:
|
||||
return None
|
||||
|
||||
return choice_map.get(selection)
|
||||
52
kt-kernel/python/cli/commands/sft.py
Normal file
52
kt-kernel/python/cli/commands/sft.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
"""
|
||||
SFT command for kt-cli.
|
||||
|
||||
Fine-tuning with LlamaFactory integration.
|
||||
"""
|
||||
|
||||
import typer
|
||||
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import console
|
||||
|
||||
app = typer.Typer(help="Fine-tuning with LlamaFactory (coming soon)")
|
||||
|
||||
|
||||
@app.callback(invoke_without_command=True)
|
||||
def callback(ctx: typer.Context) -> None:
|
||||
"""Fine-tuning commands (coming soon)."""
|
||||
if ctx.invoked_subcommand is None:
|
||||
console.print()
|
||||
console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]")
|
||||
console.print()
|
||||
console.print("[dim]kt sft train - Train a model[/dim]")
|
||||
console.print("[dim]kt sft chat - Chat with a trained model[/dim]")
|
||||
console.print("[dim]kt sft export - Export a trained model[/dim]")
|
||||
console.print()
|
||||
|
||||
|
||||
@app.command(name="train")
|
||||
def train() -> None:
|
||||
"""Train a model using LlamaFactory (coming soon)."""
|
||||
console.print()
|
||||
console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]")
|
||||
console.print()
|
||||
raise typer.Exit(0)
|
||||
|
||||
|
||||
@app.command(name="chat")
|
||||
def chat() -> None:
|
||||
"""Chat with a trained model using LlamaFactory (coming soon)."""
|
||||
console.print()
|
||||
console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]")
|
||||
console.print()
|
||||
raise typer.Exit(0)
|
||||
|
||||
|
||||
@app.command(name="export")
|
||||
def export() -> None:
|
||||
"""Export a trained model using LlamaFactory (coming soon)."""
|
||||
console.print()
|
||||
console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]")
|
||||
console.print()
|
||||
raise typer.Exit(0)
|
||||
118
kt-kernel/python/cli/commands/version.py
Normal file
118
kt-kernel/python/cli/commands/version.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
"""
|
||||
Version command for kt-cli.
|
||||
|
||||
Displays version information for kt-cli and related packages.
|
||||
"""
|
||||
|
||||
import platform
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
|
||||
from kt_kernel.cli import __version__
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import console, print_version_table
|
||||
from kt_kernel.cli.utils.environment import detect_cuda_version, get_installed_package_version
|
||||
|
||||
|
||||
def _get_sglang_info() -> str:
|
||||
"""Get sglang version and installation source information."""
|
||||
try:
|
||||
import sglang
|
||||
|
||||
version = getattr(sglang, "__version__", None)
|
||||
|
||||
if not version:
|
||||
version = get_installed_package_version("sglang")
|
||||
|
||||
if not version:
|
||||
return t("version_not_installed")
|
||||
|
||||
# Try to detect installation source
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
|
||||
if hasattr(sglang, "__file__") and sglang.__file__:
|
||||
location = Path(sglang.__file__).parent.parent
|
||||
git_dir = location / ".git"
|
||||
|
||||
if git_dir.exists():
|
||||
# Installed from git (editable install)
|
||||
try:
|
||||
# Get remote URL
|
||||
result = subprocess.run(
|
||||
["git", "remote", "get-url", "origin"],
|
||||
cwd=location,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=2,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
remote_url = result.stdout.strip()
|
||||
# Simplify GitHub URLs
|
||||
if "github.com" in remote_url:
|
||||
repo_name = remote_url.split("/")[-1].replace(".git", "")
|
||||
owner = remote_url.split("/")[-2]
|
||||
return f"{version} [dim](GitHub: {owner}/{repo_name})[/dim]"
|
||||
return f"{version} [dim](Git: {remote_url})[/dim]"
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
pass
|
||||
|
||||
# Default: installed from PyPI
|
||||
return f"{version} [dim](PyPI)[/dim]"
|
||||
|
||||
except ImportError:
|
||||
return t("version_not_installed")
|
||||
|
||||
|
||||
def version(
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed version info"),
|
||||
) -> None:
|
||||
"""Show version information."""
|
||||
console.print(f"\n[bold]{t('version_info')}[/bold] v{__version__}\n")
|
||||
|
||||
# Basic info
|
||||
versions = {
|
||||
t("version_python"): platform.python_version(),
|
||||
t("version_platform"): f"{platform.system()} {platform.release()}",
|
||||
}
|
||||
|
||||
# CUDA version
|
||||
cuda_version = detect_cuda_version()
|
||||
versions[t("version_cuda")] = cuda_version or t("version_cuda_not_found")
|
||||
|
||||
print_version_table(versions)
|
||||
|
||||
# Always show key packages with installation source
|
||||
console.print("\n[bold]Packages:[/bold]\n")
|
||||
|
||||
sglang_info = _get_sglang_info()
|
||||
key_packages = {
|
||||
t("version_kt_kernel"): get_installed_package_version("kt-kernel") or t("version_not_installed"),
|
||||
t("version_sglang"): sglang_info,
|
||||
}
|
||||
|
||||
print_version_table(key_packages)
|
||||
|
||||
# Show SGLang installation hint if not installed
|
||||
if sglang_info == t("version_not_installed"):
|
||||
from kt_kernel.cli.utils.sglang_checker import print_sglang_install_instructions
|
||||
|
||||
console.print()
|
||||
print_sglang_install_instructions()
|
||||
|
||||
if verbose:
|
||||
console.print("\n[bold]Additional Packages:[/bold]\n")
|
||||
|
||||
package_versions = {
|
||||
t("version_ktransformers"): get_installed_package_version("ktransformers") or t("version_not_installed"),
|
||||
t("version_llamafactory"): get_installed_package_version("llamafactory") or t("version_not_installed"),
|
||||
"typer": get_installed_package_version("typer") or t("version_not_installed"),
|
||||
"rich": get_installed_package_version("rich") or t("version_not_installed"),
|
||||
"torch": get_installed_package_version("torch") or t("version_not_installed"),
|
||||
"transformers": get_installed_package_version("transformers") or t("version_not_installed"),
|
||||
}
|
||||
|
||||
print_version_table(package_versions)
|
||||
|
||||
console.print()
|
||||
Loading…
Add table
Add a link
Reference in a new issue