mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-05-01 21:21:12 +00:00
update kt-kernel
This commit is contained in:
parent
1a925769d9
commit
f854d03bd7
119 changed files with 4459 additions and 6368 deletions
529
kt-kernel/scripts/compare_weights.py
Normal file
529
kt-kernel/scripts/compare_weights.py
Normal file
|
|
@ -0,0 +1,529 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Compare two sets of quantized weights generated by convert_cpu_weights.py
|
||||
|
||||
This script supports comparing:
|
||||
- Two safetensor format weights (merged)
|
||||
- Two .kt format weights (layer folder structure)
|
||||
- One safetensor and one .kt format (cross-format comparison)
|
||||
|
||||
Usage:
|
||||
python compare_weights.py --path1 /path/to/weights1 --path2 /path/to/weights2
|
||||
python compare_weights.py --path1 /path/to/weights1 --path2 /path/to/weights2 --tolerance 1e-5
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import glob
|
||||
import numpy as np
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
from typing import Dict, Tuple
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def unpack_awq_int32_to_int8(packed: np.ndarray, bits: int = 4) -> np.ndarray:
|
||||
"""Unpack AWQ int32 packed format to int8
|
||||
|
||||
AWQ uses INT4 quantization: 8 x 4-bit values packed into 1 x 32-bit integer
|
||||
|
||||
Args:
|
||||
packed: Packed int32 array
|
||||
bits: Number of bits per element (default: 4)
|
||||
|
||||
Returns:
|
||||
Unpacked int8 array
|
||||
"""
|
||||
if packed.dtype != np.int32:
|
||||
# Try to reinterpret as int32
|
||||
packed = packed.view(np.int32)
|
||||
|
||||
pack_num = 32 // bits # 8 for INT4
|
||||
unpacked_size = packed.size * pack_num
|
||||
|
||||
unpacked = np.empty(unpacked_size, dtype=np.int8)
|
||||
|
||||
for i in range(pack_num):
|
||||
shift = i * bits
|
||||
mask = (1 << bits) - 1 # 0x0F for 4-bit
|
||||
unpacked[i::pack_num] = ((packed >> shift) & mask).astype(np.int8)
|
||||
|
||||
return unpacked
|
||||
|
||||
|
||||
def normalize_tensor_dtype(tensor: np.ndarray, tensor_name: str, is_awq: bool = False) -> np.ndarray:
|
||||
"""Normalize tensor to consistent dtype based on tensor type
|
||||
|
||||
Args:
|
||||
tensor: Input tensor
|
||||
tensor_name: Name of the tensor (used to determine type)
|
||||
is_awq: Whether this is AWQ format (requires unpacking)
|
||||
|
||||
Returns:
|
||||
Normalized tensor with consistent dtype
|
||||
"""
|
||||
# Determine tensor type from name
|
||||
is_scale = "scale" in tensor_name
|
||||
is_weight = "weight" in tensor_name
|
||||
is_qzeros = "qzeros" in tensor_name
|
||||
|
||||
if is_scale:
|
||||
# Scale should be float32
|
||||
if tensor.dtype != np.float32:
|
||||
# Try to reinterpret bytes as float32
|
||||
tensor = tensor.view(np.float32)
|
||||
return tensor
|
||||
|
||||
elif is_weight or is_qzeros:
|
||||
# Weight/qzeros should be int8
|
||||
if is_awq and tensor.dtype == np.int32:
|
||||
# AWQ format: unpack int32 to int8
|
||||
tensor = unpack_awq_int32_to_int8(tensor)
|
||||
elif tensor.dtype == np.float32:
|
||||
# Two cases for float32:
|
||||
# Case 1: Values look like int8 values (e.g., [37., 73., -70.])
|
||||
# -> use astype to convert values
|
||||
# Case 2: Values are large scientific notation (e.g., [2.6e34, ...])
|
||||
# -> use view to reinterpret bytes
|
||||
|
||||
# Check if values are in int8 range (-128 to 127)
|
||||
if len(tensor) > 0:
|
||||
sample_size = min(100, len(tensor))
|
||||
sample_values = tensor.flat[:sample_size]
|
||||
|
||||
# If most values are in int8 range and have no decimal parts
|
||||
in_int8_range = np.all((sample_values >= -128) & (sample_values <= 127))
|
||||
is_integer_valued = np.all(sample_values == np.round(sample_values))
|
||||
|
||||
if in_int8_range and is_integer_valued:
|
||||
# Case 1: Direct value conversion
|
||||
tensor = tensor.astype(np.int8)
|
||||
else:
|
||||
# Case 2: Byte reinterpretation (4 bytes -> 4 int8s)
|
||||
tensor = tensor.view(np.int8)
|
||||
else:
|
||||
tensor = tensor.astype(np.int8)
|
||||
|
||||
elif tensor.dtype == np.int32:
|
||||
# Reinterpret int32 as int8 (4x more elements)
|
||||
tensor = tensor.view(np.int8)
|
||||
elif tensor.dtype != np.int8:
|
||||
# Other types: try to convert
|
||||
tensor = tensor.astype(np.int8)
|
||||
|
||||
return tensor
|
||||
|
||||
else:
|
||||
# Unknown type, return as-is
|
||||
return tensor
|
||||
|
||||
|
||||
def load_kt_binary(file_path: str) -> np.ndarray:
|
||||
"""Load .kt format binary tensor file
|
||||
|
||||
Args:
|
||||
file_path: Path to .kt binary file
|
||||
|
||||
Returns:
|
||||
numpy array with the loaded tensor
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
binary_data = f.read()
|
||||
|
||||
# Determine dtype based on file name
|
||||
if "scale" in file_path:
|
||||
dtype = np.float32
|
||||
else:
|
||||
dtype = np.int8
|
||||
|
||||
return np.frombuffer(binary_data, dtype=dtype)
|
||||
|
||||
|
||||
def detect_weight_format(path: str) -> str:
|
||||
"""Detect if weights are in safetensor or .kt format
|
||||
|
||||
Args:
|
||||
path: Path to weight directory
|
||||
|
||||
Returns:
|
||||
'safetensor' or 'kt' or 'unknown'
|
||||
"""
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f"Path not found: {path}")
|
||||
|
||||
# Check for safetensor files
|
||||
safetensor_files = glob.glob(os.path.join(path, "*.safetensors"))
|
||||
if safetensor_files:
|
||||
return "safetensor"
|
||||
|
||||
# Check for layer folder structure
|
||||
layer_folders = glob.glob(os.path.join(path, "_layer_*"))
|
||||
if layer_folders:
|
||||
return "kt"
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
def detect_awq_format(weights_sample: Dict[str, np.ndarray]) -> bool:
|
||||
"""Detect if weights are in AWQ format
|
||||
|
||||
AWQ format characteristics:
|
||||
- Has 'qzeros' tensors
|
||||
- Weight tensors are int32 dtype (packed)
|
||||
|
||||
Args:
|
||||
weights_sample: Sample of loaded weights
|
||||
|
||||
Returns:
|
||||
True if AWQ format detected
|
||||
"""
|
||||
has_qzeros = any("qzeros" in key for key in weights_sample.keys())
|
||||
|
||||
if not has_qzeros:
|
||||
return False
|
||||
|
||||
# Check if weight tensors are int32
|
||||
for key, tensor in weights_sample.items():
|
||||
if "weight" in key and tensor.dtype == np.int32:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def load_safetensor_weights(path: str) -> Dict[str, np.ndarray]:
|
||||
"""Load all weights from safetensor format
|
||||
|
||||
Args:
|
||||
path: Path to directory containing safetensor files
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tensor names to numpy arrays (dtype normalized)
|
||||
"""
|
||||
weights = {}
|
||||
|
||||
safetensor_files = sorted(glob.glob(os.path.join(path, "*.safetensors")))
|
||||
if not safetensor_files:
|
||||
raise FileNotFoundError(f"No safetensor files found in {path}")
|
||||
|
||||
print(f"Loading safetensor files from {path}")
|
||||
|
||||
# First pass: load all tensors
|
||||
for file in safetensor_files:
|
||||
with safe_open(file, framework="pt") as f:
|
||||
for key in f.keys():
|
||||
# Only load MoE expert weights for comparison
|
||||
if ".ffn_" in key and "_exps." in key:
|
||||
tensor = f.get_tensor(key)
|
||||
weights[key] = tensor.cpu().numpy()
|
||||
|
||||
# Detect AWQ format
|
||||
is_awq = detect_awq_format(weights)
|
||||
print(f" Format detected: {'AWQ' if is_awq else 'INT4/INT8'}")
|
||||
|
||||
# Second pass: normalize dtypes
|
||||
print(f" Normalizing dtypes...")
|
||||
for key in list(weights.keys()):
|
||||
original_dtype = weights[key].dtype
|
||||
original_shape = weights[key].shape
|
||||
weights[key] = normalize_tensor_dtype(weights[key], key, is_awq=is_awq)
|
||||
|
||||
if weights[key].shape != original_shape or weights[key].dtype != original_dtype:
|
||||
print(f" {key}: {original_dtype}{original_shape} -> {weights[key].dtype}{weights[key].shape}")
|
||||
|
||||
print(f" Loaded {len(weights)} tensors from safetensor format")
|
||||
return weights
|
||||
|
||||
|
||||
def load_kt_weights(path: str) -> Dict[str, np.ndarray]:
|
||||
"""Load all weights from .kt format (layer folder structure)
|
||||
|
||||
Args:
|
||||
path: Path to directory containing _layer_* folders
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tensor names to numpy arrays
|
||||
"""
|
||||
weights = {}
|
||||
|
||||
layer_folders = sorted(glob.glob(os.path.join(path, "_layer_*")))
|
||||
if not layer_folders:
|
||||
raise FileNotFoundError(f"No _layer_* folders found in {path}")
|
||||
|
||||
print(f"Loading .kt files from {path}")
|
||||
|
||||
for layer_folder in layer_folders:
|
||||
# Extract layer index from folder name
|
||||
layer_idx = int(os.path.basename(layer_folder).split("_")[-1])
|
||||
|
||||
# Find all NUMA folders
|
||||
numa_folders = sorted(glob.glob(os.path.join(layer_folder, "_numa_*")))
|
||||
|
||||
for numa_folder in numa_folders:
|
||||
# Extract NUMA index
|
||||
numa_idx = int(os.path.basename(numa_folder).split("_")[-1])
|
||||
|
||||
# Find all .kt files
|
||||
kt_files = glob.glob(os.path.join(numa_folder, "*.kt"))
|
||||
|
||||
for kt_file in kt_files:
|
||||
filename = os.path.basename(kt_file)
|
||||
|
||||
# Parse filename to extract metadata
|
||||
# Format: {METHOD}_{proj}_{expert}_{size}Byte_{type}_.kt
|
||||
parts = filename.replace(".kt", "").split("_")
|
||||
|
||||
if len(parts) >= 5:
|
||||
method = parts[0] # INT4, INT8, etc.
|
||||
proj = parts[1] # down, gate, up
|
||||
expert = parts[2] # expert ID
|
||||
tensor_type = parts[4] # quant or scale
|
||||
|
||||
# Map proj names
|
||||
proj_map = {"down": "ffn_down_exps", "gate": "ffn_gate_exps", "up": "ffn_up_exps"}
|
||||
|
||||
proj_key = proj_map.get(proj, proj)
|
||||
|
||||
# Build key matching safetensor format
|
||||
if tensor_type == "quant":
|
||||
key = f"blk.{layer_idx}.{proj_key}.{expert}.numa.{numa_idx}.weight"
|
||||
else: # scale
|
||||
key = f"blk.{layer_idx}.{proj_key}.{expert}.numa.{numa_idx}.scale"
|
||||
|
||||
# Load tensor
|
||||
weights[key] = load_kt_binary(kt_file)
|
||||
|
||||
# Normalize dtypes (.kt format is never AWQ)
|
||||
print(f" Normalizing dtypes...")
|
||||
for key in list(weights.keys()):
|
||||
original_dtype = weights[key].dtype
|
||||
original_shape = weights[key].shape
|
||||
weights[key] = normalize_tensor_dtype(weights[key], key, is_awq=False)
|
||||
|
||||
if weights[key].shape != original_shape or weights[key].dtype != original_dtype:
|
||||
print(f" {key}: {original_dtype}{original_shape} -> {weights[key].dtype}{weights[key].shape}")
|
||||
|
||||
print(f" Loaded {len(weights)} tensors from .kt format")
|
||||
return weights
|
||||
|
||||
|
||||
def normalize_key(key: str) -> Tuple[int, str, int, str]:
|
||||
"""Normalize tensor key to extract layer, projection, expert, and type
|
||||
|
||||
Args:
|
||||
key: Tensor key like "blk.0.ffn_up_exps.5.weight" or "blk.0.ffn_up_exps.5.numa.0.weight"
|
||||
|
||||
Returns:
|
||||
Tuple of (layer_idx, proj_name, expert_idx, tensor_type)
|
||||
"""
|
||||
parts = key.split(".")
|
||||
|
||||
layer_idx = int(parts[1])
|
||||
proj_name = parts[2]
|
||||
expert_idx = int(parts[3])
|
||||
|
||||
# Handle both formats: with and without numa
|
||||
if "numa" in key:
|
||||
tensor_type = parts[6] # weight or scale
|
||||
else:
|
||||
tensor_type = parts[4] # weight, scale, or qzeros
|
||||
|
||||
return (layer_idx, proj_name, expert_idx, tensor_type)
|
||||
|
||||
|
||||
def compare_weights(
|
||||
weights1: Dict[str, np.ndarray], weights2: Dict[str, np.ndarray], tolerance: float = 1e-6
|
||||
) -> Tuple[bool, Dict[str, Dict]]:
|
||||
"""Compare two sets of weights
|
||||
|
||||
Args:
|
||||
weights1: First set of weights
|
||||
weights2: Second set of weights
|
||||
tolerance: Numerical tolerance for comparison
|
||||
|
||||
Returns:
|
||||
Tuple of (all_match, differences_dict)
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("WEIGHT COMPARISON")
|
||||
print("=" * 80)
|
||||
|
||||
# Group keys by normalized form (ignoring numa index)
|
||||
def group_by_base_key(weights):
|
||||
groups = defaultdict(list)
|
||||
for key in weights.keys():
|
||||
try:
|
||||
layer, proj, expert, ttype = normalize_key(key)
|
||||
base_key = f"blk.{layer}.{proj}.{expert}.{ttype}"
|
||||
groups[base_key].append(key)
|
||||
except:
|
||||
# Skip keys that don't match expected format
|
||||
pass
|
||||
return groups
|
||||
|
||||
groups1 = group_by_base_key(weights1)
|
||||
groups2 = group_by_base_key(weights2)
|
||||
|
||||
all_base_keys = sorted(set(groups1.keys()) | set(groups2.keys()))
|
||||
|
||||
all_match = True
|
||||
differences = {}
|
||||
|
||||
total_comparisons = 0
|
||||
matching_comparisons = 0
|
||||
|
||||
for base_key in all_base_keys:
|
||||
keys1 = groups1.get(base_key, [])
|
||||
keys2 = groups2.get(base_key, [])
|
||||
|
||||
if not keys1:
|
||||
print(f"❌ Missing in weights1: {base_key}")
|
||||
differences[base_key] = {"status": "missing_in_weights1"}
|
||||
all_match = False
|
||||
continue
|
||||
|
||||
if not keys2:
|
||||
print(f"❌ Missing in weights2: {base_key}")
|
||||
differences[base_key] = {"status": "missing_in_weights2"}
|
||||
all_match = False
|
||||
continue
|
||||
|
||||
# For kt format, we may have multiple keys (one per NUMA node)
|
||||
# We need to concatenate them for comparison
|
||||
if len(keys1) > 1 or len(keys2) > 1:
|
||||
# Concatenate tensors from all NUMA nodes
|
||||
tensor1 = np.concatenate([weights1[k] for k in sorted(keys1)])
|
||||
tensor2 = np.concatenate([weights2[k] for k in sorted(keys2)])
|
||||
else:
|
||||
tensor1 = weights1[keys1[0]]
|
||||
tensor2 = weights2[keys2[0]]
|
||||
|
||||
total_comparisons += 1
|
||||
|
||||
# Debug: print dtype and shape info
|
||||
if tensor1.dtype != tensor2.dtype:
|
||||
print(f"⚠️ Dtype mismatch for {base_key}: {tensor1.dtype} vs {tensor2.dtype}")
|
||||
print(f" This should have been normalized. Shape: {tensor1.shape} vs {tensor2.shape}")
|
||||
|
||||
# Compare shapes
|
||||
if tensor1.shape != tensor2.shape:
|
||||
print(f"❌ Shape mismatch for {base_key}:")
|
||||
print(f" Shape1: {tensor1.shape} (dtype: {tensor1.dtype})")
|
||||
print(f" Shape2: {tensor2.shape} (dtype: {tensor2.dtype})")
|
||||
differences[base_key] = {
|
||||
"status": "shape_mismatch",
|
||||
"shape1": tensor1.shape,
|
||||
"shape2": tensor2.shape,
|
||||
"dtype1": str(tensor1.dtype),
|
||||
"dtype2": str(tensor2.dtype),
|
||||
}
|
||||
all_match = False
|
||||
continue
|
||||
|
||||
# Compare dtypes (should be consistent after normalization)
|
||||
if tensor1.dtype != tensor2.dtype:
|
||||
print(f"❌ Dtype mismatch for {base_key} after normalization:")
|
||||
print(f" Dtype1: {tensor1.dtype}")
|
||||
print(f" Dtype2: {tensor2.dtype}")
|
||||
differences[base_key] = {
|
||||
"status": "dtype_mismatch",
|
||||
"dtype1": str(tensor1.dtype),
|
||||
"dtype2": str(tensor2.dtype),
|
||||
}
|
||||
all_match = False
|
||||
continue
|
||||
|
||||
# Compare values
|
||||
if np.allclose(tensor1, tensor2, atol=tolerance, rtol=tolerance):
|
||||
matching_comparisons += 1
|
||||
else:
|
||||
max_diff = np.max(np.abs(tensor1 - tensor2))
|
||||
mean_diff = np.mean(np.abs(tensor1 - tensor2))
|
||||
|
||||
print(f"❌ Value mismatch for {base_key}:")
|
||||
print(f" Max difference: {max_diff:.2e}")
|
||||
print(f" Mean difference: {mean_diff:.2e}")
|
||||
print(f" Tolerance: {tolerance:.2e}")
|
||||
|
||||
differences[base_key] = {
|
||||
"status": "value_mismatch",
|
||||
"max_diff": float(max_diff),
|
||||
"mean_diff": float(mean_diff),
|
||||
"tolerance": tolerance,
|
||||
}
|
||||
all_match = False
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("SUMMARY")
|
||||
print("=" * 80)
|
||||
print(f"Total comparisons: {total_comparisons}")
|
||||
print(f"Matching: {matching_comparisons}")
|
||||
print(f"Mismatching: {total_comparisons - matching_comparisons}")
|
||||
print(f"Missing tensors: {len(differences) - (total_comparisons - matching_comparisons)}")
|
||||
|
||||
if all_match:
|
||||
print("\n✅ All weights match!")
|
||||
else:
|
||||
print(f"\n❌ Found {len(differences)} differences")
|
||||
|
||||
return all_match, differences
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Compare two sets of quantized weights")
|
||||
parser.add_argument("--path1", type=str, required=True, help="Path to first weight directory")
|
||||
parser.add_argument("--path2", type=str, required=True, help="Path to second weight directory")
|
||||
parser.add_argument(
|
||||
"--tolerance", type=float, default=1e-6, help="Numerical tolerance for comparison (default: 1e-6)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate paths
|
||||
if not os.path.exists(args.path1):
|
||||
print(f"Error: Path1 does not exist: {args.path1}")
|
||||
return 1
|
||||
|
||||
if not os.path.exists(args.path2):
|
||||
print(f"Error: Path2 does not exist: {args.path2}")
|
||||
return 1
|
||||
|
||||
# Detect formats
|
||||
print("Detecting weight formats...")
|
||||
format1 = detect_weight_format(args.path1)
|
||||
format2 = detect_weight_format(args.path2)
|
||||
|
||||
print(f"Path1 format: {format1}")
|
||||
print(f"Path2 format: {format2}")
|
||||
|
||||
if format1 == "unknown":
|
||||
print(f"Error: Unable to detect weight format in {args.path1}")
|
||||
return 1
|
||||
|
||||
if format2 == "unknown":
|
||||
print(f"Error: Unable to detect weight format in {args.path2}")
|
||||
return 1
|
||||
|
||||
# Load weights based on format
|
||||
print("\nLoading weights...")
|
||||
|
||||
if format1 == "safetensor":
|
||||
weights1 = load_safetensor_weights(args.path1)
|
||||
else:
|
||||
weights1 = load_kt_weights(args.path1)
|
||||
|
||||
if format2 == "safetensor":
|
||||
weights2 = load_safetensor_weights(args.path2)
|
||||
else:
|
||||
weights2 = load_kt_weights(args.path2)
|
||||
|
||||
# Compare weights
|
||||
all_match, differences = compare_weights(weights1, weights2, args.tolerance)
|
||||
|
||||
return 0 if all_match else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
Loading…
Add table
Add a link
Reference in a new issue