kvcache-ai-ktransformers/kt-kernel/python/sft/weights.py
mrhaoxx 9544a8960d
Some checks failed
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Book-CI / test-2 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
feat(sft): AMX MoE SFT backend with LoRA support (#1936)
* feat(sft): AMX MoE SFT backend with LoRA support

Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD:

Core C++ implementation:
- sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines)
- moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA
- amx/moe-sft-tp.hpp: AMX-specific TP implementation
- avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM
- amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization
- worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure
- ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants)

Python sft/ submodule (kt_kernel.sft):
- base.py: BaseSFTMoEWrapper with buffer management (template method pattern)
- amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction)
- autograd.py: KTMoEFunction (torch.autograd.Function for distributed training)
- layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers)
- arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection)
- weights.py: Expert weight extraction and checkpoint loading
- lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter)
- wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map
- config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough)
- dist_utils.py: Distributed gather/scatter, checkpoint-phase detection

Design decisions:
- Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights
- DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework
  interaction fields), all logic in kt_kernel.sft
- Inference isolation: importing kt_kernel does not load sft/ submodule
- Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically

Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally.

* refactor(sft): unify KTConfig field names with kt_ prefix, add share_cache_pool, remove dead code

- KTConfig fields all use kt_ prefix matching dict keys — eliminates
  _OLD_TO_NEW mapping and prefix-stripping in wrapper.py
- Add kt_share_cache_pool field, auto-enabled when gradient_checkpointing
  is on (via training_args.py), flows through to C++ cache allocation
- Remove dead checkpoint detection code: in_ckpt_recompute,
  in_ckpt_first_forward vars (assigned but never read), fallback
  _is_in_checkpoint_first_forward() function, unused inspect import
- Remove redundant env var fallbacks in wrapper.py for share_backward_bb
  and share_cache_pool (KTConfig.__post_init__ already handles env vars)
- Simplify layer.py checkpoint logic to single _checkpoint_hook_mode() check

Verified: Qwen3-235B 3-step training on sap4, loss matches baseline
(1.2886 / 1.9824 / 1.377 vs 1.2886 / 1.9766 / 1.3809)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* refactor(sft): share_backward_bb default True, share_cache_pool auto-derived

- kt_share_backward_bb defaults to True (always saves memory)
- kt_share_cache_pool no longer reads from env var; defaults False,
  auto-set to True by trainer_config_process when gradient checkpointing
  is enabled

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix: add missing gpu_experts_mask=None to KTMoEWrapper call in SFT wrapper

KTMoEWrapper.__new__() requires gpu_experts_mask as a positional argument,
but the SFT wrapper omitted it, causing MoE layer wrapping to fail silently
and FSDP2 to attempt broadcasting all expert weights (OOM/NCCL crash).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(sft): support transformers v5 fused expert format

Fused experts (e.g. Qwen3MoeExperts) store weights as 3D Parameters
(gate_up_proj [E,2I,H], down_proj [E,H,I]) instead of per-expert
nn.Linear modules. PEFT cannot attach LoRA to these, so we create
KT-managed LoRA buffers with kaiming init, nn.Parameter wrappers
for the optimizer, and pre-assigned .grad for C++ backward.

- arch.py: detect_fused_experts() detection
- weights.py: fused format extraction and weight clearing
- wrapper.py: detect fused at wrap time, store _fused_experts/_lora_rank
- lora.py: _create_fused_expert_lora_buffers, save/load fused LoRA,
  get_kt_lora_params collects fused params, deduplicate wrapper finding
- layer.py: handle v5 TopKRouter tuple output, remove dead code
- autograd.py: sync_forward_sft/submit_forward_sft API rename

Verified: v5 loss/expert-LoRA values match v4 baseline, v4 backward compat.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* feat(sft): add Qwen3.5 MoE support + fused checkpoint loading

- arch.py: add Qwen3_5Moe arch match, read config from text_config,
  _get_layers_prefix returns model.language_model.layers for Qwen3.5,
  _get_model_container_and_layers searches language_model attr
- weights.py: load_experts_from_checkpoint_files detects fused format
  (gate_up_proj in weight_map) and splits into gate/up/down
- wrapper.py: hidden_size fallback to text_config

Verified: Qwen3.5-35B-A3B (256 experts, fused format) E2E pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [fix](sft): align Python API with C++ backend after v5 refactor

- wrapper.py: pass gpu_experts_mask=None to KTMoEWrapper (required by C++ signature)
- layer.py: rename submit_forward_sft/sync_forward_sft to submit_forward/sync_forward
- autograd.py: rename sync_forward_sft to sync_forward

The sft-v5 refactor (commits 58d7eab, dd1da65) renamed Python-side method
calls but the C++ backend (AMXSFTMoEWrapper) still exposes the original
method names. This caused AttributeError on Qwen3.5-35B and other models.

* align sft branch with main: revert worker_pool, strip sft_timer, fix inference defaults

- Revert worker_pool.cpp/.h to main (remove RDTSC timer, Chrome Trace,
  sft_timer namespace, ITT API, extended do_work_stealing_job API)
- Strip all sft_timer instrumentation from sft-only files (sft_moe.hpp,
  moe-sft-tp.hpp, avx_kernels.hpp)
- Restore pin_memory=True in KExpertsCPUBuffer (inference path)
- Restore fused tensor transpose logic in convert_cpu_weights.py (main layout)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* revert CMakeLists.txt to main: remove debug flags and cpptrace dep

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* clean up dev artifacts: remove SFT design docs, debug examples, bench scripts

Remove files not needed in the merge:
- docs/SFT+KTWrapper/ (6 Chinese design docs)
- docs/sft_moe_amx/ (21 dev/debug docs)
- 12 debug/test example scripts
- 6 SFT-specific bench scripts and report

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* remove dev version stamps from ext_bindings, sft_moe, moe-sft-tp

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: JimmyPeilinLi <lipeilin@mail.nwpu.edu.cn>
2026-04-22 11:27:01 +08:00

557 lines
21 KiB
Python

# Weight extraction and loading utilities for SFT
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import json
import logging
import os
import time
from contextlib import nullcontext
from dataclasses import dataclass
import torch
import torch.nn as nn
from .arch import MOEArchConfig
from .dist_utils import _maybe_zero3_gathered_parameters
logger = logging.getLogger(__name__)
try:
from safetensors import safe_open
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
safe_open = None
# =============================================================================
# Weight Extraction
# =============================================================================
def extract_moe_weights(
moe_module: nn.Module, moe_config: MOEArchConfig
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Extract MoE expert weights from the module.
Returns (gate_proj, up_proj, down_proj) with shape
[expert_num, out_features, in_features].
Supports two formats:
- ModuleList of Linear experts (transformers v4 style)
- Fused Parameters (transformers v5 style): single module with
``gate_up_proj`` [E, 2*I, H] and ``down_proj`` [E, H, I] tensors.
"""
from .arch import detect_fused_experts
experts = getattr(moe_module, moe_config.experts_attr)
# Fused format (transformers v5): a single nn.Module with gate_up_proj/down_proj tensors
if detect_fused_experts(experts):
gate_up = getattr(experts, "gate_up_proj").data
down_fused = getattr(experts, "down_proj").data
# gate_up_proj is [E, 2*I, H], split into gate [E, I, H] and up [E, I, H]
intermediate = gate_up.shape[1] // 2
gate_proj = gate_up[:, :intermediate, :].contiguous()
up_proj = gate_up[:, intermediate:, :].contiguous()
# down_proj is already [E, H, I]
down_proj = down_fused.contiguous()
return gate_proj, up_proj, down_proj
gate_name, up_name, down_name = moe_config.weight_names
gather_params: list[torch.nn.Parameter] = []
for expert in experts:
for weight_name in (gate_name, up_name, down_name):
proj = getattr(expert, weight_name, None)
if proj is not None and hasattr(proj, "weight"):
# Handle PEFT LoRA wrapped modules
weight = proj.weight
if isinstance(weight, torch.Tensor):
gather_params.append(weight)
elif hasattr(weight, "data"):
gather_params.append(weight.data)
with _maybe_zero3_gathered_parameters(gather_params):
gate_weights = []
up_weights = []
down_weights = []
for expert in experts:
# Handle PEFT LoRA wrapped modules - get weight tensor properly
gate_proj = getattr(expert, gate_name)
up_proj_mod = getattr(expert, up_name)
down_proj_mod = getattr(expert, down_name)
# Get weight tensors, handling both regular Linear and PEFT LoRA wrapped
def get_weight_tensor(mod):
weight = mod.weight
if isinstance(weight, torch.Tensor):
return weight.data
elif hasattr(weight, "data"):
return weight.data
else:
raise ValueError(f"Cannot extract weight from {type(mod)}, weight type={type(weight)}")
gate_weights.append(get_weight_tensor(gate_proj))
up_weights.append(get_weight_tensor(up_proj_mod))
down_weights.append(get_weight_tensor(down_proj_mod))
gate_proj = torch.stack(gate_weights, dim=0)
up_proj = torch.stack(up_weights, dim=0)
down_proj = torch.stack(down_weights, dim=0)
return gate_proj, up_proj, down_proj
def _clear_original_expert_weights(moe_module: nn.Module, moe_config: MOEArchConfig) -> None:
"""
Clear original expert weights to free memory after KT weights are loaded.
"""
from .arch import detect_fused_experts
experts = getattr(moe_module, moe_config.experts_attr, None)
if experts is None:
return
# Fused format: replace gate_up_proj/down_proj tensors with zero-storage placeholders
if detect_fused_experts(experts):
for name in ("gate_up_proj", "down_proj"):
param = getattr(experts, name, None)
if not isinstance(param, torch.nn.Parameter):
continue
original_dtype = param.dtype
tiny_storage = torch.UntypedStorage(1, device="cpu")
fake_tensor = torch.tensor([], dtype=original_dtype, device="cpu").set_(
tiny_storage, storage_offset=0, size=param.shape,
stride=[0] * len(param.shape),
)
experts._parameters[name] = nn.Parameter(fake_tensor, requires_grad=False)
return
def _iter_weight_params():
for expert in experts:
for weight_name in moe_config.weight_names:
proj = getattr(expert, weight_name, None)
if proj is None or not hasattr(proj, "weight"):
continue
parametrizations = getattr(proj, "parametrizations", None)
parametrized_weight = getattr(parametrizations, "weight", None) if parametrizations is not None else None
if parametrized_weight is not None:
original = getattr(parametrized_weight, "original", None)
if isinstance(original, torch.nn.Parameter):
yield proj, parametrized_weight, "original", original
continue
direct_weight = getattr(proj, "_parameters", {}).get("weight")
if isinstance(direct_weight, torch.nn.Parameter):
yield proj, proj, "weight", direct_weight
continue
# Fallback: `weight` can be a non-settable property (e.g. parametrizations) or a non-Parameter.
weight_attr = getattr(proj, "weight", None)
if isinstance(weight_attr, torch.nn.Parameter):
yield proj, proj, "weight", weight_attr
gather_params: list[torch.nn.Parameter] = []
for _, _, _, weight_param in _iter_weight_params():
gather_params.append(weight_param)
replaced_count = 0
with _maybe_zero3_gathered_parameters(gather_params):
for proj, container, param_name, weight_param in _iter_weight_params():
original_dtype = weight_param.dtype
# Create a CPU tensor with the correct shape but NO physical memory.
# torch.empty(shape, device="cpu") unfortunately touches pages via the
# allocator, consuming real RSS. Instead, allocate a 1-byte storage and
# use set_ to give it the original shape with zero strides. The tensor
# is "valid" (correct dtype, device, shape) so PEFT can discover
# in/out features, but its storage is essentially zero-cost.
# NOTE: reading element values from this tensor is undefined -- it is
# only used for shape/dtype discovery by PEFT.
tiny_storage = torch.UntypedStorage(1, device="cpu")
fake_tensor = torch.tensor([], dtype=original_dtype, device="cpu").set_(
tiny_storage, storage_offset=0, size=weight_param.shape,
stride=[0] * len(weight_param.shape),
)
new_param = nn.Parameter(fake_tensor, requires_grad=False)
replaced_count += 1
# Avoid `KeyError: attribute 'weight' already exists` for parametrized modules
# where `weight` is a property and the real parameter lives elsewhere.
container_params = getattr(container, "_parameters", {})
if isinstance(container_params, dict) and param_name in container_params:
container_params[param_name] = new_param
continue
if hasattr(container, param_name):
logger.debug(
f"Skipping clearing expert weight {type(proj).__name__}.{param_name}: "
"attribute exists but is not a registered parameter."
)
continue
try:
setattr(container, param_name, new_param)
except Exception as exc:
logger.warning(
f"Failed to clear expert weight {type(proj).__name__}.{param_name}: {exc}"
)
logger.info(f"Replaced {replaced_count} expert weight params")
# =============================================================================
# kt_weight_path Loading Functions
# =============================================================================
@dataclass
class INT8ExpertWeights:
"""Container for INT8 expert weights with scales."""
gate_proj: torch.Tensor
gate_scale: torch.Tensor
up_proj: torch.Tensor
up_scale: torch.Tensor
down_proj: torch.Tensor
down_scale: torch.Tensor
def _find_safetensor_files(kt_weight_path: str) -> list[str]:
if not os.path.isdir(kt_weight_path):
raise FileNotFoundError(f"kt_weight_path directory not found: {kt_weight_path}")
safetensor_files = []
for file in sorted(os.listdir(kt_weight_path)):
if file.endswith(".safetensors"):
safetensor_files.append(os.path.join(kt_weight_path, file))
if not safetensor_files:
raise FileNotFoundError(f"No safetensors files found in {kt_weight_path}")
return safetensor_files
def _load_kt_weight_index(kt_weight_path: str) -> dict[str, str]:
if not SAFETENSORS_AVAILABLE:
raise ImportError("safetensors is required for loading kt_weight_path")
index = {}
safetensor_files = _find_safetensor_files(kt_weight_path)
for file_path in safetensor_files:
with safe_open(file_path, framework="pt") as f:
for key in f.keys():
index[key] = file_path
logger.info(f"Indexed {len(index)} tensors from {len(safetensor_files)} safetensors files")
return index
def _dequant_fp8_experts(weights: list[torch.Tensor], scales: list[torch.Tensor | None], block_size: tuple[int, int]) -> torch.Tensor:
"""Dequantize a list of FP8 expert weights and stack them (batched, vectorized).
Args:
weights: list of [out, in] float8_e4m3fn tensors (one per expert)
scales: list of [out//bs_m, in//bs_n] scale_inv tensors (one per expert, may be None)
block_size: (bs_m, bs_n)
Returns:
Stacked BF16 tensor of shape [num_experts, out, in]
"""
has_scales = scales[0] is not None
if not has_scales:
return torch.stack(weights, dim=0).to(torch.bfloat16).cpu().contiguous()
bs_m, bs_n = block_size
n = len(weights)
out_features, in_features = weights[0].shape
# Stack all experts: [N, out, in] fp8 -> reshape to blocks -> bf16
w = torch.stack(weights, dim=0) # [N, out, in] fp8
w = w.reshape(n, out_features // bs_m, bs_m, in_features // bs_n, bs_n)
w = w.to(torch.bfloat16)
# Stack all scales: [N, out//bs_m, in//bs_n] -> bf16, broadcast multiply
s = torch.stack(scales, dim=0).to(torch.bfloat16) # [N, out//bs_m, in//bs_n]
w = w * s[:, :, None, :, None]
return w.reshape(n, out_features, in_features).contiguous()
def load_experts_from_checkpoint_files(
checkpoint_files: list[str],
sharded_metadata: dict | None,
layers_prefix: str,
moe_config: MOEArchConfig,
layer_idx: int,
block_size: tuple[int, int] | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if not SAFETENSORS_AVAILABLE:
raise ImportError("safetensors is required for loading experts from checkpoint files")
if not checkpoint_files:
raise FileNotFoundError("checkpoint_files is empty")
t0 = time.time()
weight_map = None
base_dir = os.path.dirname(checkpoint_files[0])
if sharded_metadata is not None:
weight_map = sharded_metadata.get("weight_map", None)
gate_name, up_name, down_name = moe_config.weight_names
experts_prefix = f"{layers_prefix}.{layer_idx}.{moe_config.moe_layer_attr}.{moe_config.experts_attr}"
fused_gate_up_key = f"{experts_prefix}.gate_up_proj"
fused_down_key = f"{experts_prefix}.down_proj"
is_fused = weight_map is not None and fused_gate_up_key in weight_map
if is_fused:
keys = [fused_gate_up_key, fused_down_key]
else:
keys = []
for expert_idx in range(moe_config.expert_num):
base = f"{experts_prefix}.{expert_idx}"
keys.append(f"{base}.{gate_name}.weight")
keys.append(f"{base}.{gate_name}.weight_scale_inv")
keys.append(f"{base}.{up_name}.weight")
keys.append(f"{base}.{up_name}.weight_scale_inv")
keys.append(f"{base}.{down_name}.weight")
keys.append(f"{base}.{down_name}.weight_scale_inv")
keys_by_file: dict[str, list[str]] = {}
mapped_count = 0
unmapped_count = 0
for key in keys:
if weight_map is not None:
filename = weight_map.get(key)
if filename is None:
unmapped_count += 1
continue
mapped_count += 1
file_path = os.path.join(base_dir, filename)
else:
file_path = checkpoint_files[0]
keys_by_file.setdefault(file_path, []).append(key)
print(
f"[kt_moe] Layer {layer_idx}: key mapping done in {time.time()-t0:.1f}s — "
f"total_keys={len(keys)}, mapped={mapped_count}, unmapped={unmapped_count}, "
f"files_to_open={len(keys_by_file)}",
flush=True,
)
t1 = time.time()
tensor_map: dict[str, torch.Tensor] = {}
for file_idx, (file_path, file_keys) in enumerate(keys_by_file.items()):
with safe_open(file_path, framework="pt") as f:
available_keys = set(f.keys())
for key in file_keys:
if key in available_keys:
tensor_map[key] = f.get_tensor(key)
if file_idx == 0:
print(
f"[kt_moe] Layer {layer_idx}: first file loaded ({os.path.basename(file_path)}, "
f"{len(file_keys)} keys) in {time.time()-t1:.1f}s",
flush=True,
)
print(
f"[kt_moe] Layer {layer_idx}: all files loaded in {time.time()-t1:.1f}s — "
f"tensor_map has {len(tensor_map)} tensors",
flush=True,
)
t2 = time.time()
if is_fused:
gate_up = tensor_map.get(fused_gate_up_key)
down = tensor_map.get(fused_down_key)
if gate_up is None or down is None:
raise FileNotFoundError(f"Missing fused expert weights for layer {layer_idx}")
gate_up = gate_up.cpu().to(torch.bfloat16).contiguous()
I = gate_up.shape[1] // 2
gate_proj = gate_up[:, :I, :].contiguous()
up_proj = gate_up[:, I:, :].contiguous()
down_proj = down.cpu().to(torch.bfloat16).contiguous()
del gate_up
print(
f"[kt_moe] Layer {layer_idx}: fused expert format — "
f"split gate_up_proj [{gate_proj.shape}] + down [{down_proj.shape}]",
flush=True,
)
print(
f"[kt_moe] Layer {layer_idx}: done — dtype={gate_proj.dtype}, "
f"shape={gate_proj.shape}, dequant=0.0s, total={time.time()-t0:.1f}s",
flush=True,
)
return gate_proj, up_proj, down_proj
gate_weights = []
up_weights = []
down_weights = []
gate_scales = []
up_scales = []
down_scales = []
for expert_idx in range(moe_config.expert_num):
base = f"{experts_prefix}.{expert_idx}"
gate_key = f"{base}.{gate_name}.weight"
up_key = f"{base}.{up_name}.weight"
down_key = f"{base}.{down_name}.weight"
if gate_key not in tensor_map or up_key not in tensor_map or down_key not in tensor_map:
raise FileNotFoundError(f"Missing expert weights for layer {layer_idx}, expert {expert_idx}")
gate_weights.append(tensor_map[gate_key])
up_weights.append(tensor_map[up_key])
down_weights.append(tensor_map[down_key])
gate_scales.append(tensor_map.get(f"{base}.{gate_name}.weight_scale_inv"))
up_scales.append(tensor_map.get(f"{base}.{up_name}.weight_scale_inv"))
down_scales.append(tensor_map.get(f"{base}.{down_name}.weight_scale_inv"))
# Check if weights are FP8 and need dequantization
t2 = time.time()
is_fp8 = gate_weights[0].dtype == torch.float8_e4m3fn
if is_fp8:
if block_size is None:
block_size = (128, 128)
print(
f"[kt_moe] Layer {layer_idx}: FP8 expert weights detected, "
f"dequantizing with block_size={block_size} "
f"(has_scales={gate_scales[0] is not None})",
flush=True,
)
gate_proj = _dequant_fp8_experts(gate_weights, gate_scales, block_size)
up_proj = _dequant_fp8_experts(up_weights, up_scales, block_size)
down_proj = _dequant_fp8_experts(down_weights, down_scales, block_size)
else:
gate_proj = torch.stack(gate_weights, dim=0).cpu().to(torch.bfloat16).contiguous()
up_proj = torch.stack(up_weights, dim=0).cpu().to(torch.bfloat16).contiguous()
down_proj = torch.stack(down_weights, dim=0).cpu().to(torch.bfloat16).contiguous()
print(
f"[kt_moe] Layer {layer_idx}: done — dtype={gate_proj.dtype}, shape={gate_proj.shape}, "
f"dequant={time.time()-t2:.1f}s, total={time.time()-t0:.1f}s",
flush=True,
)
return gate_proj, up_proj, down_proj
def load_experts_from_kt_weight_path(
kt_weight_path: str,
layer_idx: int,
num_experts: int,
hidden_size: int,
intermediate_size: int,
) -> INT8ExpertWeights:
"""Load INT8 preprocessed expert weights from kt_weight_path for a specific layer."""
if not SAFETENSORS_AVAILABLE:
raise ImportError("safetensors is required for loading kt_weight_path")
index = _load_kt_weight_index(kt_weight_path)
numa_count = 0
test_key_prefix = f"blk.{layer_idx}.ffn_gate_exps.0.numa."
for key in index.keys():
if key.startswith(test_key_prefix) and key.endswith(".weight"):
numa_idx = int(key.split("numa.")[1].split(".")[0])
numa_count = max(numa_count, numa_idx + 1)
if numa_count == 0:
raise FileNotFoundError(
f"No weights found for layer {layer_idx} in {kt_weight_path}. "
f"Expected keys like 'blk.{layer_idx}.ffn_gate_exps.0.numa.0.weight'"
)
logger.info(
f"Loading INT8 weights for layer {layer_idx}: {num_experts} experts, {numa_count} NUMA partitions"
)
gate_weights_list = []
gate_scales_list = []
up_weights_list = []
up_scales_list = []
down_weights_list = []
down_scales_list = []
for expert_idx in range(num_experts):
gate_w_parts = []
gate_s_parts = []
for numa_idx in range(numa_count):
w_key = f"blk.{layer_idx}.ffn_gate_exps.{expert_idx}.numa.{numa_idx}.weight"
s_key = f"blk.{layer_idx}.ffn_gate_exps.{expert_idx}.numa.{numa_idx}.scale"
if w_key not in index:
raise FileNotFoundError(f"Weight key not found: {w_key}")
with safe_open(index[w_key], framework="pt") as f:
gate_w_parts.append(f.get_tensor(w_key))
gate_s_parts.append(f.get_tensor(s_key))
gate_w = torch.cat(gate_w_parts, dim=0)
gate_s = torch.cat(gate_s_parts, dim=0)
gate_w = gate_w.view(intermediate_size, hidden_size)
gate_weights_list.append(gate_w)
gate_scales_list.append(gate_s)
up_w_parts = []
up_s_parts = []
for numa_idx in range(numa_count):
w_key = f"blk.{layer_idx}.ffn_up_exps.{expert_idx}.numa.{numa_idx}.weight"
s_key = f"blk.{layer_idx}.ffn_up_exps.{expert_idx}.numa.{numa_idx}.scale"
if w_key not in index:
raise FileNotFoundError(f"Weight key not found: {w_key}")
with safe_open(index[w_key], framework="pt") as f:
up_w_parts.append(f.get_tensor(w_key))
up_s_parts.append(f.get_tensor(s_key))
up_w = torch.cat(up_w_parts, dim=0)
up_s = torch.cat(up_s_parts, dim=0)
up_w = up_w.view(intermediate_size, hidden_size)
up_weights_list.append(up_w)
up_scales_list.append(up_s)
down_w_parts = []
down_s_parts = []
for numa_idx in range(numa_count):
w_key = f"blk.{layer_idx}.ffn_down_exps.{expert_idx}.numa.{numa_idx}.weight"
s_key = f"blk.{layer_idx}.ffn_down_exps.{expert_idx}.numa.{numa_idx}.scale"
if w_key not in index:
raise FileNotFoundError(f"Weight key not found: {w_key}")
with safe_open(index[w_key], framework="pt") as f:
down_w_parts.append(f.get_tensor(w_key))
down_s_parts.append(f.get_tensor(s_key))
down_w = torch.cat(down_w_parts, dim=0)
down_s = torch.cat(down_s_parts, dim=0)
down_w = down_w.view(hidden_size, intermediate_size)
down_weights_list.append(down_w)
down_scales_list.append(down_s)
gate_proj = torch.stack(gate_weights_list, dim=0)
gate_scale = torch.stack(gate_scales_list, dim=0)
up_proj = torch.stack(up_weights_list, dim=0)
up_scale = torch.stack(up_scales_list, dim=0)
down_proj = torch.stack(down_weights_list, dim=0)
down_scale = torch.stack(down_scales_list, dim=0)
return INT8ExpertWeights(
gate_proj=gate_proj,
gate_scale=gate_scale,
up_proj=up_proj,
up_scale=up_scale,
down_proj=down_proj,
down_scale=down_scale,
)