mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-28 11:49:51 +00:00
kt-cli enhancement (#1834)
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
* [feat]: redesign kt run interactive configuration with i18n support - Redesign kt run with 8-step interactive flow (model selection, inference method, NUMA/CPU, GPU experts, KV cache, GPU/TP selection, parsers, host/port) - Add configuration save/load system (~/.ktransformers/run_configs.yaml) - Add i18n support for kt chat (en/zh translations) - Add universal input validators with auto-retry and Chinese comma support - Add port availability checker with auto-suggestion - Add parser configuration (--tool-call-parser, --reasoning-parser) - Remove tuna command and clean up redundant files - Fix: variable reference bug in run.py, filter to show only MoE models * [feat]: unify model selection UI and enable shared experts fusion by default - Unify kt run model selection table with kt model list display * Add Total size, MoE Size, Repo, and SHA256 status columns * Use consistent formatting and styling * Improve user decision-making with more information - Enable --disable-shared-experts-fusion by default * Change default value from False to True * Users can still override with --enable-shared-experts-fusion * [feat]: improve kt chat with performance metrics and better CJK support - Add performance metrics display after each response * Total time, TTFT (Time To First Token), TPOT (Time Per Output Token) * Accurate input/output token counts using model tokenizer * Fallback to estimation if tokenizer unavailable * Metrics shown in dim style (not prominent) - Fix Chinese character input issues * Replace Prompt.ask() with console.input() for better CJK support * Fixes backspace deletion showing half-characters - Suppress NumPy subnormal warnings * Filter "The value of the smallest subnormal" warnings * Cleaner CLI output on certain hardware environments * [fix]: correct TTFT measurement in kt chat - Move start_time initialization before API call - Previously start_time was set when receiving first chunk, causing TTFT ≈ 0ms - Now correctly measures time from request sent to first token received * [docs]: 添加 Clawdbot 集成指南 - KTransformers 企业级 AI 助手部署方案 * [docs]: 强调推荐使用 Kimi K2.5 作为核心模型,突出企业级推理能力 * [docs]: 添加 Clawdbot 飞书接入教程链接 * [feat]: improve CLI table display, model verification, and chat experience - Add sequence number (#) column to all model tables by default - Filter kt edit to show only MoE GPU models (exclude AMX) - Extend kt model verify to check *.json and *.py files in addition to weights - Fix re-verification bug where repaired files caused false failures - Suppress tokenizer debug output in kt chat token counting * [fix]: fix cpu cores. --------- Co-authored-by: skqliao <skqliao@gmail.com>
This commit is contained in:
parent
4f64665758
commit
56cbd69ac4
23 changed files with 10327 additions and 781 deletions
|
|
@ -96,9 +96,9 @@ def chat(
|
|||
kt chat -t 0.9 --max-tokens 4096 # Adjust generation parameters
|
||||
"""
|
||||
if not HAS_OPENAI:
|
||||
print_error("OpenAI Python SDK is required for chat functionality.")
|
||||
print_error(t("chat_openai_required"))
|
||||
console.print()
|
||||
console.print("Install it with:")
|
||||
console.print(t("chat_install_hint"))
|
||||
console.print(" pip install openai")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
|
@ -114,10 +114,10 @@ def chat(
|
|||
console.print()
|
||||
console.print(
|
||||
Panel.fit(
|
||||
f"[bold cyan]KTransformers Chat[/bold cyan]\n\n"
|
||||
f"Server: [yellow]{final_host}:{final_port}[/yellow]\n"
|
||||
f"Temperature: [cyan]{temperature}[/cyan] | Max tokens: [cyan]{max_tokens}[/cyan]\n\n"
|
||||
f"[dim]Type '/help' for commands, '/quit' to exit[/dim]",
|
||||
f"[bold cyan]{t('chat_title')}[/bold cyan]\n\n"
|
||||
f"{t('chat_server')}: [yellow]{final_host}:{final_port}[/yellow]\n"
|
||||
f"{t('chat_temperature')}: [cyan]{temperature}[/cyan] | {t('chat_max_tokens')}: [cyan]{max_tokens}[/cyan]\n\n"
|
||||
f"[dim]{t('chat_help_hint')}[/dim]",
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
|
|
@ -152,31 +152,44 @@ def chat(
|
|||
)
|
||||
|
||||
# Test connection
|
||||
print_info("Connecting to server...")
|
||||
print_info(t("chat_connecting"))
|
||||
models = client.models.list()
|
||||
available_models = [m.id for m in models.data]
|
||||
|
||||
if not available_models:
|
||||
print_error("No models available on server")
|
||||
print_error(t("chat_no_models"))
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Select model
|
||||
if model:
|
||||
if model not in available_models:
|
||||
print_warning(f"Model '{model}' not found. Available models: {', '.join(available_models)}")
|
||||
print_warning(t("chat_model_not_found", model=model, available=", ".join(available_models)))
|
||||
selected_model = available_models[0]
|
||||
else:
|
||||
selected_model = model
|
||||
else:
|
||||
selected_model = available_models[0]
|
||||
|
||||
print_success(f"Connected to model: {selected_model}")
|
||||
print_success(t("chat_connected", model=selected_model))
|
||||
console.print()
|
||||
|
||||
# Load tokenizer for accurate token counting
|
||||
tokenizer = None
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# selected_model is the model path
|
||||
tokenizer = AutoTokenizer.from_pretrained(selected_model, trust_remote_code=True)
|
||||
console.print(f"[dim]Loaded tokenizer from {selected_model}[/dim]")
|
||||
console.print()
|
||||
except Exception as e:
|
||||
console.print(f"[dim yellow]Warning: Could not load tokenizer, token counts will be estimated[/dim]")
|
||||
console.print()
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Failed to connect to server: {e}")
|
||||
print_error(t("chat_connect_failed", error=str(e)))
|
||||
console.print()
|
||||
console.print("Make sure the model server is running:")
|
||||
console.print(t("chat_server_not_running"))
|
||||
console.print(" kt run <model>")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
|
@ -201,12 +214,12 @@ def chat(
|
|||
# Main chat loop
|
||||
try:
|
||||
while True:
|
||||
# Get user input
|
||||
# Get user input - use console.input() for better CJK character support
|
||||
try:
|
||||
user_input = Prompt.ask("[bold green]You[/bold green]")
|
||||
user_input = console.input(f"[bold green]{t('chat_user_prompt')}[/bold green]: ")
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
console.print()
|
||||
print_info("Goodbye!")
|
||||
print_info(t("chat_goodbye"))
|
||||
break
|
||||
|
||||
if not user_input.strip():
|
||||
|
|
@ -224,15 +237,19 @@ def chat(
|
|||
|
||||
# Generate response
|
||||
console.print()
|
||||
console.print("[bold cyan]Assistant[/bold cyan]")
|
||||
console.print(f"[bold cyan]{t('chat_assistant_prompt')}[/bold cyan]")
|
||||
|
||||
try:
|
||||
if stream:
|
||||
# Streaming response
|
||||
response_content = _stream_response(client, selected_model, messages, temperature, max_tokens)
|
||||
response_content = _stream_response(
|
||||
client, selected_model, messages, temperature, max_tokens, tokenizer
|
||||
)
|
||||
else:
|
||||
# Non-streaming response
|
||||
response_content = _generate_response(client, selected_model, messages, temperature, max_tokens)
|
||||
response_content = _generate_response(
|
||||
client, selected_model, messages, temperature, max_tokens, tokenizer
|
||||
)
|
||||
|
||||
# Add assistant response to history
|
||||
messages.append({"role": "assistant", "content": response_content})
|
||||
|
|
@ -240,7 +257,7 @@ def chat(
|
|||
console.print()
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Error generating response: {e}")
|
||||
print_error(t("chat_generation_error", error=str(e)))
|
||||
# Remove the user message that caused the error
|
||||
messages.pop()
|
||||
continue
|
||||
|
|
@ -252,12 +269,12 @@ def chat(
|
|||
except KeyboardInterrupt:
|
||||
console.print()
|
||||
console.print()
|
||||
print_info("Chat interrupted. Goodbye!")
|
||||
print_info(t("chat_interrupted"))
|
||||
|
||||
# Final history save
|
||||
if save_history and messages:
|
||||
_save_history(history_file, messages, selected_model)
|
||||
console.print(f"[dim]History saved to: {history_file}[/dim]")
|
||||
console.print(f"[dim]{t('chat_history_saved', path=str(history_file))}[/dim]")
|
||||
console.print()
|
||||
|
||||
|
||||
|
|
@ -267,12 +284,22 @@ def _stream_response(
|
|||
messages: list,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tokenizer=None,
|
||||
) -> str:
|
||||
"""Generate streaming response and display in real-time."""
|
||||
import time
|
||||
|
||||
response_content = ""
|
||||
reasoning_content = ""
|
||||
|
||||
# Performance tracking
|
||||
first_token_time = None
|
||||
chunk_count = 0
|
||||
|
||||
try:
|
||||
# Start timing before sending request
|
||||
start_time = time.time()
|
||||
|
||||
stream = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
|
@ -282,33 +309,120 @@ def _stream_response(
|
|||
)
|
||||
|
||||
for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
reasoning_delta = getattr(delta, "reasoning_content", None)
|
||||
if reasoning_delta:
|
||||
reasoning_content += reasoning_delta
|
||||
console.print(reasoning_delta, end="", style="dim")
|
||||
if delta.content:
|
||||
content = delta.content
|
||||
response_content += content
|
||||
console.print(content, end="")
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
if delta:
|
||||
reasoning_delta = getattr(delta, "reasoning_content", None)
|
||||
if reasoning_delta:
|
||||
if first_token_time is None:
|
||||
first_token_time = time.time()
|
||||
reasoning_content += reasoning_delta
|
||||
console.print(reasoning_delta, end="", style="dim")
|
||||
chunk_count += 1
|
||||
|
||||
if delta.content:
|
||||
if first_token_time is None:
|
||||
first_token_time = time.time()
|
||||
content = delta.content
|
||||
response_content += content
|
||||
console.print(content, end="")
|
||||
chunk_count += 1
|
||||
|
||||
console.print() # Newline after streaming
|
||||
|
||||
# Display performance metrics
|
||||
end_time = time.time()
|
||||
if first_token_time and chunk_count > 0:
|
||||
ttft = first_token_time - start_time
|
||||
total_time = end_time - start_time
|
||||
|
||||
# Calculate TPOT based on chunks
|
||||
if chunk_count > 1:
|
||||
generation_time = total_time - ttft
|
||||
tpot = generation_time / (chunk_count - 1)
|
||||
else:
|
||||
tpot = 0
|
||||
|
||||
# Calculate accurate token counts using tokenizer
|
||||
if tokenizer:
|
||||
input_tokens = _count_tokens_with_tokenizer(messages, tokenizer)
|
||||
output_tokens = _count_tokens_with_tokenizer(
|
||||
[{"role": "assistant", "content": response_content}], tokenizer
|
||||
)
|
||||
token_prefix = ""
|
||||
else:
|
||||
# Fallback to estimation
|
||||
input_tokens = _estimate_tokens(messages)
|
||||
output_tokens = _estimate_tokens([{"role": "assistant", "content": response_content}])
|
||||
token_prefix = "~"
|
||||
|
||||
# Build metrics display
|
||||
metrics = f"[dim]Total: {total_time*1000:.0f}ms | TTFT: {ttft*1000:.0f}ms"
|
||||
if tpot > 0:
|
||||
metrics += f" | TPOT: {tpot*1000:.1f}ms"
|
||||
metrics += f" | In: {token_prefix}{input_tokens} | Out: {token_prefix}{output_tokens}"
|
||||
metrics += "[/dim]"
|
||||
|
||||
console.print(metrics)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Streaming error: {e}")
|
||||
|
||||
return response_content
|
||||
|
||||
|
||||
def _count_tokens_with_tokenizer(messages: list, tokenizer) -> int:
|
||||
"""Count tokens accurately using the model's tokenizer."""
|
||||
try:
|
||||
# Concatenate all message content
|
||||
text = ""
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
# Simple format: role + content
|
||||
text += f"{role}: {content}\n"
|
||||
|
||||
# Encode and count tokens - suppress any debug output from custom tokenizers
|
||||
import os
|
||||
import sys
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
|
||||
with open(os.devnull, "w") as devnull:
|
||||
with redirect_stdout(devnull), redirect_stderr(devnull):
|
||||
tokens = tokenizer.encode(text, add_special_tokens=True)
|
||||
return len(tokens)
|
||||
except Exception:
|
||||
# Fallback to estimation if tokenizer fails
|
||||
return _estimate_tokens(messages)
|
||||
|
||||
|
||||
def _estimate_tokens(messages: list) -> int:
|
||||
"""Estimate token count for messages (rough approximation)."""
|
||||
total_chars = 0
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
total_chars += len(content)
|
||||
|
||||
# Rough estimation:
|
||||
# - English: ~4 chars per token
|
||||
# - Chinese: ~1.5 chars per token
|
||||
# Use 2.5 as average
|
||||
return max(1, int(total_chars / 2.5))
|
||||
|
||||
|
||||
def _generate_response(
|
||||
client: "OpenAI",
|
||||
model: str,
|
||||
messages: list,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tokenizer=None,
|
||||
) -> str:
|
||||
"""Generate non-streaming response."""
|
||||
import time
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
|
@ -317,12 +431,36 @@ def _generate_response(
|
|||
stream=False,
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
total_time = end_time - start_time
|
||||
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# Display as markdown
|
||||
md = Markdown(content)
|
||||
console.print(md)
|
||||
|
||||
# Calculate accurate token counts using tokenizer
|
||||
if tokenizer:
|
||||
input_tokens = _count_tokens_with_tokenizer(messages, tokenizer)
|
||||
output_tokens = _count_tokens_with_tokenizer([{"role": "assistant", "content": content}], tokenizer)
|
||||
token_prefix = ""
|
||||
else:
|
||||
# Fallback to API usage or estimation
|
||||
input_tokens = response.usage.prompt_tokens if response.usage else _estimate_tokens(messages)
|
||||
output_tokens = (
|
||||
response.usage.completion_tokens
|
||||
if response.usage
|
||||
else _estimate_tokens([{"role": "assistant", "content": content}])
|
||||
)
|
||||
token_prefix = "" if response.usage else "~"
|
||||
|
||||
# Display performance metrics
|
||||
console.print(
|
||||
f"[dim]Time: {total_time*1000:.0f}ms | "
|
||||
f"In: {token_prefix}{input_tokens} | Out: {token_prefix}{output_tokens}[/dim]"
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -335,20 +473,14 @@ def _handle_command(command: str, messages: list, temperature: float, max_tokens
|
|||
|
||||
if cmd in ["/quit", "/exit", "/q"]:
|
||||
console.print()
|
||||
print_info("Goodbye!")
|
||||
print_info(t("chat_goodbye"))
|
||||
return False
|
||||
|
||||
elif cmd in ["/help", "/h"]:
|
||||
console.print()
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold]Available Commands:[/bold]\n\n"
|
||||
"/help, /h - Show this help message\n"
|
||||
"/quit, /exit, /q - Exit chat\n"
|
||||
"/clear, /c - Clear conversation history\n"
|
||||
"/history, /hist - Show conversation history\n"
|
||||
"/info, /i - Show current settings\n"
|
||||
"/retry, /r - Regenerate last response",
|
||||
f"[bold]{t('chat_help_title')}[/bold]\n\n{t('chat_help_content')}",
|
||||
title="Help",
|
||||
border_style="cyan",
|
||||
)
|
||||
|
|
@ -359,19 +491,19 @@ def _handle_command(command: str, messages: list, temperature: float, max_tokens
|
|||
elif cmd in ["/clear", "/c"]:
|
||||
messages.clear()
|
||||
console.print()
|
||||
print_success("Conversation history cleared")
|
||||
print_success(t("chat_history_cleared"))
|
||||
console.print()
|
||||
return True
|
||||
|
||||
elif cmd in ["/history", "/hist"]:
|
||||
console.print()
|
||||
if not messages:
|
||||
print_info("No conversation history")
|
||||
print_info(t("chat_no_history"))
|
||||
else:
|
||||
console.print(
|
||||
Panel(
|
||||
_format_history(messages),
|
||||
title=f"History ({len(messages)} messages)",
|
||||
title=t("chat_history_title", count=len(messages)),
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
|
|
@ -382,10 +514,7 @@ def _handle_command(command: str, messages: list, temperature: float, max_tokens
|
|||
console.print()
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold]Current Settings:[/bold]\n\n"
|
||||
f"Temperature: [cyan]{temperature}[/cyan]\n"
|
||||
f"Max tokens: [cyan]{max_tokens}[/cyan]\n"
|
||||
f"Messages: [cyan]{len(messages)}[/cyan]",
|
||||
f"[bold]{t('chat_info_title')}[/bold]\n\n{t('chat_info_content', temperature=temperature, max_tokens=max_tokens, messages=len(messages))}",
|
||||
title="Info",
|
||||
border_style="cyan",
|
||||
)
|
||||
|
|
@ -397,16 +526,16 @@ def _handle_command(command: str, messages: list, temperature: float, max_tokens
|
|||
if len(messages) >= 2 and messages[-1]["role"] == "assistant":
|
||||
# Remove last assistant response
|
||||
messages.pop()
|
||||
print_info("Retrying last response...")
|
||||
print_info(t("chat_retrying"))
|
||||
console.print()
|
||||
else:
|
||||
print_warning("No previous response to retry")
|
||||
print_warning(t("chat_no_retry"))
|
||||
console.print()
|
||||
return True
|
||||
|
||||
else:
|
||||
print_warning(f"Unknown command: {command}")
|
||||
console.print("[dim]Type /help for available commands[/dim]")
|
||||
print_warning(t("chat_unknown_command", command=command))
|
||||
console.print(f"[dim]{t('chat_unknown_hint')}[/dim]")
|
||||
console.print()
|
||||
return True
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -35,12 +35,12 @@ class QuantMethod(str, Enum):
|
|||
|
||||
|
||||
def quant(
|
||||
model: str = typer.Argument(
|
||||
...,
|
||||
model: Optional[str] = typer.Argument(
|
||||
None,
|
||||
help="Model name or path to quantize",
|
||||
),
|
||||
method: QuantMethod = typer.Option(
|
||||
QuantMethod.INT4,
|
||||
method: Optional[QuantMethod] = typer.Option(
|
||||
None,
|
||||
"--method",
|
||||
"-m",
|
||||
help="Quantization method",
|
||||
|
|
@ -51,8 +51,8 @@ def quant(
|
|||
"-o",
|
||||
help="Output path for quantized weights",
|
||||
),
|
||||
input_type: str = typer.Option(
|
||||
"fp8",
|
||||
input_type: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--input-type",
|
||||
"-i",
|
||||
help="Input weight type (fp8, fp16, bf16)",
|
||||
|
|
@ -72,6 +72,11 @@ def quant(
|
|||
"--no-merge",
|
||||
help="Don't merge safetensor files",
|
||||
),
|
||||
gpu: bool = typer.Option(
|
||||
False,
|
||||
"--gpu",
|
||||
help="Use GPU for conversion (faster)",
|
||||
),
|
||||
yes: bool = typer.Option(
|
||||
False,
|
||||
"--yes",
|
||||
|
|
@ -79,54 +84,231 @@ def quant(
|
|||
help="Skip confirmation prompts",
|
||||
),
|
||||
) -> None:
|
||||
"""Quantize model weights for CPU inference."""
|
||||
settings = get_settings()
|
||||
console.print()
|
||||
"""Quantize model weights for CPU inference.
|
||||
|
||||
# Resolve input path
|
||||
input_path = _resolve_input_path(model, settings)
|
||||
if input_path is None:
|
||||
print_error(t("quant_input_not_found", path=model))
|
||||
If no model is specified, interactive mode will be activated.
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
# Check if we should use interactive mode
|
||||
# Interactive mode triggers when: no model, or missing critical parameters
|
||||
needs_interactive = model is None or method is None or cpu_threads is None or numa_nodes is None
|
||||
is_interactive = False
|
||||
|
||||
if needs_interactive and sys.stdin.isatty():
|
||||
# Use interactive configuration (includes verification in Step 1.5)
|
||||
from kt_kernel.cli.utils.quant_interactive import interactive_quant_config
|
||||
|
||||
console.print()
|
||||
console.print(f"[bold cyan]═══ {t('quant_interactive_title')} ═══[/bold cyan]")
|
||||
console.print()
|
||||
console.print(f"[yellow]{t('quant_new_model_notice')}[/yellow]")
|
||||
console.print()
|
||||
|
||||
config = interactive_quant_config()
|
||||
if config is None:
|
||||
# User cancelled
|
||||
raise typer.Exit(0)
|
||||
|
||||
# Extract configuration
|
||||
model_obj = config["model"]
|
||||
model = model_obj.id
|
||||
input_path = Path(model_obj.path)
|
||||
method = QuantMethod(config["method"])
|
||||
input_type = config["input_type"]
|
||||
cpu_threads = config["cpu_threads"]
|
||||
numa_nodes = config["numa_nodes"]
|
||||
output = config["output_path"]
|
||||
gpu = config["use_gpu"]
|
||||
is_interactive = True
|
||||
|
||||
console.print()
|
||||
print_success(t("quant_config_complete"))
|
||||
console.print()
|
||||
else:
|
||||
# Non-interactive mode - require model parameter
|
||||
if model is None:
|
||||
print_error("Model argument is required in non-interactive mode")
|
||||
console.print()
|
||||
console.print("Usage: kt quant <model>")
|
||||
console.print(" Or: kt quant (for interactive mode)")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Set defaults for optional parameters
|
||||
method = method or QuantMethod.INT4
|
||||
input_type = input_type or "fp8"
|
||||
|
||||
console.print()
|
||||
|
||||
# Resolve input path
|
||||
input_path = _resolve_input_path(model, settings)
|
||||
if input_path is None:
|
||||
print_error(t("quant_input_not_found", path=model))
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Pre-quantization verification (only in non-interactive mode)
|
||||
# Interactive mode already did verification in interactive_quant_config()
|
||||
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
|
||||
from kt_kernel.cli.utils.model_verifier import pre_operation_verification
|
||||
|
||||
user_registry = UserModelRegistry()
|
||||
user_model_obj = user_registry.find_by_path(str(input_path))
|
||||
|
||||
if user_model_obj and user_model_obj.format == "safetensors":
|
||||
pre_operation_verification(user_model_obj, user_registry, operation_name="quantizing")
|
||||
|
||||
# Get user model info for both modes (needed later for registering quantized model)
|
||||
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
|
||||
|
||||
user_registry = UserModelRegistry()
|
||||
user_model_obj = user_registry.find_by_path(str(input_path))
|
||||
|
||||
# Validate that it's a MoE model (not AMX or GGUF)
|
||||
from kt_kernel.cli.commands.model import is_amx_weights
|
||||
|
||||
# Check if it's AMX (already quantized)
|
||||
is_amx, _ = is_amx_weights(str(input_path))
|
||||
if is_amx:
|
||||
print_error("Cannot quantize AMX models (already quantized)")
|
||||
console.print()
|
||||
console.print(f" The model at {input_path} is already in AMX format.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
print_info(t("quant_input_path", path=str(input_path)))
|
||||
# Check if it's a MoE model
|
||||
from kt_kernel.cli.utils.analyze_moe_model import analyze_moe_model
|
||||
|
||||
# Resolve output path
|
||||
if output is None:
|
||||
output = input_path.parent / f"{input_path.name}-{method.value.upper()}"
|
||||
|
||||
print_info(t("quant_output_path", path=str(output)))
|
||||
print_info(t("quant_method", method=method.value.upper()))
|
||||
|
||||
# Detect CPU configuration
|
||||
cpu = detect_cpu_info()
|
||||
final_cpu_threads = cpu_threads or cpu.cores
|
||||
final_numa_nodes = numa_nodes or cpu.numa_nodes
|
||||
|
||||
print_info(f"CPU threads: {final_cpu_threads}")
|
||||
print_info(f"NUMA nodes: {final_numa_nodes}")
|
||||
|
||||
# Check if output exists
|
||||
if output.exists():
|
||||
print_warning(f"Output path already exists: {output}")
|
||||
moe_result = None # Store for later use when registering quantized model
|
||||
try:
|
||||
moe_result = analyze_moe_model(str(input_path), use_cache=True)
|
||||
if not moe_result or not moe_result.get("is_moe"):
|
||||
print_error("Only MoE models can be quantized to AMX format")
|
||||
console.print()
|
||||
console.print(f" The model at {input_path} is not a MoE model.")
|
||||
console.print(" AMX quantization is designed for MoE models (e.g., DeepSeek-V3).")
|
||||
raise typer.Exit(1)
|
||||
except Exception as e:
|
||||
print_warning(f"Could not detect MoE information: {e}")
|
||||
console.print()
|
||||
if not yes:
|
||||
if not confirm("Overwrite?", default=False):
|
||||
if not confirm("Continue quantization anyway?", default=False):
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Detect CPU configuration and resolve output path (only needed in non-interactive mode)
|
||||
if not is_interactive:
|
||||
print_info(t("quant_input_path", path=str(input_path)))
|
||||
|
||||
# Detect CPU configuration (needed for output path)
|
||||
cpu = detect_cpu_info()
|
||||
final_cpu_threads = cpu_threads or cpu.cores
|
||||
final_numa_nodes = numa_nodes or cpu.numa_nodes
|
||||
|
||||
# Resolve output path
|
||||
if output is None:
|
||||
# Priority: paths.weights > paths.models[0] > model's parent directory
|
||||
weights_dir = settings.weights_dir
|
||||
|
||||
if weights_dir and weights_dir.exists():
|
||||
# Use configured weights directory (highest priority)
|
||||
output = weights_dir / f"{input_path.name}-AMX{method.value.upper()}-NUMA{final_numa_nodes}"
|
||||
else:
|
||||
# Use first model storage path
|
||||
model_paths = settings.get_model_paths()
|
||||
if model_paths and model_paths[0].exists():
|
||||
output = model_paths[0] / f"{input_path.name}-AMX{method.value.upper()}-NUMA{final_numa_nodes}"
|
||||
else:
|
||||
# Fallback to model's parent directory
|
||||
output = input_path.parent / f"{input_path.name}-AMX{method.value.upper()}-NUMA{final_numa_nodes}"
|
||||
|
||||
print_info(t("quant_output_path", path=str(output)))
|
||||
print_info(t("quant_method", method=method.value.upper()))
|
||||
print_info(t("quant_cpu_threads", threads=final_cpu_threads))
|
||||
print_info(t("quant_numa_nodes", nodes=final_numa_nodes))
|
||||
|
||||
# Calculate space requirements
|
||||
console.print()
|
||||
console.print(f"[bold cyan]{t('quant_disk_analysis')}[/bold cyan]")
|
||||
console.print()
|
||||
|
||||
# Calculate source model size
|
||||
try:
|
||||
total_bytes = sum(f.stat().st_size for f in input_path.glob("*.safetensors") if f.is_file())
|
||||
source_size_gb = total_bytes / (1024**3)
|
||||
except Exception:
|
||||
source_size_gb = 0.0
|
||||
|
||||
# Estimate quantized size
|
||||
input_bits = {"fp8": 8, "fp16": 16, "bf16": 16}
|
||||
quant_bits = {"int4": 4, "int8": 8}
|
||||
input_bit = input_bits.get(input_type, 16)
|
||||
quant_bit = quant_bits.get(method.value, 4)
|
||||
ratio = quant_bit / input_bit
|
||||
estimated_size_gb = source_size_gb * ratio
|
||||
|
||||
# Check available space
|
||||
import shutil
|
||||
|
||||
try:
|
||||
check_path = output.parent if not output.exists() else output
|
||||
while not check_path.exists() and check_path != check_path.parent:
|
||||
check_path = check_path.parent
|
||||
stat = shutil.disk_usage(check_path)
|
||||
available_gb = stat.free / (1024**3)
|
||||
except Exception:
|
||||
available_gb = 0.0
|
||||
|
||||
is_sufficient = available_gb >= (estimated_size_gb * 1.2)
|
||||
|
||||
console.print(f" {t('quant_source_size'):<26} {source_size_gb:.2f} GB")
|
||||
console.print(f" {t('quant_estimated_size'):<26} {estimated_size_gb:.2f} GB")
|
||||
console.print(f" {t('quant_available_space'):<26} {available_gb:.2f} GB")
|
||||
console.print()
|
||||
|
||||
if not is_sufficient:
|
||||
required_with_buffer = estimated_size_gb * 1.2
|
||||
print_warning(t("quant_insufficient_space"))
|
||||
console.print()
|
||||
console.print(f" {t('quant_required_space'):<26} {required_with_buffer:.2f} GB")
|
||||
console.print(f" {t('quant_available_space'):<26} {available_gb:.2f} GB")
|
||||
console.print(f" {t('quant_shortage'):<26} {required_with_buffer - available_gb:.2f} GB")
|
||||
console.print()
|
||||
console.print(f" {t('quant_may_fail')}")
|
||||
console.print()
|
||||
|
||||
if not yes:
|
||||
if not confirm(t("quant_continue_anyway"), default=False):
|
||||
raise typer.Abort()
|
||||
console.print()
|
||||
|
||||
# Check if output exists and generate unique name
|
||||
if output.exists():
|
||||
print_warning(t("quant_output_exists", path=str(output)))
|
||||
console.print()
|
||||
|
||||
# Generate unique name by adding suffix
|
||||
original_name = output.name
|
||||
parent_dir = output.parent
|
||||
counter = 2
|
||||
|
||||
while output.exists():
|
||||
new_name = f"{original_name}-{counter}"
|
||||
output = parent_dir / new_name
|
||||
counter += 1
|
||||
|
||||
print_success(t("quant_using_unique", path=str(output)))
|
||||
console.print()
|
||||
|
||||
# Confirm (only show if not using --yes flag)
|
||||
if not yes:
|
||||
console.print()
|
||||
print_warning(t("quant_time_warning"))
|
||||
console.print()
|
||||
|
||||
if not confirm(t("prompt_continue")):
|
||||
raise typer.Abort()
|
||||
|
||||
# Confirm
|
||||
if not yes:
|
||||
console.print()
|
||||
console.print("[bold]Quantization Settings:[/bold]")
|
||||
console.print(f" Input: {input_path}")
|
||||
console.print(f" Output: {output}")
|
||||
console.print(f" Method: {method.value.upper()}")
|
||||
console.print(f" Input type: {input_type}")
|
||||
console.print()
|
||||
print_warning("Quantization may take 30-60 minutes depending on model size.")
|
||||
console.print()
|
||||
|
||||
if not confirm(t("prompt_continue")):
|
||||
raise typer.Abort()
|
||||
else:
|
||||
# Interactive mode: cpu_threads and numa_nodes already set
|
||||
final_cpu_threads = cpu_threads
|
||||
final_numa_nodes = numa_nodes
|
||||
|
||||
# Find conversion script
|
||||
kt_kernel_path = _find_kt_kernel_path()
|
||||
|
|
@ -141,37 +323,145 @@ def quant(
|
|||
|
||||
# Build command
|
||||
cmd = [
|
||||
sys.executable, str(script_path),
|
||||
"--input-path", str(input_path),
|
||||
"--input-type", input_type,
|
||||
"--output", str(output),
|
||||
"--quant-method", method.value,
|
||||
"--cpuinfer-threads", str(final_cpu_threads),
|
||||
"--threadpool-count", str(final_numa_nodes),
|
||||
sys.executable,
|
||||
str(script_path),
|
||||
"--input-path",
|
||||
str(input_path),
|
||||
"--input-type",
|
||||
input_type,
|
||||
"--output",
|
||||
str(output),
|
||||
"--quant-method",
|
||||
method.value,
|
||||
"--cpuinfer-threads",
|
||||
str(final_cpu_threads),
|
||||
"--threadpool-count",
|
||||
str(final_numa_nodes),
|
||||
]
|
||||
|
||||
if no_merge:
|
||||
cmd.append("--no-merge-safetensor")
|
||||
|
||||
if gpu:
|
||||
cmd.append("--gpu")
|
||||
|
||||
# Run quantization
|
||||
console.print()
|
||||
print_step(t("quant_starting"))
|
||||
console.print()
|
||||
console.print(f"[dim]$ {' '.join(cmd)}[/dim]")
|
||||
console.print()
|
||||
console.print("[dim]" + "=" * 80 + "[/dim]")
|
||||
console.print()
|
||||
|
||||
try:
|
||||
process = subprocess.run(cmd)
|
||||
# Run with real-time stdout/stderr output
|
||||
import os
|
||||
import time
|
||||
|
||||
env = os.environ.copy()
|
||||
env["PYTHONUNBUFFERED"] = "1" # Disable Python output buffering
|
||||
|
||||
# Record start time
|
||||
start_time = time.time()
|
||||
|
||||
process = subprocess.run(
|
||||
cmd,
|
||||
stdout=None, # Inherit parent's stdout (real-time output)
|
||||
stderr=None, # Inherit parent's stderr (real-time output)
|
||||
env=env,
|
||||
)
|
||||
|
||||
# Calculate elapsed time
|
||||
elapsed_time = time.time() - start_time
|
||||
hours = int(elapsed_time // 3600)
|
||||
minutes = int((elapsed_time % 3600) // 60)
|
||||
seconds = int(elapsed_time % 60)
|
||||
|
||||
console.print()
|
||||
console.print("[dim]" + "=" * 80 + "[/dim]")
|
||||
console.print()
|
||||
|
||||
if process.returncode == 0:
|
||||
console.print()
|
||||
print_success(t("quant_complete"))
|
||||
console.print()
|
||||
|
||||
# Display elapsed time
|
||||
if hours > 0:
|
||||
time_str = f"{hours}h {minutes}m {seconds}s"
|
||||
elif minutes > 0:
|
||||
time_str = f"{minutes}m {seconds}s"
|
||||
else:
|
||||
time_str = f"{seconds}s"
|
||||
console.print(f" [cyan]{t('quant_time_elapsed')} {time_str}[/cyan]")
|
||||
console.print()
|
||||
console.print(f" Quantized weights saved to: {output}")
|
||||
console.print()
|
||||
console.print(" Use with:")
|
||||
console.print(f" kt run {model} --weights-path {output}")
|
||||
console.print()
|
||||
|
||||
# Auto-register the quantized model
|
||||
try:
|
||||
from kt_kernel.cli.utils.user_model_registry import UserModel
|
||||
|
||||
# Generate model name from output path
|
||||
base_name = output.name
|
||||
suggested_name = user_registry.suggest_name(base_name)
|
||||
|
||||
# Determine MoE information and source model name
|
||||
if user_model_obj:
|
||||
is_moe_val = user_model_obj.is_moe
|
||||
num_experts = user_model_obj.moe_num_experts
|
||||
num_active = user_model_obj.moe_num_experts_per_tok
|
||||
repo_type_val = user_model_obj.repo_type
|
||||
repo_id_val = user_model_obj.repo_id
|
||||
source_model_name = user_model_obj.name # Store source model name
|
||||
elif moe_result:
|
||||
is_moe_val = moe_result.get("is_moe", True)
|
||||
num_experts = moe_result.get("num_experts")
|
||||
num_active = moe_result.get("num_experts_per_tok")
|
||||
repo_type_val = None
|
||||
repo_id_val = None
|
||||
source_model_name = input_path.name # Use folder name as fallback
|
||||
else:
|
||||
is_moe_val = None
|
||||
num_experts = None
|
||||
num_active = None
|
||||
repo_type_val = None
|
||||
repo_id_val = None
|
||||
source_model_name = input_path.name # Use folder name as fallback
|
||||
|
||||
# Create new model entry (AMX format uses "safetensors" format, detected by is_amx_weights())
|
||||
new_model = UserModel(
|
||||
name=suggested_name,
|
||||
path=str(output),
|
||||
format="safetensors", # AMX files are safetensors format
|
||||
repo_type=repo_type_val,
|
||||
repo_id=repo_id_val,
|
||||
sha256_status="not_checked", # AMX weights don't need verification
|
||||
# Inherit MoE information from source model
|
||||
is_moe=is_moe_val,
|
||||
moe_num_experts=num_experts,
|
||||
moe_num_experts_per_tok=num_active,
|
||||
# AMX quantization metadata
|
||||
amx_source_model=source_model_name,
|
||||
amx_quant_method=method.value, # "int4" or "int8"
|
||||
amx_numa_nodes=final_numa_nodes,
|
||||
)
|
||||
|
||||
user_registry.add_model(new_model)
|
||||
console.print()
|
||||
print_success(t("quant_registered", name=suggested_name))
|
||||
console.print()
|
||||
console.print(f" {t('quant_view_with')} [cyan]kt model list[/cyan]")
|
||||
console.print(f" {t('quant_use_with')} [cyan]kt run {suggested_name}[/cyan]")
|
||||
console.print()
|
||||
except Exception as e:
|
||||
# Non-fatal error - quantization succeeded but registration failed
|
||||
console.print()
|
||||
print_warning(t("quant_register_failed", error=str(e)))
|
||||
console.print()
|
||||
console.print(f" {t('quant_use_with')}")
|
||||
console.print(f" kt run {model} --weights-path {output}")
|
||||
console.print()
|
||||
else:
|
||||
print_error(f"Quantization failed with exit code {process.returncode}")
|
||||
raise typer.Exit(process.returncode)
|
||||
|
|
@ -221,6 +511,7 @@ def _find_kt_kernel_path() -> Optional[Path]:
|
|||
"""Find the kt-kernel installation path."""
|
||||
try:
|
||||
import kt_kernel
|
||||
|
||||
return Path(kt_kernel.__file__).parent.parent
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from kt_kernel.cli.utils.console import (
|
|||
prompt_choice,
|
||||
)
|
||||
from kt_kernel.cli.utils.environment import detect_cpu_info, detect_gpus, detect_ram_gb
|
||||
from kt_kernel.cli.utils.model_registry import MODEL_COMPUTE_FUNCTIONS, ModelInfo, get_registry
|
||||
from kt_kernel.cli.utils.user_model_registry import UserModelRegistry
|
||||
|
||||
|
||||
@click.command(
|
||||
|
|
@ -120,8 +120,6 @@ def run(
|
|||
# Handle disable/enable shared experts fusion flags
|
||||
if enable_shared_experts_fusion:
|
||||
disable_shared_experts_fusion = False
|
||||
elif disable_shared_experts_fusion is None:
|
||||
disable_shared_experts_fusion = None
|
||||
|
||||
# Convert Path objects from click
|
||||
model_path_obj = Path(model_path) if model_path else None
|
||||
|
|
@ -214,266 +212,250 @@ def _run_impl(
|
|||
raise typer.Exit(1)
|
||||
|
||||
settings = get_settings()
|
||||
registry = get_registry()
|
||||
user_registry = UserModelRegistry()
|
||||
|
||||
console.print()
|
||||
# Check if we should use interactive mode
|
||||
# Interactive mode triggers when:
|
||||
# 1. No model specified, OR
|
||||
# 2. Model specified but missing critical parameters (gpu_experts, tensor_parallel_size, etc.)
|
||||
use_interactive = False
|
||||
|
||||
# If no model specified, show interactive selection
|
||||
if model is None:
|
||||
model = _interactive_model_selection(registry, settings)
|
||||
if model is None:
|
||||
use_interactive = True
|
||||
elif (
|
||||
gpu_experts is None
|
||||
or tensor_parallel_size is None
|
||||
or cpu_threads is None
|
||||
or numa_nodes is None
|
||||
or max_total_tokens is None
|
||||
):
|
||||
# Model specified but some parameters missing - use interactive
|
||||
use_interactive = True
|
||||
|
||||
if use_interactive and sys.stdin.isatty():
|
||||
# Use new interactive configuration flow
|
||||
from kt_kernel.cli.utils.run_interactive import interactive_run_config
|
||||
|
||||
console.print()
|
||||
console.print("[bold cyan]═══ Interactive Run Configuration ═══[/bold cyan]")
|
||||
console.print()
|
||||
|
||||
config = interactive_run_config()
|
||||
if config is None:
|
||||
# User cancelled
|
||||
raise typer.Exit(0)
|
||||
|
||||
# Step 1: Detect hardware
|
||||
print_step(t("run_detecting_hardware"))
|
||||
gpus = detect_gpus()
|
||||
cpu = detect_cpu_info()
|
||||
ram = detect_ram_gb()
|
||||
# Extract configuration from new format
|
||||
user_model_obj = config["model"]
|
||||
model = user_model_obj.id
|
||||
resolved_model_path = Path(config["model_path"])
|
||||
resolved_weights_path = Path(config["weights_path"])
|
||||
|
||||
if gpus:
|
||||
gpu_info = f"{gpus[0].name} ({gpus[0].vram_gb}GB VRAM)"
|
||||
if len(gpus) > 1:
|
||||
gpu_info += f" + {len(gpus) - 1} more"
|
||||
print_info(t("run_gpu_info", name=gpus[0].name, vram=gpus[0].vram_gb))
|
||||
# Extract parameters
|
||||
gpu_experts = config["gpu_experts"]
|
||||
cpu_threads = config["cpu_threads"]
|
||||
numa_nodes = config["numa_nodes"]
|
||||
tensor_parallel_size = config["tp_size"]
|
||||
|
||||
# Get kt-method and other method-specific settings
|
||||
kt_method = config["kt_method"]
|
||||
|
||||
# KV cache settings (may be None for non-raw methods)
|
||||
max_total_tokens = config.get("kv_cache", 32768)
|
||||
chunked_prefill_size = config.get("chunk_prefill", 32768)
|
||||
kt_gpu_prefill_threshold = config.get("gpu_prefill_threshold", 500)
|
||||
|
||||
# Memory settings
|
||||
mem_fraction_static = config["mem_fraction_static"]
|
||||
|
||||
# Parser settings (optional)
|
||||
tool_call_parser = config.get("tool_call_parser")
|
||||
reasoning_parser = config.get("reasoning_parser")
|
||||
|
||||
# Server settings
|
||||
host = config.get("host", "0.0.0.0")
|
||||
port = config.get("port", 30000)
|
||||
|
||||
# Set CUDA_VISIBLE_DEVICES for selected GPUs
|
||||
selected_gpus = config["selected_gpus"]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(gpu_id) for gpu_id in selected_gpus)
|
||||
|
||||
# Detect hardware for parameter resolution (needed for resolve() function later)
|
||||
gpus = detect_gpus()
|
||||
cpu = detect_cpu_info()
|
||||
|
||||
console.print()
|
||||
print_info(f"[green]✓[/green] Configuration complete")
|
||||
console.print()
|
||||
else:
|
||||
print_warning(t("doctor_gpu_not_found"))
|
||||
gpu_info = "None"
|
||||
# Non-interactive mode - use traditional flow
|
||||
console.print()
|
||||
|
||||
print_info(t("run_cpu_info", name=cpu.name, cores=cpu.cores, numa=cpu.numa_nodes))
|
||||
print_info(t("run_ram_info", total=int(ram)))
|
||||
# Initialize variables that may have been set by interactive mode
|
||||
# These will be None in non-interactive mode and will use defaults via resolve()
|
||||
|
||||
# Step 2: Resolve model
|
||||
console.print()
|
||||
print_step(t("run_checking_model"))
|
||||
# If no model specified, show old interactive selection
|
||||
if model is None:
|
||||
model = _interactive_model_selection(user_registry, settings)
|
||||
if model is None:
|
||||
raise typer.Exit(0)
|
||||
|
||||
model_info = None
|
||||
resolved_model_path = model_path
|
||||
# Detect hardware (needed for defaults)
|
||||
gpus = detect_gpus()
|
||||
cpu = detect_cpu_info()
|
||||
ram = detect_ram_gb()
|
||||
|
||||
# Check if model is a path
|
||||
if Path(model).exists():
|
||||
resolved_model_path = Path(model)
|
||||
print_info(t("run_model_path", path=str(resolved_model_path)))
|
||||
|
||||
# Try to infer model type from path to use default configurations
|
||||
# Check directory name against known models
|
||||
dir_name = resolved_model_path.name.lower()
|
||||
for registered_model in registry.list_all():
|
||||
# Check if directory name matches model name or aliases
|
||||
if dir_name == registered_model.name.lower():
|
||||
model_info = registered_model
|
||||
print_info(f"Detected model type: {registered_model.name}")
|
||||
break
|
||||
for alias in registered_model.aliases:
|
||||
if dir_name == alias.lower() or alias.lower() in dir_name:
|
||||
model_info = registered_model
|
||||
print_info(f"Detected model type: {registered_model.name}")
|
||||
break
|
||||
if model_info:
|
||||
break
|
||||
|
||||
# Also check HuggingFace repo format (org--model)
|
||||
if not model_info:
|
||||
for registered_model in registry.list_all():
|
||||
repo_slug = registered_model.hf_repo.replace("/", "--").lower()
|
||||
if repo_slug in dir_name or dir_name in repo_slug:
|
||||
model_info = registered_model
|
||||
print_info(f"Detected model type: {registered_model.name}")
|
||||
break
|
||||
|
||||
if not model_info:
|
||||
print_warning("Could not detect model type from path. Using default parameters.")
|
||||
console.print(" [dim]Tip: Use model name (e.g., 'kt run m2') to apply optimized configurations[/dim]")
|
||||
else:
|
||||
# Search in registry
|
||||
matches = registry.search(model)
|
||||
|
||||
if not matches:
|
||||
print_error(t("run_model_not_found", name=model))
|
||||
console.print()
|
||||
console.print("Available models:")
|
||||
for m in registry.list_all()[:5]:
|
||||
console.print(f" - {m.name} ({', '.join(m.aliases[:2])})")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if len(matches) == 1:
|
||||
model_info = matches[0]
|
||||
if gpus:
|
||||
gpu_info = f"{gpus[0].name} ({gpus[0].vram_gb}GB VRAM)"
|
||||
if len(gpus) > 1:
|
||||
gpu_info += f" + {len(gpus) - 1} more"
|
||||
print_info(t("run_gpu_info", name=gpus[0].name, vram=gpus[0].vram_gb))
|
||||
else:
|
||||
# Multiple matches - prompt user
|
||||
console.print()
|
||||
print_info(t("run_multiple_matches"))
|
||||
choices = [f"{m.name} ({m.hf_repo})" for m in matches]
|
||||
selected = prompt_choice(t("run_select_model"), choices)
|
||||
idx = choices.index(selected)
|
||||
model_info = matches[idx]
|
||||
print_warning(t("doctor_gpu_not_found"))
|
||||
gpu_info = "None"
|
||||
|
||||
# Find model path
|
||||
if model_path is None:
|
||||
resolved_model_path = _find_model_path(model_info, settings)
|
||||
if resolved_model_path is None:
|
||||
print_error(t("run_model_not_found", name=model_info.name))
|
||||
print_info(t("run_cpu_info", name=cpu.name, cores=cpu.cores, numa=cpu.numa_nodes))
|
||||
print_info(t("run_ram_info", total=int(ram)))
|
||||
|
||||
# Step 2: Resolve model
|
||||
console.print()
|
||||
print_step(t("run_checking_model"))
|
||||
|
||||
user_model_obj = None
|
||||
resolved_model_path = model_path
|
||||
|
||||
# Check if model is a path
|
||||
if Path(model).exists():
|
||||
resolved_model_path = Path(model)
|
||||
print_info(t("run_model_path", path=str(resolved_model_path)))
|
||||
|
||||
# Try to find in user registry by path
|
||||
user_model_obj = user_registry.find_by_path(str(resolved_model_path))
|
||||
if user_model_obj:
|
||||
print_info(f"Using registered model: {user_model_obj.name}")
|
||||
else:
|
||||
print_warning("Using unregistered model path. Consider adding it with 'kt model add'")
|
||||
else:
|
||||
# Search in user registry by name
|
||||
user_model_obj = user_registry.get_model(model)
|
||||
|
||||
if not user_model_obj:
|
||||
print_error(t("run_model_not_found", name=model))
|
||||
console.print()
|
||||
console.print(
|
||||
f" Download with: kt download {model_info.aliases[0] if model_info.aliases else model_info.name}"
|
||||
)
|
||||
|
||||
# Show available models
|
||||
all_models = user_registry.list_models()
|
||||
if all_models:
|
||||
console.print("Available registered models:")
|
||||
for m in all_models[:5]:
|
||||
console.print(f" - {m.name}")
|
||||
if len(all_models) > 5:
|
||||
console.print(f" ... and {len(all_models) - 5} more")
|
||||
else:
|
||||
console.print("No models registered yet.")
|
||||
|
||||
console.print()
|
||||
console.print(f"Add your model with: [cyan]kt model add /path/to/model[/cyan]")
|
||||
console.print(f"Or scan for models: [cyan]kt model scan[/cyan]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
print_info(t("run_model_path", path=str(resolved_model_path)))
|
||||
# Use model path from registry
|
||||
resolved_model_path = Path(user_model_obj.path)
|
||||
|
||||
# Step 3: Check quantized weights (only if explicitly requested)
|
||||
resolved_weights_path = None
|
||||
# Verify path exists
|
||||
if not resolved_model_path.exists():
|
||||
print_error(f"Model path does not exist: {resolved_model_path}")
|
||||
console.print()
|
||||
console.print(f"Run 'kt model refresh' to check all models")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Only use quantized weights if explicitly specified by user
|
||||
if weights_path is not None:
|
||||
# User explicitly specified weights path
|
||||
resolved_weights_path = weights_path
|
||||
if not resolved_weights_path.exists():
|
||||
print_error(t("run_weights_not_found"))
|
||||
console.print(f" Path: {resolved_weights_path}")
|
||||
print_info(t("run_model_path", path=str(resolved_model_path)))
|
||||
|
||||
# Step 2.5: Pre-run verification (optional integrity check)
|
||||
if user_model_obj and user_model_obj.format == "safetensors":
|
||||
from kt_kernel.cli.utils.model_verifier import pre_operation_verification
|
||||
|
||||
pre_operation_verification(user_model_obj, user_registry, operation_name="running")
|
||||
|
||||
# Step 3: Check quantized weights (only if explicitly requested)
|
||||
resolved_weights_path = None
|
||||
|
||||
# Only use quantized weights if explicitly specified by user
|
||||
if weights_path is not None:
|
||||
# User explicitly specified weights path
|
||||
resolved_weights_path = weights_path
|
||||
if not resolved_weights_path.exists():
|
||||
print_error(t("run_weights_not_found"))
|
||||
console.print(f" Path: {resolved_weights_path}")
|
||||
raise typer.Exit(1)
|
||||
print_info(f"Using quantized weights: {resolved_weights_path}")
|
||||
elif quantize:
|
||||
# User requested quantization
|
||||
console.print()
|
||||
print_step(t("run_quantizing"))
|
||||
# TODO: Implement quantization
|
||||
print_warning("Quantization not yet implemented. Please run 'kt quant' manually.")
|
||||
raise typer.Exit(1)
|
||||
print_info(f"Using quantized weights: {resolved_weights_path}")
|
||||
elif quantize:
|
||||
# User requested quantization
|
||||
console.print()
|
||||
print_step(t("run_quantizing"))
|
||||
# TODO: Implement quantization
|
||||
print_warning("Quantization not yet implemented. Please run 'kt quant' manually.")
|
||||
raise typer.Exit(1)
|
||||
else:
|
||||
# Default: use original precision model without quantization
|
||||
console.print()
|
||||
print_info("Using original precision model (no quantization)")
|
||||
else:
|
||||
# Default: use original precision model without quantization
|
||||
console.print()
|
||||
print_info("Using original precision model (no quantization)")
|
||||
|
||||
# Step 4: Build command
|
||||
# Resolve all parameters (CLI > model defaults > config > auto-detect)
|
||||
final_host = host or settings.get("server.host", "0.0.0.0")
|
||||
final_port = port or settings.get("server.port", 30000)
|
||||
# Helper to resolve parameter with fallback chain: CLI > config > default
|
||||
def resolve(cli_val, config_key, default):
|
||||
if cli_val is not None:
|
||||
return cli_val
|
||||
config_val = settings.get(config_key)
|
||||
return config_val if config_val is not None else default
|
||||
|
||||
# Get defaults from model info if available
|
||||
model_defaults = model_info.default_params if model_info else {}
|
||||
# Server configuration
|
||||
final_host = resolve(host, "server.host", "0.0.0.0")
|
||||
final_port = resolve(port, "server.port", 30000)
|
||||
|
||||
# Determine tensor parallel size first (needed for GPU expert calculation)
|
||||
# Priority: CLI > model defaults > config > auto-detect (with model constraints)
|
||||
|
||||
# Check if explicitly specified by user or configuration
|
||||
explicitly_specified = (
|
||||
tensor_parallel_size # CLI argument (highest priority)
|
||||
or model_defaults.get("tensor-parallel-size") # Model defaults
|
||||
or settings.get("inference.tensor_parallel_size") # Config file
|
||||
# Tensor parallel size: CLI > config > auto-detect from GPUs
|
||||
final_tensor_parallel_size = resolve(
|
||||
tensor_parallel_size, "inference.tensor_parallel_size", len(gpus) if gpus else 1
|
||||
)
|
||||
|
||||
if explicitly_specified:
|
||||
# Use explicitly specified value
|
||||
requested_tensor_parallel_size = explicitly_specified
|
||||
else:
|
||||
# Auto-detect from GPUs, considering model's max constraint
|
||||
detected_gpu_count = len(gpus) if gpus else 1
|
||||
if model_info and model_info.max_tensor_parallel_size is not None:
|
||||
# Automatically limit to model's maximum to use as many GPUs as possible
|
||||
requested_tensor_parallel_size = min(detected_gpu_count, model_info.max_tensor_parallel_size)
|
||||
else:
|
||||
requested_tensor_parallel_size = detected_gpu_count
|
||||
|
||||
# Apply model's max_tensor_parallel_size constraint if explicitly specified value exceeds it
|
||||
final_tensor_parallel_size = requested_tensor_parallel_size
|
||||
if model_info and model_info.max_tensor_parallel_size is not None:
|
||||
if requested_tensor_parallel_size > model_info.max_tensor_parallel_size:
|
||||
console.print()
|
||||
print_warning(
|
||||
f"Model {model_info.name} only supports up to {model_info.max_tensor_parallel_size}-way "
|
||||
f"tensor parallelism, but {requested_tensor_parallel_size} was requested. "
|
||||
f"Reducing to {model_info.max_tensor_parallel_size}."
|
||||
)
|
||||
final_tensor_parallel_size = model_info.max_tensor_parallel_size
|
||||
|
||||
# CPU/GPU configuration with smart defaults
|
||||
# kt-cpuinfer: default to 80% of total CPU threads (cores * NUMA nodes)
|
||||
total_threads = cpu.cores * cpu.numa_nodes
|
||||
final_cpu_threads = (
|
||||
cpu_threads
|
||||
or model_defaults.get("kt-cpuinfer")
|
||||
or settings.get("inference.cpu_threads")
|
||||
or int(total_threads * 0.8)
|
||||
)
|
||||
|
||||
# kt-threadpool-count: default to NUMA node count
|
||||
final_numa_nodes = (
|
||||
numa_nodes
|
||||
or model_defaults.get("kt-threadpool-count")
|
||||
or settings.get("inference.numa_nodes")
|
||||
or cpu.numa_nodes
|
||||
)
|
||||
|
||||
# kt-num-gpu-experts: use model-specific computation if available and not explicitly set
|
||||
if gpu_experts is not None:
|
||||
# User explicitly set it
|
||||
final_gpu_experts = gpu_experts
|
||||
elif model_info and model_info.name in MODEL_COMPUTE_FUNCTIONS and gpus:
|
||||
# Use model-specific computation function (only if GPUs detected)
|
||||
vram_per_gpu = gpus[0].vram_gb
|
||||
compute_func = MODEL_COMPUTE_FUNCTIONS[model_info.name]
|
||||
final_gpu_experts = compute_func(final_tensor_parallel_size, vram_per_gpu)
|
||||
console.print()
|
||||
print_info(
|
||||
f"Auto-computed kt-num-gpu-experts: {final_gpu_experts} (TP={final_tensor_parallel_size}, VRAM={vram_per_gpu}GB per GPU)"
|
||||
)
|
||||
else:
|
||||
# Fall back to defaults
|
||||
final_gpu_experts = model_defaults.get("kt-num-gpu-experts") or settings.get("inference.gpu_experts", 1)
|
||||
total_threads = cpu.threads # Use logical threads instead of physical cores
|
||||
final_cpu_threads = resolve(cpu_threads, "inference.cpu_threads", int(total_threads * 0.8))
|
||||
final_numa_nodes = resolve(numa_nodes, "inference.numa_nodes", cpu.numa_nodes)
|
||||
final_gpu_experts = resolve(gpu_experts, "inference.gpu_experts", 1)
|
||||
|
||||
# KT-kernel options
|
||||
final_kt_method = kt_method or model_defaults.get("kt-method") or settings.get("inference.kt_method", "AMXINT4")
|
||||
final_kt_gpu_prefill_threshold = (
|
||||
kt_gpu_prefill_threshold
|
||||
or model_defaults.get("kt-gpu-prefill-token-threshold")
|
||||
or settings.get("inference.kt_gpu_prefill_token_threshold", 4096)
|
||||
)
|
||||
final_kt_method = resolve(kt_method, "inference.kt_method", "AMXINT4")
|
||||
final_kt_gpu_prefill_threshold = resolve(kt_gpu_prefill_threshold, "inference.kt_gpu_prefill_token_threshold", 4096)
|
||||
|
||||
# SGLang options
|
||||
final_attention_backend = (
|
||||
attention_backend
|
||||
or model_defaults.get("attention-backend")
|
||||
or settings.get("inference.attention_backend", "triton")
|
||||
)
|
||||
final_max_total_tokens = (
|
||||
max_total_tokens or model_defaults.get("max-total-tokens") or settings.get("inference.max_total_tokens", 40000)
|
||||
)
|
||||
final_max_running_requests = (
|
||||
max_running_requests
|
||||
or model_defaults.get("max-running-requests")
|
||||
or settings.get("inference.max_running_requests", 32)
|
||||
)
|
||||
final_chunked_prefill_size = (
|
||||
chunked_prefill_size
|
||||
or model_defaults.get("chunked-prefill-size")
|
||||
or settings.get("inference.chunked_prefill_size", 4096)
|
||||
)
|
||||
final_mem_fraction_static = (
|
||||
mem_fraction_static
|
||||
or model_defaults.get("mem-fraction-static")
|
||||
or settings.get("inference.mem_fraction_static", 0.98)
|
||||
)
|
||||
final_watchdog_timeout = (
|
||||
watchdog_timeout or model_defaults.get("watchdog-timeout") or settings.get("inference.watchdog_timeout", 3000)
|
||||
)
|
||||
final_served_model_name = (
|
||||
served_model_name or model_defaults.get("served-model-name") or settings.get("inference.served_model_name", "")
|
||||
)
|
||||
final_attention_backend = resolve(attention_backend, "inference.attention_backend", "flashinfer")
|
||||
final_max_total_tokens = resolve(max_total_tokens, "inference.max_total_tokens", 40000)
|
||||
final_max_running_requests = resolve(max_running_requests, "inference.max_running_requests", 32)
|
||||
final_chunked_prefill_size = resolve(chunked_prefill_size, "inference.chunked_prefill_size", 4096)
|
||||
final_mem_fraction_static = resolve(mem_fraction_static, "inference.mem_fraction_static", 0.98)
|
||||
final_watchdog_timeout = resolve(watchdog_timeout, "inference.watchdog_timeout", 3000)
|
||||
final_served_model_name = resolve(served_model_name, "inference.served_model_name", "")
|
||||
|
||||
# Performance flags
|
||||
if disable_shared_experts_fusion is not None:
|
||||
final_disable_shared_experts_fusion = disable_shared_experts_fusion
|
||||
elif "disable-shared-experts-fusion" in model_defaults:
|
||||
final_disable_shared_experts_fusion = model_defaults["disable-shared-experts-fusion"]
|
||||
else:
|
||||
final_disable_shared_experts_fusion = settings.get("inference.disable_shared_experts_fusion", False)
|
||||
final_disable_shared_experts_fusion = resolve(
|
||||
disable_shared_experts_fusion, "inference.disable_shared_experts_fusion", True
|
||||
)
|
||||
|
||||
# Pass all model default params to handle any extra parameters
|
||||
extra_params = model_defaults if model_info else {}
|
||||
# Pass extra CLI parameters
|
||||
extra_params = {}
|
||||
|
||||
# Parser parameters (from interactive mode or None in non-interactive mode)
|
||||
final_tool_call_parser = None
|
||||
final_reasoning_parser = None
|
||||
if "tool_call_parser" in locals() and tool_call_parser:
|
||||
final_tool_call_parser = tool_call_parser
|
||||
if "reasoning_parser" in locals() and reasoning_parser:
|
||||
final_reasoning_parser = reasoning_parser
|
||||
|
||||
cmd = _build_sglang_command(
|
||||
model_path=resolved_model_path,
|
||||
weights_path=resolved_weights_path,
|
||||
model_info=model_info,
|
||||
host=final_host,
|
||||
port=final_port,
|
||||
gpu_experts=final_gpu_experts,
|
||||
|
|
@ -490,6 +472,8 @@ def _run_impl(
|
|||
watchdog_timeout=final_watchdog_timeout,
|
||||
served_model_name=final_served_model_name,
|
||||
disable_shared_experts_fusion=final_disable_shared_experts_fusion,
|
||||
tool_call_parser=final_tool_call_parser,
|
||||
reasoning_parser=final_reasoning_parser,
|
||||
settings=settings,
|
||||
extra_model_params=extra_params,
|
||||
extra_cli_args=extra_cli_args,
|
||||
|
|
@ -508,11 +492,9 @@ def _run_impl(
|
|||
console.print()
|
||||
print_step("Configuration")
|
||||
|
||||
# Model info
|
||||
if model_info:
|
||||
console.print(f" Model: [bold]{model_info.name}[/bold]")
|
||||
else:
|
||||
console.print(f" Model: [bold]{resolved_model_path.name}[/bold]")
|
||||
# Display model name
|
||||
model_display_name = user_model_obj.name if user_model_obj else resolved_model_path.name
|
||||
console.print(f" Model: [bold]{model_display_name}[/bold]")
|
||||
|
||||
console.print(f" Path: [dim]{resolved_model_path}[/dim]")
|
||||
|
||||
|
|
@ -572,88 +554,13 @@ def _run_impl(
|
|||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def _find_model_path(model_info: ModelInfo, settings, max_depth: int = 3) -> Optional[Path]:
|
||||
"""Find the model path on disk by searching all configured model paths.
|
||||
|
||||
Args:
|
||||
model_info: Model information to search for
|
||||
settings: Settings instance
|
||||
max_depth: Maximum depth to search within each model path (default: 3)
|
||||
|
||||
Returns:
|
||||
Path to the model directory, or None if not found
|
||||
"""
|
||||
model_paths = settings.get_model_paths()
|
||||
|
||||
# Generate possible names to search for
|
||||
possible_names = [
|
||||
model_info.name,
|
||||
model_info.name.lower(),
|
||||
model_info.name.replace(" ", "-"),
|
||||
model_info.hf_repo.split("/")[-1],
|
||||
model_info.hf_repo.replace("/", "--"),
|
||||
]
|
||||
|
||||
# Add alias-based names
|
||||
for alias in model_info.aliases:
|
||||
possible_names.append(alias)
|
||||
possible_names.append(alias.lower())
|
||||
|
||||
# Search in all configured model directories
|
||||
for models_dir in model_paths:
|
||||
if not models_dir.exists():
|
||||
continue
|
||||
|
||||
# Search recursively up to max_depth
|
||||
for depth in range(max_depth):
|
||||
for name in possible_names:
|
||||
if depth == 0:
|
||||
# Direct children: models_dir / name
|
||||
search_paths = [models_dir / name]
|
||||
else:
|
||||
# Nested: use rglob to find directories matching the name
|
||||
search_paths = list(models_dir.rglob(name))
|
||||
|
||||
for path in search_paths:
|
||||
if path.exists() and (path / "config.json").exists():
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_weights_path(model_info: ModelInfo, settings) -> Optional[Path]:
|
||||
"""Find the quantized weights path on disk by searching all configured paths."""
|
||||
model_paths = settings.get_model_paths()
|
||||
weights_dir = settings.weights_dir
|
||||
|
||||
# Check common patterns
|
||||
base_names = [
|
||||
model_info.name,
|
||||
model_info.name.lower(),
|
||||
model_info.hf_repo.split("/")[-1],
|
||||
]
|
||||
|
||||
suffixes = ["-INT4", "-int4", "_INT4", "_int4", "-quant", "-quantized"]
|
||||
|
||||
# Prepare search directories
|
||||
search_dirs = [weights_dir] if weights_dir else []
|
||||
search_dirs.extend(model_paths)
|
||||
|
||||
for base in base_names:
|
||||
for suffix in suffixes:
|
||||
for dir_path in search_dirs:
|
||||
if dir_path:
|
||||
path = dir_path / f"{base}{suffix}"
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
return None
|
||||
# Dead code removed: _find_model_path() and _find_weights_path()
|
||||
# These functions were part of the old builtin model system
|
||||
|
||||
|
||||
def _build_sglang_command(
|
||||
model_path: Path,
|
||||
weights_path: Optional[Path],
|
||||
model_info: Optional[ModelInfo],
|
||||
host: str,
|
||||
port: int,
|
||||
gpu_experts: int,
|
||||
|
|
@ -670,6 +577,8 @@ def _build_sglang_command(
|
|||
watchdog_timeout: int,
|
||||
served_model_name: str,
|
||||
disable_shared_experts_fusion: bool,
|
||||
tool_call_parser: Optional[str],
|
||||
reasoning_parser: Optional[str],
|
||||
settings,
|
||||
extra_model_params: Optional[dict] = None, # New parameter for additional params
|
||||
extra_cli_args: Optional[list[str]] = None, # Extra args from CLI to pass to sglang
|
||||
|
|
@ -700,9 +609,6 @@ def _build_sglang_command(
|
|||
elif cpu_threads > 0 or gpu_experts > 1:
|
||||
# CPU offloading configured - use kt-kernel
|
||||
use_kt_kernel = True
|
||||
elif model_info and model_info.type == "moe":
|
||||
# MoE model - likely needs kt-kernel for expert offloading
|
||||
use_kt_kernel = True
|
||||
|
||||
if use_kt_kernel:
|
||||
# Add kt-weight-path: use quantized weights if available, otherwise use model path
|
||||
|
|
@ -723,6 +629,7 @@ def _build_sglang_command(
|
|||
kt_method,
|
||||
"--kt-gpu-prefill-token-threshold",
|
||||
str(kt_gpu_prefill_threshold),
|
||||
"--kt-enable-dynamic-expert-update", # Enable dynamic expert updates
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -757,6 +664,16 @@ def _build_sglang_command(
|
|||
if disable_shared_experts_fusion:
|
||||
cmd.append("--disable-shared-experts-fusion")
|
||||
|
||||
# Add FP8 backend if using FP8 method
|
||||
if "FP8" in kt_method.upper():
|
||||
cmd.extend(["--fp8-gemm-backend", "triton"])
|
||||
|
||||
# Add parsers if specified
|
||||
if tool_call_parser:
|
||||
cmd.extend(["--tool-call-parser", tool_call_parser])
|
||||
if reasoning_parser:
|
||||
cmd.extend(["--reasoning-parser", reasoning_parser])
|
||||
|
||||
# Add any extra parameters from model defaults that weren't explicitly handled
|
||||
if extra_model_params:
|
||||
# List of parameters already handled above
|
||||
|
|
@ -801,30 +718,31 @@ def _build_sglang_command(
|
|||
return cmd
|
||||
|
||||
|
||||
def _interactive_model_selection(registry, settings) -> Optional[str]:
|
||||
def _interactive_model_selection(user_registry, settings) -> Optional[str]:
|
||||
"""Show interactive model selection interface.
|
||||
|
||||
Returns:
|
||||
Selected model name or None if cancelled.
|
||||
"""
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.prompt import Prompt
|
||||
|
||||
from kt_kernel.cli.i18n import get_lang
|
||||
# Get all user models
|
||||
all_models = user_registry.list_models()
|
||||
|
||||
lang = get_lang()
|
||||
|
||||
# Find local models first
|
||||
local_models = registry.find_local_models()
|
||||
|
||||
# Get all registered models
|
||||
all_models = registry.list_all()
|
||||
if not all_models:
|
||||
console.print()
|
||||
print_warning("No models registered.")
|
||||
console.print()
|
||||
console.print(f" Add models with: [cyan]kt model scan[/cyan]")
|
||||
console.print(f" Or manually: [cyan]kt model add /path/to/model[/cyan]")
|
||||
console.print()
|
||||
return None
|
||||
|
||||
console.print()
|
||||
console.print(
|
||||
Panel.fit(
|
||||
t("run_select_model_title"),
|
||||
"Select a model to run",
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
|
|
@ -834,54 +752,30 @@ def _interactive_model_selection(registry, settings) -> Optional[str]:
|
|||
choices = []
|
||||
choice_map = {} # index -> model name
|
||||
|
||||
# Section 1: Local models (downloaded)
|
||||
if local_models:
|
||||
console.print(f"[bold green]{t('run_local_models')}[/bold green]")
|
||||
console.print()
|
||||
|
||||
for i, (model_info, path) in enumerate(local_models, 1):
|
||||
desc = model_info.description_zh if lang == "zh" else model_info.description
|
||||
short_desc = desc[:50] + "..." if len(desc) > 50 else desc
|
||||
console.print(f" [cyan][{i}][/cyan] [bold]{model_info.name}[/bold]")
|
||||
console.print(f" [dim]{short_desc}[/dim]")
|
||||
console.print(f" [dim]{path}[/dim]")
|
||||
choices.append(str(i))
|
||||
choice_map[str(i)] = model_info.name
|
||||
|
||||
console.print()
|
||||
|
||||
# Section 2: All registered models (for reference)
|
||||
start_idx = len(local_models) + 1
|
||||
console.print(f"[bold yellow]{t('run_registered_models')}[/bold yellow]")
|
||||
# Show all user models
|
||||
console.print(f"[bold green]Available Models:[/bold green]")
|
||||
console.print()
|
||||
|
||||
# Filter out already shown local models
|
||||
local_model_names = {m.name for m, _ in local_models}
|
||||
|
||||
for i, model_info in enumerate(all_models, start_idx):
|
||||
if model_info.name in local_model_names:
|
||||
continue
|
||||
|
||||
desc = model_info.description_zh if lang == "zh" else model_info.description
|
||||
short_desc = desc[:50] + "..." if len(desc) > 50 else desc
|
||||
console.print(f" [cyan][{i}][/cyan] [bold]{model_info.name}[/bold]")
|
||||
console.print(f" [dim]{short_desc}[/dim]")
|
||||
console.print(f" [dim]{model_info.hf_repo}[/dim]")
|
||||
for i, model in enumerate(all_models, 1):
|
||||
# Check if path exists
|
||||
path_status = "✓" if model.path_exists() else "✗ Missing"
|
||||
console.print(f" [cyan][{i}][/cyan] [bold]{model.name}[/bold] [{path_status}]")
|
||||
console.print(f" [dim]{model.format} - {model.path}[/dim]")
|
||||
choices.append(str(i))
|
||||
choice_map[str(i)] = model_info.name
|
||||
choice_map[str(i)] = model.name
|
||||
|
||||
console.print()
|
||||
|
||||
# Add cancel option
|
||||
cancel_idx = str(len(choices) + 1)
|
||||
console.print(f" [cyan][{cancel_idx}][/cyan] [dim]{t('cancel')}[/dim]")
|
||||
console.print(f" [cyan][{cancel_idx}][/cyan] [dim]Cancel[/dim]")
|
||||
choices.append(cancel_idx)
|
||||
console.print()
|
||||
|
||||
# Prompt for selection
|
||||
try:
|
||||
selection = Prompt.ask(
|
||||
t("run_select_model_prompt"),
|
||||
"Select model",
|
||||
choices=choices,
|
||||
default="1" if choices else cancel_idx,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue