mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-29 04:09:52 +00:00
parent
e7d277d163
commit
d8046e1bb4
65 changed files with 12111 additions and 2502 deletions
3
kt-kernel/python/cli/utils/__init__.py
Normal file
3
kt-kernel/python/cli/utils/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
Utility modules for kt-cli.
|
||||
"""
|
||||
249
kt-kernel/python/cli/utils/console.py
Normal file
249
kt-kernel/python/cli/utils/console.py
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
"""
|
||||
Console utilities for kt-cli.
|
||||
|
||||
Provides Rich-based console output helpers for consistent formatting.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
DownloadColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskProgressColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TransferSpeedColumn,
|
||||
)
|
||||
from rich.prompt import Confirm, Prompt
|
||||
from rich.table import Table
|
||||
from rich.theme import Theme
|
||||
|
||||
from kt_kernel.cli.i18n import t
|
||||
|
||||
# Custom theme for kt-cli
|
||||
KT_THEME = Theme(
|
||||
{
|
||||
"info": "cyan",
|
||||
"warning": "yellow",
|
||||
"error": "bold red",
|
||||
"success": "bold green",
|
||||
"highlight": "bold magenta",
|
||||
"muted": "dim",
|
||||
}
|
||||
)
|
||||
|
||||
# Global console instance
|
||||
console = Console(theme=KT_THEME)
|
||||
|
||||
|
||||
def print_info(message: str, **kwargs) -> None:
|
||||
"""Print an info message."""
|
||||
console.print(f"[info]ℹ[/info] {message}", **kwargs)
|
||||
|
||||
|
||||
def print_success(message: str, **kwargs) -> None:
|
||||
"""Print a success message."""
|
||||
console.print(f"[success]✓[/success] {message}", **kwargs)
|
||||
|
||||
|
||||
def print_warning(message: str, **kwargs) -> None:
|
||||
"""Print a warning message."""
|
||||
console.print(f"[warning]⚠[/warning] {message}", **kwargs)
|
||||
|
||||
|
||||
def print_error(message: str, **kwargs) -> None:
|
||||
"""Print an error message."""
|
||||
console.print(f"[error]✗[/error] {message}", **kwargs)
|
||||
|
||||
|
||||
def print_step(message: str, **kwargs) -> None:
|
||||
"""Print a step indicator."""
|
||||
console.print(f"[highlight]→[/highlight] {message}", **kwargs)
|
||||
|
||||
|
||||
def print_header(title: str, subtitle: Optional[str] = None) -> None:
|
||||
"""Print a header panel."""
|
||||
content = f"[bold]{title}[/bold]"
|
||||
if subtitle:
|
||||
content += f"\n[muted]{subtitle}[/muted]"
|
||||
console.print(Panel(content, expand=False))
|
||||
|
||||
|
||||
def print_version_table(versions: dict[str, Optional[str]]) -> None:
|
||||
"""Print a version information table."""
|
||||
table = Table(show_header=False, box=None, padding=(0, 2))
|
||||
table.add_column("Component", style="bold")
|
||||
table.add_column("Version")
|
||||
|
||||
for name, version in versions.items():
|
||||
if version:
|
||||
table.add_row(name, f"[success]{version}[/success]")
|
||||
else:
|
||||
table.add_row(name, f"[muted]{t('version_not_installed')}[/muted]")
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def print_dependency_table(deps: list[dict]) -> None:
|
||||
"""Print a dependency status table."""
|
||||
table = Table(title=t("install_checking_deps"))
|
||||
table.add_column(t("version_info"), style="bold")
|
||||
table.add_column("Current")
|
||||
table.add_column("Required")
|
||||
table.add_column("Status")
|
||||
|
||||
for dep in deps:
|
||||
status = dep.get("status", "ok")
|
||||
if status == "ok":
|
||||
status_str = f"[success]{t('install_dep_ok')}[/success]"
|
||||
elif status == "outdated":
|
||||
status_str = f"[warning]{t('install_dep_outdated')}[/warning]"
|
||||
else:
|
||||
status_str = f"[error]{t('install_dep_missing')}[/error]"
|
||||
|
||||
table.add_row(
|
||||
dep["name"],
|
||||
dep.get("installed", "-"),
|
||||
dep.get("required", "-"),
|
||||
status_str,
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def confirm(message: str, default: bool = True) -> bool:
|
||||
"""Ask for confirmation."""
|
||||
return Confirm.ask(message, default=default, console=console)
|
||||
|
||||
|
||||
def prompt_choice(message: str, choices: list[str], default: Optional[str] = None) -> str:
|
||||
"""Prompt for a choice from a list."""
|
||||
# Display numbered choices
|
||||
console.print(f"\n[bold]{message}[/bold]")
|
||||
for i, choice in enumerate(choices, 1):
|
||||
console.print(f" [highlight][{i}][/highlight] {choice}")
|
||||
|
||||
while True:
|
||||
response = Prompt.ask(
|
||||
"\n" + t("prompt_select"),
|
||||
console=console,
|
||||
default=str(choices.index(default) + 1) if default else None,
|
||||
)
|
||||
try:
|
||||
idx = int(response) - 1
|
||||
if 0 <= idx < len(choices):
|
||||
return choices[idx]
|
||||
except ValueError:
|
||||
# Check if response matches a choice directly
|
||||
if response in choices:
|
||||
return response
|
||||
|
||||
print_error(f"Please enter a number between 1 and {len(choices)}")
|
||||
|
||||
|
||||
def prompt_text(message: str, default: Optional[str] = None) -> str:
|
||||
"""Prompt for text input."""
|
||||
return Prompt.ask(message, console=console, default=default)
|
||||
|
||||
|
||||
def create_progress() -> Progress:
|
||||
"""Create a progress bar for general tasks."""
|
||||
return Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TimeElapsedColumn(),
|
||||
console=console,
|
||||
)
|
||||
|
||||
|
||||
def create_download_progress() -> Progress:
|
||||
"""Create a progress bar for downloads."""
|
||||
return Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
DownloadColumn(),
|
||||
TransferSpeedColumn(),
|
||||
TimeRemainingColumn(),
|
||||
console=console,
|
||||
)
|
||||
|
||||
|
||||
def print_model_table(models: list[dict]) -> None:
|
||||
"""Print a table of models."""
|
||||
table = Table(title=t("download_list_title"))
|
||||
table.add_column("Name", style="bold")
|
||||
table.add_column("Repository")
|
||||
table.add_column("Type")
|
||||
table.add_column("Requirements")
|
||||
|
||||
for model in models:
|
||||
reqs = []
|
||||
if model.get("gpu_vram_gb"):
|
||||
reqs.append(f"GPU: {model['gpu_vram_gb']}GB")
|
||||
if model.get("cpu_ram_gb"):
|
||||
reqs.append(f"RAM: {model['cpu_ram_gb']}GB")
|
||||
|
||||
table.add_row(
|
||||
model.get("name", ""),
|
||||
model.get("hf_repo", ""),
|
||||
model.get("type", ""),
|
||||
", ".join(reqs) if reqs else "-",
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def print_hardware_info(gpu_info: str, cpu_info: str, ram_info: str) -> None:
|
||||
"""Print hardware information."""
|
||||
table = Table(show_header=False, box=None)
|
||||
table.add_column("Icon", width=3)
|
||||
table.add_column("Info")
|
||||
|
||||
table.add_row("🖥️", gpu_info)
|
||||
table.add_row("💻", cpu_info)
|
||||
table.add_row("🧠", ram_info)
|
||||
|
||||
console.print(Panel(table, title="Hardware", expand=False))
|
||||
|
||||
|
||||
def print_server_info(
|
||||
mode: str, host: str, port: int, gpu_experts: int, cpu_threads: int
|
||||
) -> None:
|
||||
"""Print server startup information."""
|
||||
table = Table(show_header=False, box=None)
|
||||
table.add_column("Key", style="bold")
|
||||
table.add_column("Value")
|
||||
|
||||
table.add_row(t("run_server_mode").split(":")[0], mode)
|
||||
table.add_row("Host", host)
|
||||
table.add_row("Port", str(port))
|
||||
table.add_row(t("run_gpu_experts").split(":")[0], f"{gpu_experts}/layer")
|
||||
table.add_row(t("run_cpu_threads").split(":")[0], str(cpu_threads))
|
||||
|
||||
console.print(Panel(table, title=t("run_server_started"), expand=False, border_style="green"))
|
||||
|
||||
|
||||
def print_api_info(host: str, port: int) -> None:
|
||||
"""Print API endpoint information."""
|
||||
api_url = f"http://{host}:{port}"
|
||||
docs_url = f"http://{host}:{port}/docs"
|
||||
|
||||
console.print()
|
||||
console.print(f" {t('run_api_url', host=host, port=port)}")
|
||||
console.print(f" {t('run_docs_url', host=host, port=port)}")
|
||||
console.print()
|
||||
console.print(f" [muted]Test command:[/muted]")
|
||||
console.print(
|
||||
f" [dim]curl {api_url}/v1/chat/completions -H 'Content-Type: application/json' "
|
||||
f"-d '{{\"model\": \"default\", \"messages\": [{{\"role\": \"user\", \"content\": \"Hello\"}}]}}'[/dim]"
|
||||
)
|
||||
console.print()
|
||||
console.print(f" [muted]{t('run_stop_hint')}[/muted]")
|
||||
1108
kt-kernel/python/cli/utils/environment.py
Normal file
1108
kt-kernel/python/cli/utils/environment.py
Normal file
File diff suppressed because it is too large
Load diff
374
kt-kernel/python/cli/utils/model_registry.py
Normal file
374
kt-kernel/python/cli/utils/model_registry.py
Normal file
|
|
@ -0,0 +1,374 @@
|
|||
"""
|
||||
Model registry for kt-cli.
|
||||
|
||||
Provides a registry of supported models with fuzzy matching capabilities.
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Information about a supported model."""
|
||||
|
||||
name: str
|
||||
hf_repo: str
|
||||
aliases: list[str] = field(default_factory=list)
|
||||
type: str = "moe" # moe, dense
|
||||
gpu_vram_gb: float = 0
|
||||
cpu_ram_gb: float = 0
|
||||
default_params: dict = field(default_factory=dict)
|
||||
description: str = ""
|
||||
description_zh: str = ""
|
||||
max_tensor_parallel_size: Optional[int] = None # Maximum tensor parallel size for this model
|
||||
|
||||
|
||||
# Built-in model registry
|
||||
BUILTIN_MODELS: list[ModelInfo] = [
|
||||
ModelInfo(
|
||||
name="DeepSeek-V3-0324",
|
||||
hf_repo="deepseek-ai/DeepSeek-V3-0324",
|
||||
aliases=["deepseek-v3-0324", "deepseek-v3", "dsv3", "deepseek3", "v3-0324"],
|
||||
type="moe",
|
||||
default_params={
|
||||
"kt-num-gpu-experts": 1,
|
||||
"attention-backend": "triton",
|
||||
"disable-shared-experts-fusion": True,
|
||||
"kt-method": "AMXINT4",
|
||||
},
|
||||
description="DeepSeek V3-0324 685B MoE model (March 2025, improved benchmarks)",
|
||||
description_zh="DeepSeek V3-0324 685B MoE 模型(2025年3月,改进的基准测试)",
|
||||
),
|
||||
ModelInfo(
|
||||
name="DeepSeek-V3.2",
|
||||
hf_repo="deepseek-ai/DeepSeek-V3.2",
|
||||
aliases=["deepseek-v3.2", "dsv3.2", "deepseek3.2", "v3.2"],
|
||||
type="moe",
|
||||
default_params={
|
||||
"kt-method": "FP8",
|
||||
"kt-gpu-prefill-token-threshold": 4096,
|
||||
"attention-backend": "flashinfer",
|
||||
"fp8-gemm-backend": "triton",
|
||||
"max-total-tokens": 100000,
|
||||
"max-running-requests": 16,
|
||||
"chunked-prefill-size": 32768,
|
||||
"mem-fraction-static": 0.80,
|
||||
"watchdog-timeout": 3000,
|
||||
"served-model-name": "DeepSeek-V3.2",
|
||||
"disable-shared-experts-fusion": True,
|
||||
},
|
||||
description="DeepSeek V3.2 671B MoE model (latest)",
|
||||
description_zh="DeepSeek V3.2 671B MoE 模型(最新)",
|
||||
),
|
||||
ModelInfo(
|
||||
name="DeepSeek-R1-0528",
|
||||
hf_repo="deepseek-ai/DeepSeek-R1-0528",
|
||||
aliases=["deepseek-r1-0528", "deepseek-r1", "dsr1", "r1", "r1-0528"],
|
||||
type="moe",
|
||||
default_params={
|
||||
"kt-num-gpu-experts": 1,
|
||||
"attention-backend": "triton",
|
||||
"disable-shared-experts-fusion": True,
|
||||
"kt-method": "AMXINT4",
|
||||
},
|
||||
description="DeepSeek R1-0528 reasoning model (May 2025, improved reasoning depth)",
|
||||
description_zh="DeepSeek R1-0528 推理模型(2025年5月,改进的推理深度)",
|
||||
),
|
||||
ModelInfo(
|
||||
name="Kimi-K2-Thinking",
|
||||
hf_repo="moonshotai/Kimi-K2-Thinking",
|
||||
aliases=["kimi-k2-thinking", "kimi-thinking", "k2-thinking", "kimi", "k2"],
|
||||
type="moe",
|
||||
default_params={
|
||||
"kt-method": "RAWINT4",
|
||||
"kt-gpu-prefill-token-threshold": 400,
|
||||
"attention-backend": "flashinfer",
|
||||
"max-total-tokens": 100000,
|
||||
"max-running-requests": 16,
|
||||
"chunked-prefill-size": 32768,
|
||||
"mem-fraction-static": 0.80,
|
||||
"watchdog-timeout": 3000,
|
||||
"served-model-name": "Kimi-K2-Thinking",
|
||||
"disable-shared-experts-fusion": True,
|
||||
},
|
||||
description="Moonshot Kimi K2 Thinking MoE model",
|
||||
description_zh="月之暗面 Kimi K2 Thinking MoE 模型",
|
||||
),
|
||||
ModelInfo(
|
||||
name="MiniMax-M2",
|
||||
hf_repo="MiniMaxAI/MiniMax-M2",
|
||||
aliases=["minimax-m2", "m2"],
|
||||
type="moe",
|
||||
default_params={
|
||||
"kt-method": "FP8",
|
||||
"kt-gpu-prefill-token-threshold": 4096,
|
||||
"attention-backend": "flashinfer",
|
||||
"fp8-gemm-backend": "triton",
|
||||
"max-total-tokens": 100000,
|
||||
"max-running-requests": 16,
|
||||
"chunked-prefill-size": 32768,
|
||||
"mem-fraction-static": 0.80,
|
||||
"watchdog-timeout": 3000,
|
||||
"served-model-name": "MiniMax-M2",
|
||||
"disable-shared-experts-fusion": True,
|
||||
"tool-call-parser": "minimax-m2",
|
||||
"reasoning-parser": "minimax-append-think",
|
||||
},
|
||||
description="MiniMax M2 MoE model",
|
||||
description_zh="MiniMax M2 MoE 模型",
|
||||
max_tensor_parallel_size=4, # M2 only supports up to 4-way tensor parallelism
|
||||
),
|
||||
ModelInfo(
|
||||
name="MiniMax-M2.1",
|
||||
hf_repo="MiniMaxAI/MiniMax-M2.1",
|
||||
aliases=["minimax-m2.1", "m2.1"],
|
||||
type="moe",
|
||||
default_params={
|
||||
"kt-method": "FP8",
|
||||
"kt-gpu-prefill-token-threshold": 4096,
|
||||
"attention-backend": "flashinfer",
|
||||
"fp8-gemm-backend": "triton",
|
||||
"max-total-tokens": 100000,
|
||||
"max-running-requests": 16,
|
||||
"chunked-prefill-size": 32768,
|
||||
"mem-fraction-static": 0.80,
|
||||
"watchdog-timeout": 3000,
|
||||
"served-model-name": "MiniMax-M2.1",
|
||||
"disable-shared-experts-fusion": True,
|
||||
"tool-call-parser": "minimax-m2",
|
||||
"reasoning-parser": "minimax-append-think",
|
||||
},
|
||||
description="MiniMax M2.1 MoE model (enhanced multi-language programming)",
|
||||
description_zh="MiniMax M2.1 MoE 模型(增强多语言编程能力)",
|
||||
max_tensor_parallel_size=4, # M2.1 only supports up to 4-way tensor parallelism
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""Registry of supported models with fuzzy matching."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the model registry."""
|
||||
self._models: dict[str, ModelInfo] = {}
|
||||
self._aliases: dict[str, str] = {}
|
||||
self._load_builtin_models()
|
||||
self._load_user_models()
|
||||
|
||||
def _load_builtin_models(self) -> None:
|
||||
"""Load built-in models."""
|
||||
for model in BUILTIN_MODELS:
|
||||
self._register(model)
|
||||
|
||||
def _load_user_models(self) -> None:
|
||||
"""Load user-defined models from config."""
|
||||
settings = get_settings()
|
||||
registry_file = settings.config_dir / "registry.yaml"
|
||||
|
||||
if registry_file.exists():
|
||||
try:
|
||||
with open(registry_file, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
|
||||
for name, info in data.get("models", {}).items():
|
||||
model = ModelInfo(
|
||||
name=name,
|
||||
hf_repo=info.get("hf_repo", ""),
|
||||
aliases=info.get("aliases", []),
|
||||
type=info.get("type", "moe"),
|
||||
gpu_vram_gb=info.get("gpu_vram_gb", 0),
|
||||
cpu_ram_gb=info.get("cpu_ram_gb", 0),
|
||||
default_params=info.get("default_params", {}),
|
||||
description=info.get("description", ""),
|
||||
description_zh=info.get("description_zh", ""),
|
||||
max_tensor_parallel_size=info.get("max_tensor_parallel_size"),
|
||||
)
|
||||
self._register(model)
|
||||
except (yaml.YAMLError, OSError):
|
||||
pass
|
||||
|
||||
def _register(self, model: ModelInfo) -> None:
|
||||
"""Register a model."""
|
||||
self._models[model.name.lower()] = model
|
||||
|
||||
# Register aliases
|
||||
for alias in model.aliases:
|
||||
self._aliases[alias.lower()] = model.name.lower()
|
||||
|
||||
def get(self, name: str) -> Optional[ModelInfo]:
|
||||
"""Get a model by exact name or alias."""
|
||||
name_lower = name.lower()
|
||||
|
||||
# Check direct match
|
||||
if name_lower in self._models:
|
||||
return self._models[name_lower]
|
||||
|
||||
# Check aliases
|
||||
if name_lower in self._aliases:
|
||||
return self._models[self._aliases[name_lower]]
|
||||
|
||||
return None
|
||||
|
||||
def search(self, query: str, limit: int = 10) -> list[ModelInfo]:
|
||||
"""Search for models using fuzzy matching.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
limit: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of matching models, sorted by relevance
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
results: list[tuple[float, ModelInfo]] = []
|
||||
|
||||
for model in self._models.values():
|
||||
score = self._match_score(query_lower, model)
|
||||
if score > 0:
|
||||
results.append((score, model))
|
||||
|
||||
# Sort by score descending
|
||||
results.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
return [model for _, model in results[:limit]]
|
||||
|
||||
def _match_score(self, query: str, model: ModelInfo) -> float:
|
||||
"""Calculate match score for a model.
|
||||
|
||||
Returns a score between 0 and 1, where 1 is an exact match.
|
||||
"""
|
||||
# Check exact match
|
||||
if query == model.name.lower():
|
||||
return 1.0
|
||||
|
||||
# Check alias exact match
|
||||
for alias in model.aliases:
|
||||
if query == alias.lower():
|
||||
return 0.95
|
||||
|
||||
# Check if query is contained in name
|
||||
if query in model.name.lower():
|
||||
return 0.8
|
||||
|
||||
# Check if query is contained in aliases
|
||||
for alias in model.aliases:
|
||||
if query in alias.lower():
|
||||
return 0.7
|
||||
|
||||
# Check if query is contained in hf_repo
|
||||
if query in model.hf_repo.lower():
|
||||
return 0.6
|
||||
|
||||
# Fuzzy matching - check if all query parts are present
|
||||
query_parts = re.split(r"[-_.\s]", query)
|
||||
name_lower = model.name.lower()
|
||||
|
||||
matches = sum(1 for part in query_parts if part and part in name_lower)
|
||||
if matches > 0:
|
||||
return 0.5 * (matches / len(query_parts))
|
||||
|
||||
return 0.0
|
||||
|
||||
def list_all(self) -> list[ModelInfo]:
|
||||
"""List all registered models."""
|
||||
return list(self._models.values())
|
||||
|
||||
def find_local_models(self) -> list[tuple[ModelInfo, Path]]:
|
||||
"""Find models that are downloaded locally in any configured model path.
|
||||
|
||||
Returns:
|
||||
List of (ModelInfo, path) tuples for local models
|
||||
"""
|
||||
settings = get_settings()
|
||||
model_paths = settings.get_model_paths()
|
||||
results = []
|
||||
|
||||
for model in self._models.values():
|
||||
found = False
|
||||
# Search in all configured model directories
|
||||
for models_dir in model_paths:
|
||||
if not models_dir.exists():
|
||||
continue
|
||||
|
||||
# Check common path patterns
|
||||
possible_paths = [
|
||||
models_dir / model.name,
|
||||
models_dir / model.name.lower(),
|
||||
models_dir / model.hf_repo.split("/")[-1],
|
||||
models_dir / model.hf_repo.replace("/", "--"),
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
if path.exists() and (path / "config.json").exists():
|
||||
results.append((model, path))
|
||||
found = True
|
||||
break
|
||||
|
||||
if found:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Global registry instance
|
||||
_registry: Optional[ModelRegistry] = None
|
||||
|
||||
|
||||
def get_registry() -> ModelRegistry:
|
||||
"""Get the global model registry instance."""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
_registry = ModelRegistry()
|
||||
return _registry
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Model-specific parameter computation functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def compute_deepseek_v3_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int:
|
||||
per_gpu_gb = 16
|
||||
if vram_per_gpu_gb < per_gpu_gb:
|
||||
return int(0)
|
||||
total_vram = int(tensor_parallel_size * (vram_per_gpu_gb - per_gpu_gb))
|
||||
|
||||
return total_vram // 3
|
||||
|
||||
|
||||
def compute_kimi_k2_thinking_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int:
|
||||
"""Compute kt-num-gpu-experts for Kimi K2 Thinking."""
|
||||
per_gpu_gb = 16
|
||||
if vram_per_gpu_gb < per_gpu_gb:
|
||||
return int(0)
|
||||
total_vram = int(tensor_parallel_size * (vram_per_gpu_gb - per_gpu_gb))
|
||||
|
||||
return total_vram * 2 // 3
|
||||
|
||||
|
||||
def compute_minimax_m2_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int:
|
||||
"""Compute kt-num-gpu-experts for MiniMax M2/M2.1."""
|
||||
per_gpu_gb = 16
|
||||
if vram_per_gpu_gb < per_gpu_gb:
|
||||
return int(0)
|
||||
total_vram = int(tensor_parallel_size * (vram_per_gpu_gb - per_gpu_gb))
|
||||
|
||||
return total_vram // 1
|
||||
|
||||
|
||||
# Model name to computation function mapping
|
||||
MODEL_COMPUTE_FUNCTIONS: dict[str, Callable[[int, float], int]] = {
|
||||
"DeepSeek-V3-0324": compute_deepseek_v3_gpu_experts,
|
||||
"DeepSeek-V3.2": compute_deepseek_v3_gpu_experts, # Same as V3-0324
|
||||
"DeepSeek-R1-0528": compute_deepseek_v3_gpu_experts, # Same as V3-0324
|
||||
"Kimi-K2-Thinking": compute_kimi_k2_thinking_gpu_experts,
|
||||
"MiniMax-M2": compute_minimax_m2_gpu_experts,
|
||||
"MiniMax-M2.1": compute_minimax_m2_gpu_experts, # Same as M2
|
||||
}
|
||||
407
kt-kernel/python/cli/utils/sglang_checker.py
Normal file
407
kt-kernel/python/cli/utils/sglang_checker.py
Normal file
|
|
@ -0,0 +1,407 @@
|
|||
"""
|
||||
SGLang installation checker and installation instructions provider.
|
||||
|
||||
This module provides utilities to:
|
||||
- Check if SGLang is installed and get its metadata
|
||||
- Provide installation instructions when SGLang is not found
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import console
|
||||
|
||||
|
||||
def check_sglang_installation() -> dict:
|
||||
"""Check if SGLang is installed and get its metadata.
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
- installed: bool
|
||||
- version: str or None
|
||||
- location: str or None (installation path)
|
||||
- editable: bool (whether installed in editable mode)
|
||||
- git_info: dict or None (git remote and branch if available)
|
||||
- from_source: bool (whether installed from source repository)
|
||||
"""
|
||||
try:
|
||||
# Try to import sglang
|
||||
import sglang
|
||||
|
||||
version = getattr(sglang, "__version__", None)
|
||||
|
||||
# Use pip show to get detailed package information
|
||||
location = None
|
||||
editable = False
|
||||
git_info = None
|
||||
from_source = False
|
||||
|
||||
try:
|
||||
# Get pip show output
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "pip", "show", "sglang"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
pip_info = {}
|
||||
for line in result.stdout.split("\n"):
|
||||
if ":" in line:
|
||||
key, value = line.split(":", 1)
|
||||
pip_info[key.strip()] = value.strip()
|
||||
|
||||
location = pip_info.get("Location")
|
||||
editable_location = pip_info.get("Editable project location")
|
||||
|
||||
if editable_location:
|
||||
editable = True
|
||||
location = editable_location
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
# Fallback to module location
|
||||
if hasattr(sglang, "__file__") and sglang.__file__:
|
||||
location = str(Path(sglang.__file__).parent.parent)
|
||||
|
||||
# Check if it's installed from source (has .git directory)
|
||||
if location:
|
||||
git_root = None
|
||||
check_path = Path(location)
|
||||
|
||||
# Check current directory and up to 2 parent directories
|
||||
for _ in range(3):
|
||||
git_dir = check_path / ".git"
|
||||
if git_dir.exists():
|
||||
git_root = check_path
|
||||
from_source = True
|
||||
break
|
||||
if check_path.parent == check_path: # Reached root
|
||||
break
|
||||
check_path = check_path.parent
|
||||
|
||||
if from_source and git_root:
|
||||
# Try to get git remote and branch info
|
||||
try:
|
||||
# Get remote URL
|
||||
result = subprocess.run(
|
||||
["git", "remote", "get-url", "origin"],
|
||||
cwd=git_root,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
remote_url = result.stdout.strip() if result.returncode == 0 else None
|
||||
|
||||
# Extract org/repo from URL
|
||||
remote_short = None
|
||||
if remote_url:
|
||||
# Handle both https and git@ URLs
|
||||
if "github.com" in remote_url:
|
||||
parts = remote_url.rstrip("/").replace(".git", "").split("github.com")[-1]
|
||||
remote_short = parts.lstrip("/").lstrip(":")
|
||||
|
||||
# Get current branch
|
||||
result = subprocess.run(
|
||||
["git", "branch", "--show-current"],
|
||||
cwd=git_root,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
branch = result.stdout.strip() if result.returncode == 0 else None
|
||||
|
||||
if remote_url or branch:
|
||||
git_info = {
|
||||
"remote": remote_short or remote_url,
|
||||
"branch": branch,
|
||||
}
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
pass
|
||||
|
||||
return {
|
||||
"installed": True,
|
||||
"version": version,
|
||||
"location": location,
|
||||
"editable": editable,
|
||||
"git_info": git_info,
|
||||
"from_source": from_source,
|
||||
}
|
||||
except ImportError:
|
||||
return {
|
||||
"installed": False,
|
||||
"version": None,
|
||||
"location": None,
|
||||
"editable": False,
|
||||
"git_info": None,
|
||||
"from_source": False,
|
||||
}
|
||||
|
||||
|
||||
def get_sglang_install_instructions(lang: Optional[str] = None) -> str:
|
||||
"""Get SGLang installation instructions.
|
||||
|
||||
Args:
|
||||
lang: Language code ('en' or 'zh'). If None, uses current language setting.
|
||||
|
||||
Returns:
|
||||
Formatted installation instructions string.
|
||||
"""
|
||||
from kt_kernel.cli.i18n import get_lang
|
||||
|
||||
if lang is None:
|
||||
lang = get_lang()
|
||||
|
||||
if lang == "zh":
|
||||
return """
|
||||
[bold yellow]SGLang \u672a\u5b89\u88c5[/bold yellow]
|
||||
|
||||
\u8bf7\u6309\u7167\u4ee5\u4e0b\u6b65\u9aa4\u5b89\u88c5 SGLang:
|
||||
|
||||
[bold]1. \u514b\u9686\u4ed3\u5e93:[/bold]
|
||||
git clone https://github.com/kvcache-ai/sglang.git
|
||||
cd sglang
|
||||
|
||||
[bold]2. \u5b89\u88c5 (\u4e8c\u9009\u4e00):[/bold]
|
||||
|
||||
[cyan]\u65b9\u5f0f A - pip \u5b89\u88c5 (\u63a8\u8350):[/cyan]
|
||||
pip install -e "python[all]"
|
||||
|
||||
[cyan]\u65b9\u5f0f B - uv \u5b89\u88c5 (\u66f4\u5feb):[/cyan]
|
||||
pip install uv
|
||||
uv pip install -e "python[all]"
|
||||
|
||||
[dim]\u6ce8\u610f: \u8bf7\u786e\u4fdd\u5728\u6b63\u786e\u7684 Python \u73af\u5883\u4e2d\u6267\u884c\u4ee5\u4e0a\u547d\u4ee4[/dim]
|
||||
"""
|
||||
else:
|
||||
return """
|
||||
[bold yellow]SGLang is not installed[/bold yellow]
|
||||
|
||||
Please follow these steps to install SGLang:
|
||||
|
||||
[bold]1. Clone the repository:[/bold]
|
||||
git clone https://github.com/kvcache-ai/sglang.git
|
||||
cd sglang
|
||||
|
||||
[bold]2. Install (choose one):[/bold]
|
||||
|
||||
[cyan]Option A - pip install (recommended):[/cyan]
|
||||
pip install -e "python[all]"
|
||||
|
||||
[cyan]Option B - uv install (faster):[/cyan]
|
||||
pip install uv
|
||||
uv pip install -e "python[all]"
|
||||
|
||||
[dim]Note: Make sure to run these commands in the correct Python environment[/dim]
|
||||
"""
|
||||
|
||||
|
||||
def print_sglang_install_instructions() -> None:
|
||||
"""Print SGLang installation instructions to console."""
|
||||
instructions = get_sglang_install_instructions()
|
||||
console.print(instructions)
|
||||
|
||||
|
||||
def check_sglang_and_warn() -> bool:
|
||||
"""Check if SGLang is installed, print warning if not.
|
||||
|
||||
Returns:
|
||||
True if SGLang is installed, False otherwise.
|
||||
"""
|
||||
info = check_sglang_installation()
|
||||
|
||||
if not info["installed"]:
|
||||
print_sglang_install_instructions()
|
||||
return False
|
||||
|
||||
# Check if installed from PyPI (not recommended)
|
||||
if info["installed"] and not info["from_source"]:
|
||||
from kt_kernel.cli.utils.console import print_warning
|
||||
|
||||
print_warning(t("sglang_pypi_warning"))
|
||||
console.print()
|
||||
console.print("[dim]" + t("sglang_recommend_source") + "[/dim]")
|
||||
console.print()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _get_sglang_kt_kernel_cache_path() -> Path:
|
||||
"""Get the path to the sglang kt-kernel support cache file."""
|
||||
cache_dir = Path.home() / ".ktransformers" / "cache"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cache_dir / "sglang_kt_kernel_supported"
|
||||
|
||||
|
||||
def _is_sglang_kt_kernel_cache_valid() -> bool:
|
||||
"""Check if the sglang kt-kernel support cache is valid.
|
||||
|
||||
The cache is considered valid if:
|
||||
1. The cache file exists
|
||||
2. The cache file contains 'true' (indicating previous check passed)
|
||||
|
||||
Returns:
|
||||
True if cache is valid and indicates support, False otherwise.
|
||||
"""
|
||||
cache_path = _get_sglang_kt_kernel_cache_path()
|
||||
if cache_path.exists():
|
||||
try:
|
||||
content = cache_path.read_text().strip().lower()
|
||||
return content == "true"
|
||||
except (OSError, IOError):
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def _save_sglang_kt_kernel_cache(supported: bool) -> None:
|
||||
"""Save the sglang kt-kernel support check result to cache."""
|
||||
cache_path = _get_sglang_kt_kernel_cache_path()
|
||||
try:
|
||||
cache_path.write_text("true" if supported else "false")
|
||||
except (OSError, IOError):
|
||||
pass # Ignore cache write errors
|
||||
|
||||
|
||||
def clear_sglang_kt_kernel_cache() -> None:
|
||||
"""Clear the sglang kt-kernel support cache, forcing a re-check on next run."""
|
||||
cache_path = _get_sglang_kt_kernel_cache_path()
|
||||
try:
|
||||
if cache_path.exists():
|
||||
cache_path.unlink()
|
||||
except (OSError, IOError):
|
||||
pass
|
||||
|
||||
|
||||
def check_sglang_kt_kernel_support(use_cache: bool = True, silent: bool = False) -> dict:
|
||||
"""Check if SGLang supports kt-kernel parameters (--kt-gpu-prefill-token-threshold).
|
||||
|
||||
This function runs `python -m sglang.launch_server --help` and checks if the
|
||||
output contains the `--kt-gpu-prefill-token-threshold` parameter. This parameter
|
||||
is only available in the kvcache-ai/sglang fork, not in the official sglang.
|
||||
|
||||
The result is cached after the first successful check to avoid repeated checks.
|
||||
|
||||
Args:
|
||||
use_cache: If True, use cached result if available. Default is True.
|
||||
silent: If True, don't print checking message. Default is False.
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
- supported: bool - True if kt-kernel parameters are supported
|
||||
- help_output: str or None - The help output from sglang.launch_server
|
||||
- error: str or None - Error message if check failed
|
||||
- from_cache: bool - True if result was from cache
|
||||
"""
|
||||
from kt_kernel.cli.utils.console import print_step
|
||||
|
||||
# Check cache first
|
||||
if use_cache and _is_sglang_kt_kernel_cache_valid():
|
||||
return {
|
||||
"supported": True,
|
||||
"help_output": None,
|
||||
"error": None,
|
||||
"from_cache": True,
|
||||
}
|
||||
|
||||
# Print checking message
|
||||
if not silent:
|
||||
print_step(t("sglang_checking_kt_kernel_support"))
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "sglang.launch_server", "--help"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
help_output = result.stdout + result.stderr
|
||||
|
||||
# Check if --kt-gpu-prefill-token-threshold is in the help output
|
||||
supported = "--kt-gpu-prefill-token-threshold" in help_output
|
||||
|
||||
# Save to cache if supported
|
||||
if supported:
|
||||
_save_sglang_kt_kernel_cache(True)
|
||||
|
||||
return {
|
||||
"supported": supported,
|
||||
"help_output": help_output,
|
||||
"error": None,
|
||||
"from_cache": False,
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"supported": False,
|
||||
"help_output": None,
|
||||
"error": "Timeout while checking sglang.launch_server --help",
|
||||
"from_cache": False,
|
||||
}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"supported": False,
|
||||
"help_output": None,
|
||||
"error": "Python interpreter not found",
|
||||
"from_cache": False,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"supported": False,
|
||||
"help_output": None,
|
||||
"error": str(e),
|
||||
"from_cache": False,
|
||||
}
|
||||
|
||||
|
||||
def print_sglang_kt_kernel_instructions() -> None:
|
||||
"""Print instructions for installing the kvcache-ai fork of SGLang with kt-kernel support."""
|
||||
from kt_kernel.cli.i18n import get_lang
|
||||
|
||||
lang = get_lang()
|
||||
|
||||
if lang == "zh":
|
||||
instructions = """
|
||||
[bold red]SGLang 不支持 kt-kernel[/bold red]
|
||||
|
||||
您当前安装的 SGLang 不包含 kt-kernel 支持。
|
||||
kt-kernel 需要使用 kvcache-ai 维护的 SGLang 分支。
|
||||
|
||||
[bold]请按以下步骤重新安装 SGLang:[/bold]
|
||||
|
||||
[cyan]1. 卸载当前的 SGLang:[/cyan]
|
||||
pip uninstall sglang -y
|
||||
|
||||
[cyan]2. 克隆 kvcache-ai 的 SGLang 仓库:[/cyan]
|
||||
git clone https://github.com/kvcache-ai/sglang.git
|
||||
cd sglang
|
||||
|
||||
[cyan]3. 安装 SGLang:[/cyan]
|
||||
pip install -e "python[all]"
|
||||
|
||||
[dim]注意: 请确保在正确的 Python 环境中执行以上命令[/dim]
|
||||
"""
|
||||
else:
|
||||
instructions = """
|
||||
[bold red]SGLang does not support kt-kernel[/bold red]
|
||||
|
||||
Your current SGLang installation does not include kt-kernel support.
|
||||
kt-kernel requires the kvcache-ai maintained fork of SGLang.
|
||||
|
||||
[bold]Please reinstall SGLang with the following steps:[/bold]
|
||||
|
||||
[cyan]1. Uninstall current SGLang:[/cyan]
|
||||
pip uninstall sglang -y
|
||||
|
||||
[cyan]2. Clone the kvcache-ai SGLang repository:[/cyan]
|
||||
git clone https://github.com/kvcache-ai/sglang.git
|
||||
cd sglang
|
||||
|
||||
[cyan]3. Install SGLang:[/cyan]
|
||||
pip install -e "python[all]"
|
||||
|
||||
[dim]Note: Make sure to run these commands in the correct Python environment[/dim]
|
||||
"""
|
||||
console.print(instructions)
|
||||
Loading…
Add table
Add a link
Reference in a new issue