kvcache-ai-ktransformers/kt-kernel/python/cli/commands/model.py
ErvinXie d8046e1bb4
Kt minimax (#1742)
[feat]: fp8 kernel and kt-cli support
2025-12-24 15:39:44 +08:00

409 lines
12 KiB
Python

"""
Model command for kt-cli.
Manages models: download, list, and storage paths.
"""
import os
from pathlib import Path
from typing import Optional
import typer
from kt_kernel.cli.config.settings import get_settings
from kt_kernel.cli.i18n import t
from kt_kernel.cli.utils.console import (
confirm,
console,
print_error,
print_info,
print_success,
print_warning,
prompt_choice,
)
app = typer.Typer(
help="Manage models and storage paths",
invoke_without_command=True,
no_args_is_help=False,
)
@app.callback()
def callback(ctx: typer.Context) -> None:
"""
Model management commands.
Run without arguments to see available models.
"""
# If no subcommand is provided, show the model list
if ctx.invoked_subcommand is None:
show_model_list()
def show_model_list() -> None:
"""Display available models with their status and paths."""
from rich.table import Table
from kt_kernel.cli.utils.model_registry import get_registry
from kt_kernel.cli.i18n import get_lang
registry = get_registry()
settings = get_settings()
console.print()
console.print(f"[bold cyan]{t('model_supported_title')}[/bold cyan]\n")
# Get local models mapping
local_models = {m.name: p for m, p in registry.find_local_models()}
# Create table
table = Table(show_header=True, header_style="bold")
table.add_column(t("model_column_model"), style="cyan", no_wrap=True)
table.add_column(t("model_column_status"), justify="center")
all_models = registry.list_all()
for model in all_models:
if model.name in local_models:
status = f"[green]✓ {t('model_status_local')}[/green]"
else:
status = "[dim]-[/dim]"
table.add_row(model.name, status)
console.print(table)
console.print()
# Usage instructions
console.print(f"[bold]{t('model_usage_title')}:[/bold]")
console.print(f"{t('model_usage_download')} [cyan]kt model download <model-name>[/cyan]")
console.print(f"{t('model_usage_list_local')} [cyan]kt model list --local[/cyan]")
console.print(f"{t('model_usage_search')} [cyan]kt model search <query>[/cyan]")
console.print()
# Show model storage paths
model_paths = settings.get_model_paths()
console.print(f"[bold]{t('model_storage_paths_title')}:[/bold]")
for path in model_paths:
marker = "[green]✓[/green]" if path.exists() else "[dim]✗[/dim]"
console.print(f" {marker} {path}")
console.print()
@app.command(name="download")
def download(
model: Optional[str] = typer.Argument(
None,
help="Model name or HuggingFace repo (e.g., deepseek-v3, Qwen/Qwen3-30B)",
),
path: Optional[Path] = typer.Option(
None,
"--path",
"-p",
help="Custom download path",
),
list_models: bool = typer.Option(
False,
"--list",
"-l",
help="List available models",
),
resume: bool = typer.Option(
True,
"--resume/--no-resume",
help="Resume incomplete downloads",
),
yes: bool = typer.Option(
False,
"--yes",
"-y",
help="Skip confirmation prompts",
),
) -> None:
"""Download model weights from HuggingFace."""
import subprocess
from kt_kernel.cli.i18n import get_lang
from kt_kernel.cli.utils.console import print_model_table, print_step
from kt_kernel.cli.utils.model_registry import get_registry
settings = get_settings()
registry = get_registry()
console.print()
# List mode
if list_models or model is None:
print_step(t("download_list_title"))
console.print()
models = registry.list_all()
model_dicts = []
for m in models:
lang = get_lang()
desc = m.description_zh if lang == "zh" and m.description_zh else m.description
model_dicts.append(
{
"name": m.name,
"hf_repo": m.hf_repo,
"type": m.type,
"gpu_vram_gb": m.gpu_vram_gb,
"cpu_ram_gb": m.cpu_ram_gb,
}
)
print_model_table(model_dicts)
console.print()
if model is None:
console.print(f"[dim]{t('model_download_usage_hint')}[/dim]")
console.print()
return
# Search for model
print_step(t("download_searching", name=model))
# Check if it's a direct HuggingFace repo path
if "/" in model:
hf_repo = model
model_info = None
model_name = model.split("/")[-1]
else:
matches = registry.search(model)
if not matches:
print_error(t("run_model_not_found", name=model))
console.print()
console.print(t("model_download_list_hint"))
console.print(t("model_download_hf_hint"))
raise typer.Exit(1)
if len(matches) == 1:
model_info = matches[0]
else:
console.print()
print_info(t("download_multiple_found"))
choices = [f"{m.name} ({m.hf_repo})" for m in matches]
selected = prompt_choice(t("download_select"), choices)
idx = choices.index(selected)
model_info = matches[idx]
hf_repo = model_info.hf_repo
model_name = model_info.name
print_success(t("download_found", name=hf_repo))
# Determine download path
if path is None:
download_path = settings.models_dir / model_name.replace(" ", "-")
else:
download_path = path
console.print()
print_info(t("download_destination", path=str(download_path)))
# Check if already exists
if download_path.exists() and (download_path / "config.json").exists():
print_warning(t("download_already_exists", path=str(download_path)))
if not yes:
if not confirm(t("download_overwrite_prompt"), default=False):
raise typer.Abort()
# Confirm download
if not yes:
console.print()
if not confirm(t("prompt_continue")):
raise typer.Abort()
# Download using huggingface-cli
console.print()
print_step(t("download_starting"))
cmd = [
"huggingface-cli",
"download",
hf_repo,
"--local-dir",
str(download_path),
]
if resume:
cmd.append("--resume-download")
# Add mirror if configured
mirror = settings.get("download.mirror", "")
if mirror:
cmd.extend(["--endpoint", mirror])
try:
process = subprocess.run(cmd, check=True)
console.print()
print_success(t("download_complete"))
console.print()
console.print(f" {t('model_saved_to', path=download_path)}")
console.print()
console.print(f" {t('model_start_with', name=model_name)}")
console.print()
except subprocess.CalledProcessError as e:
print_error(t("model_download_failed", error=str(e)))
raise typer.Exit(1)
except FileNotFoundError:
print_error(t("model_hf_cli_not_found"))
raise typer.Exit(1)
@app.command(name="list")
def list_models(
local_only: bool = typer.Option(False, "--local", help="Show only locally downloaded models"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed info including paths"),
) -> None:
"""List available models."""
from rich.table import Table
from kt_kernel.cli.utils.model_registry import get_registry
registry = get_registry()
console.print()
if local_only:
# Show only local models
local_models = registry.find_local_models()
if not local_models:
print_warning(t("model_no_local_models"))
console.print()
console.print(f" {t('model_download_hint')} [cyan]kt model download <model-name>[/cyan]")
console.print()
return
table = Table(title=t("model_local_models_title"), show_header=True, header_style="bold")
table.add_column(t("model_column_model"), style="cyan", no_wrap=True)
if verbose:
table.add_column(t("model_column_local_path"), style="dim")
for model_info, model_path in local_models:
if verbose:
table.add_row(model_info.name, str(model_path))
else:
table.add_row(model_info.name)
console.print(table)
else:
# Show all registered models
all_models = registry.list_all()
local_models_dict = {m.name: p for m, p in registry.find_local_models()}
table = Table(title=t("model_available_models_title"), show_header=True, header_style="bold")
table.add_column(t("model_column_model"), style="cyan", no_wrap=True)
table.add_column(t("model_column_status"), justify="center")
if verbose:
table.add_column(t("model_column_local_path"), style="dim")
for model in all_models:
if model.name in local_models_dict:
status = f"[green]✓ {t('model_status_local')}[/green]"
local_path = str(local_models_dict[model.name])
else:
status = "[dim]-[/dim]"
local_path = f"[dim]{t('model_status_not_downloaded')}[/dim]"
if verbose:
table.add_row(model.name, status, local_path)
else:
table.add_row(model.name, status)
console.print(table)
console.print()
@app.command(name="path-list")
def path_list() -> None:
"""List all configured model storage paths."""
settings = get_settings()
model_paths = settings.get_model_paths()
console.print()
console.print(f"[bold]{t('model_storage_paths_title')}:[/bold]\n")
for i, path in enumerate(model_paths, 1):
marker = "[green]✓[/green]" if path.exists() else "[red]✗[/red]"
console.print(f" {marker} [{i}] {path}")
console.print()
@app.command(name="path-add")
def path_add(
path: str = typer.Argument(..., help="Path to add"),
) -> None:
"""Add a new model storage path."""
# Expand user home directory
path = os.path.expanduser(path)
# Check if path exists or can be created
path_obj = Path(path)
if not path_obj.exists():
console.print(f"[yellow]{t('model_path_not_exist', path=path)}[/yellow]")
if confirm(t("model_create_directory", path=path), default=True):
try:
path_obj.mkdir(parents=True, exist_ok=True)
console.print(f"[green]✓[/green] {t('model_created_directory', path=path)}")
except (OSError, PermissionError) as e:
print_error(t("model_create_dir_failed", error=str(e)))
raise typer.Exit(1)
else:
raise typer.Abort()
# Add to configuration
settings = get_settings()
settings.add_model_path(path)
print_success(t("model_path_added", path=path))
@app.command(name="path-remove")
def path_remove(
path: str = typer.Argument(..., help="Path to remove"),
) -> None:
"""Remove a model storage path from configuration."""
# Expand user home directory
path = os.path.expanduser(path)
settings = get_settings()
if settings.remove_model_path(path):
print_success(t("model_path_removed", path=path))
else:
print_error(t("model_path_not_found", path=path))
raise typer.Exit(1)
@app.command(name="search")
def search(
query: str = typer.Argument(..., help="Search query (model name or keyword)"),
) -> None:
"""Search for models in the registry."""
from rich.table import Table
from kt_kernel.cli.utils.model_registry import get_registry
registry = get_registry()
matches = registry.search(query)
console.print()
if not matches:
print_warning(t("model_search_no_results", query=query))
console.print()
return
table = Table(title=t("model_search_results_title", query=query), show_header=True)
table.add_column(t("model_column_name"), style="cyan")
table.add_column(t("model_column_hf_repo"), style="dim")
table.add_column(t("model_column_aliases"), style="yellow")
for model in matches:
aliases = ", ".join(model.aliases[:3])
if len(model.aliases) > 3:
aliases += f" +{len(model.aliases) - 3} more"
table.add_row(model.name, model.hf_repo, aliases)
console.print(table)
console.print()