#!/usr/bin/env python3 """ Compare two sets of quantized weights generated by convert_cpu_weights.py This script supports comparing: - Two safetensor format weights (merged) - Two .kt format weights (layer folder structure) - One safetensor and one .kt format (cross-format comparison) Usage: python compare_weights.py --path1 /path/to/weights1 --path2 /path/to/weights2 python compare_weights.py --path1 /path/to/weights1 --path2 /path/to/weights2 --tolerance 1e-5 """ import argparse import os import glob import numpy as np import torch from safetensors import safe_open from typing import Dict, Tuple from collections import defaultdict def unpack_awq_int32_to_int8(packed: np.ndarray, bits: int = 4) -> np.ndarray: """Unpack AWQ int32 packed format to int8 AWQ uses INT4 quantization: 8 x 4-bit values packed into 1 x 32-bit integer Args: packed: Packed int32 array bits: Number of bits per element (default: 4) Returns: Unpacked int8 array """ if packed.dtype != np.int32: # Try to reinterpret as int32 packed = packed.view(np.int32) pack_num = 32 // bits # 8 for INT4 unpacked_size = packed.size * pack_num unpacked = np.empty(unpacked_size, dtype=np.int8) for i in range(pack_num): shift = i * bits mask = (1 << bits) - 1 # 0x0F for 4-bit unpacked[i::pack_num] = ((packed >> shift) & mask).astype(np.int8) return unpacked def normalize_tensor_dtype(tensor: np.ndarray, tensor_name: str, is_awq: bool = False) -> np.ndarray: """Normalize tensor to consistent dtype based on tensor type Args: tensor: Input tensor tensor_name: Name of the tensor (used to determine type) is_awq: Whether this is AWQ format (requires unpacking) Returns: Normalized tensor with consistent dtype """ # Determine tensor type from name is_scale = "scale" in tensor_name is_weight = "weight" in tensor_name is_qzeros = "qzeros" in tensor_name if is_scale: # Scale should be float32 if tensor.dtype != np.float32: # Try to reinterpret bytes as float32 tensor = tensor.view(np.float32) return tensor elif is_weight or is_qzeros: # Weight/qzeros should be int8 if is_awq and tensor.dtype == np.int32: # AWQ format: unpack int32 to int8 tensor = unpack_awq_int32_to_int8(tensor) elif tensor.dtype == np.float32: # Two cases for float32: # Case 1: Values look like int8 values (e.g., [37., 73., -70.]) # -> use astype to convert values # Case 2: Values are large scientific notation (e.g., [2.6e34, ...]) # -> use view to reinterpret bytes # Check if values are in int8 range (-128 to 127) if len(tensor) > 0: sample_size = min(100, len(tensor)) sample_values = tensor.flat[:sample_size] # If most values are in int8 range and have no decimal parts in_int8_range = np.all((sample_values >= -128) & (sample_values <= 127)) is_integer_valued = np.all(sample_values == np.round(sample_values)) if in_int8_range and is_integer_valued: # Case 1: Direct value conversion tensor = tensor.astype(np.int8) else: # Case 2: Byte reinterpretation (4 bytes -> 4 int8s) tensor = tensor.view(np.int8) else: tensor = tensor.astype(np.int8) elif tensor.dtype == np.int32: # Reinterpret int32 as int8 (4x more elements) tensor = tensor.view(np.int8) elif tensor.dtype != np.int8: # Other types: try to convert tensor = tensor.astype(np.int8) return tensor else: # Unknown type, return as-is return tensor def load_kt_binary(file_path: str) -> np.ndarray: """Load .kt format binary tensor file Args: file_path: Path to .kt binary file Returns: numpy array with the loaded tensor """ if not os.path.exists(file_path): raise FileNotFoundError(f"File not found: {file_path}") with open(file_path, "rb") as f: binary_data = f.read() # Determine dtype based on file name if "scale" in file_path: dtype = np.float32 else: dtype = np.int8 return np.frombuffer(binary_data, dtype=dtype) def detect_weight_format(path: str) -> str: """Detect if weights are in safetensor or .kt format Args: path: Path to weight directory Returns: 'safetensor' or 'kt' or 'unknown' """ if not os.path.exists(path): raise FileNotFoundError(f"Path not found: {path}") # Check for safetensor files safetensor_files = glob.glob(os.path.join(path, "*.safetensors")) if safetensor_files: return "safetensor" # Check for layer folder structure layer_folders = glob.glob(os.path.join(path, "_layer_*")) if layer_folders: return "kt" return "unknown" def detect_awq_format(weights_sample: Dict[str, np.ndarray]) -> bool: """Detect if weights are in AWQ format AWQ format characteristics: - Has 'qzeros' tensors - Weight tensors are int32 dtype (packed) Args: weights_sample: Sample of loaded weights Returns: True if AWQ format detected """ has_qzeros = any("qzeros" in key for key in weights_sample.keys()) if not has_qzeros: return False # Check if weight tensors are int32 for key, tensor in weights_sample.items(): if "weight" in key and tensor.dtype == np.int32: return True return False def load_safetensor_weights(path: str) -> Dict[str, np.ndarray]: """Load all weights from safetensor format Args: path: Path to directory containing safetensor files Returns: Dictionary mapping tensor names to numpy arrays (dtype normalized) """ weights = {} safetensor_files = sorted(glob.glob(os.path.join(path, "*.safetensors"))) if not safetensor_files: raise FileNotFoundError(f"No safetensor files found in {path}") print(f"Loading safetensor files from {path}") # First pass: load all tensors for file in safetensor_files: with safe_open(file, framework="pt") as f: for key in f.keys(): # Only load MoE expert weights for comparison if ".ffn_" in key and "_exps." in key: tensor = f.get_tensor(key) weights[key] = tensor.cpu().numpy() # Detect AWQ format is_awq = detect_awq_format(weights) print(f" Format detected: {'AWQ' if is_awq else 'INT4/INT8'}") # Second pass: normalize dtypes print(f" Normalizing dtypes...") for key in list(weights.keys()): original_dtype = weights[key].dtype original_shape = weights[key].shape weights[key] = normalize_tensor_dtype(weights[key], key, is_awq=is_awq) if weights[key].shape != original_shape or weights[key].dtype != original_dtype: print(f" {key}: {original_dtype}{original_shape} -> {weights[key].dtype}{weights[key].shape}") print(f" Loaded {len(weights)} tensors from safetensor format") return weights def load_kt_weights(path: str) -> Dict[str, np.ndarray]: """Load all weights from .kt format (layer folder structure) Args: path: Path to directory containing _layer_* folders Returns: Dictionary mapping tensor names to numpy arrays """ weights = {} layer_folders = sorted(glob.glob(os.path.join(path, "_layer_*"))) if not layer_folders: raise FileNotFoundError(f"No _layer_* folders found in {path}") print(f"Loading .kt files from {path}") for layer_folder in layer_folders: # Extract layer index from folder name layer_idx = int(os.path.basename(layer_folder).split("_")[-1]) # Find all NUMA folders numa_folders = sorted(glob.glob(os.path.join(layer_folder, "_numa_*"))) for numa_folder in numa_folders: # Extract NUMA index numa_idx = int(os.path.basename(numa_folder).split("_")[-1]) # Find all .kt files kt_files = glob.glob(os.path.join(numa_folder, "*.kt")) for kt_file in kt_files: filename = os.path.basename(kt_file) # Parse filename to extract metadata # Format: {METHOD}_{proj}_{expert}_{size}Byte_{type}_.kt parts = filename.replace(".kt", "").split("_") if len(parts) >= 5: method = parts[0] # INT4, INT8, etc. proj = parts[1] # down, gate, up expert = parts[2] # expert ID tensor_type = parts[4] # quant or scale # Map proj names proj_map = {"down": "ffn_down_exps", "gate": "ffn_gate_exps", "up": "ffn_up_exps"} proj_key = proj_map.get(proj, proj) # Build key matching safetensor format if tensor_type == "quant": key = f"blk.{layer_idx}.{proj_key}.{expert}.numa.{numa_idx}.weight" else: # scale key = f"blk.{layer_idx}.{proj_key}.{expert}.numa.{numa_idx}.scale" # Load tensor weights[key] = load_kt_binary(kt_file) # Normalize dtypes (.kt format is never AWQ) print(f" Normalizing dtypes...") for key in list(weights.keys()): original_dtype = weights[key].dtype original_shape = weights[key].shape weights[key] = normalize_tensor_dtype(weights[key], key, is_awq=False) if weights[key].shape != original_shape or weights[key].dtype != original_dtype: print(f" {key}: {original_dtype}{original_shape} -> {weights[key].dtype}{weights[key].shape}") print(f" Loaded {len(weights)} tensors from .kt format") return weights def normalize_key(key: str) -> Tuple[int, str, int, str]: """Normalize tensor key to extract layer, projection, expert, and type Args: key: Tensor key like "blk.0.ffn_up_exps.5.weight" or "blk.0.ffn_up_exps.5.numa.0.weight" Returns: Tuple of (layer_idx, proj_name, expert_idx, tensor_type) """ parts = key.split(".") layer_idx = int(parts[1]) proj_name = parts[2] expert_idx = int(parts[3]) # Handle both formats: with and without numa if "numa" in key: tensor_type = parts[6] # weight or scale else: tensor_type = parts[4] # weight, scale, or qzeros return (layer_idx, proj_name, expert_idx, tensor_type) def compare_weights( weights1: Dict[str, np.ndarray], weights2: Dict[str, np.ndarray], tolerance: float = 1e-6 ) -> Tuple[bool, Dict[str, Dict]]: """Compare two sets of weights Args: weights1: First set of weights weights2: Second set of weights tolerance: Numerical tolerance for comparison Returns: Tuple of (all_match, differences_dict) """ print("\n" + "=" * 80) print("WEIGHT COMPARISON") print("=" * 80) # Group keys by normalized form (ignoring numa index) def group_by_base_key(weights): groups = defaultdict(list) for key in weights.keys(): try: layer, proj, expert, ttype = normalize_key(key) base_key = f"blk.{layer}.{proj}.{expert}.{ttype}" groups[base_key].append(key) except: # Skip keys that don't match expected format pass return groups groups1 = group_by_base_key(weights1) groups2 = group_by_base_key(weights2) all_base_keys = sorted(set(groups1.keys()) | set(groups2.keys())) all_match = True differences = {} total_comparisons = 0 matching_comparisons = 0 for base_key in all_base_keys: keys1 = groups1.get(base_key, []) keys2 = groups2.get(base_key, []) if not keys1: print(f"❌ Missing in weights1: {base_key}") differences[base_key] = {"status": "missing_in_weights1"} all_match = False continue if not keys2: print(f"❌ Missing in weights2: {base_key}") differences[base_key] = {"status": "missing_in_weights2"} all_match = False continue # For kt format, we may have multiple keys (one per NUMA node) # We need to concatenate them for comparison if len(keys1) > 1 or len(keys2) > 1: # Concatenate tensors from all NUMA nodes tensor1 = np.concatenate([weights1[k] for k in sorted(keys1)]) tensor2 = np.concatenate([weights2[k] for k in sorted(keys2)]) else: tensor1 = weights1[keys1[0]] tensor2 = weights2[keys2[0]] total_comparisons += 1 # Debug: print dtype and shape info if tensor1.dtype != tensor2.dtype: print(f"⚠️ Dtype mismatch for {base_key}: {tensor1.dtype} vs {tensor2.dtype}") print(f" This should have been normalized. Shape: {tensor1.shape} vs {tensor2.shape}") # Compare shapes if tensor1.shape != tensor2.shape: print(f"❌ Shape mismatch for {base_key}:") print(f" Shape1: {tensor1.shape} (dtype: {tensor1.dtype})") print(f" Shape2: {tensor2.shape} (dtype: {tensor2.dtype})") differences[base_key] = { "status": "shape_mismatch", "shape1": tensor1.shape, "shape2": tensor2.shape, "dtype1": str(tensor1.dtype), "dtype2": str(tensor2.dtype), } all_match = False continue # Compare dtypes (should be consistent after normalization) if tensor1.dtype != tensor2.dtype: print(f"❌ Dtype mismatch for {base_key} after normalization:") print(f" Dtype1: {tensor1.dtype}") print(f" Dtype2: {tensor2.dtype}") differences[base_key] = { "status": "dtype_mismatch", "dtype1": str(tensor1.dtype), "dtype2": str(tensor2.dtype), } all_match = False continue # Compare values if np.allclose(tensor1, tensor2, atol=tolerance, rtol=tolerance): matching_comparisons += 1 else: max_diff = np.max(np.abs(tensor1 - tensor2)) mean_diff = np.mean(np.abs(tensor1 - tensor2)) print(f"❌ Value mismatch for {base_key}:") print(f" Max difference: {max_diff:.2e}") print(f" Mean difference: {mean_diff:.2e}") print(f" Tolerance: {tolerance:.2e}") differences[base_key] = { "status": "value_mismatch", "max_diff": float(max_diff), "mean_diff": float(mean_diff), "tolerance": tolerance, } all_match = False print("\n" + "=" * 80) print("SUMMARY") print("=" * 80) print(f"Total comparisons: {total_comparisons}") print(f"Matching: {matching_comparisons}") print(f"Mismatching: {total_comparisons - matching_comparisons}") print(f"Missing tensors: {len(differences) - (total_comparisons - matching_comparisons)}") if all_match: print("\n✅ All weights match!") else: print(f"\n❌ Found {len(differences)} differences") return all_match, differences def main(): parser = argparse.ArgumentParser(description="Compare two sets of quantized weights") parser.add_argument("--path1", type=str, required=True, help="Path to first weight directory") parser.add_argument("--path2", type=str, required=True, help="Path to second weight directory") parser.add_argument( "--tolerance", type=float, default=1e-6, help="Numerical tolerance for comparison (default: 1e-6)" ) args = parser.parse_args() # Validate paths if not os.path.exists(args.path1): print(f"Error: Path1 does not exist: {args.path1}") return 1 if not os.path.exists(args.path2): print(f"Error: Path2 does not exist: {args.path2}") return 1 # Detect formats print("Detecting weight formats...") format1 = detect_weight_format(args.path1) format2 = detect_weight_format(args.path2) print(f"Path1 format: {format1}") print(f"Path2 format: {format2}") if format1 == "unknown": print(f"Error: Unable to detect weight format in {args.path1}") return 1 if format2 == "unknown": print(f"Error: Unable to detect weight format in {args.path2}") return 1 # Load weights based on format print("\nLoading weights...") if format1 == "safetensor": weights1 = load_safetensor_weights(args.path1) else: weights1 = load_kt_weights(args.path1) if format2 == "safetensor": weights2 = load_safetensor_weights(args.path2) else: weights2 = load_kt_weights(args.path2) # Compare weights all_match, differences = compare_weights(weights1, weights2, args.tolerance) return 0 if all_match else 1 if __name__ == "__main__": exit(main())