diff --git a/kt-kernel/bench/bench_k2_moe_amx.py b/kt-kernel/bench/bench_k2_moe_amx.py new file mode 100644 index 00000000..50f5837e --- /dev/null +++ b/kt-kernel/bench/bench_k2_moe_amx.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python +# coding=utf-8 +""" +Benchmark AMX_K2_MOE_TP int4 path with packed weights and BF16 scales. +""" +import json +import math +import os +import platform +import subprocess +import sys +import time + +from tqdm import tqdm + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build")) + +import kt_kernel_ext +import torch + +# Benchmark parameters (single MoE, no layer loop) +expert_num = 384 +hidden_size = 7168 +intermediate_size = 2048 +max_len = 25600 +num_experts_per_tok = 8 +qlen = 1 +warm_up_iter = 1000 +test_iter = 5000 +k_group_size = 32 + +physical_to_logical_map = ( + torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous() +) + +worker_config = kt_kernel_ext.WorkerPoolConfig() +worker_config.subpool_count = 2 +worker_config.subpool_numa_map = [0, 1] +worker_config.subpool_thread_count = [40, 40] +CPUInfer = kt_kernel_ext.CPUInfer(worker_config) + + +def get_git_commit(): + result = {} + try: + commit = ( + subprocess.check_output(["git", "rev-parse", "HEAD"]) + .decode("utf-8") + .strip() + ) + commit_msg = ( + subprocess.check_output(["git", "log", "-1", "--pretty=%B"]) + .decode("utf-8") + .strip() + ) + result["commit"] = commit + result["commit_message"] = commit_msg + + dirty_output = ( + subprocess.check_output(["git", "status", "--porcelain"]) + .decode("utf-8") + .strip() + ) + if dirty_output: + result["dirty"] = True + result["dirty_files"] = dirty_output.splitlines() + else: + result["dirty"] = False + except Exception as e: + result["commit"] = None + result["commit_message"] = None + result["dirty"] = None + result["error"] = str(e) + return result + + +def get_system_info(): + info = {} + uname = platform.uname() + info["system_name"] = uname.system + info["node_name"] = uname.node + + cpu_model = None + if os.path.exists("/proc/cpuinfo"): + try: + with open("/proc/cpuinfo", "r") as f: + for line in f: + if "model name" in line: + cpu_model = line.split(":", 1)[1].strip() + break + except Exception as e: + cpu_model = f"Error: {e}" + info["cpu_model"] = cpu_model + + mem_total_gb = None + if os.path.exists("/proc/meminfo"): + try: + with open("/proc/meminfo", "r") as f: + for line in f: + if "MemTotal" in line: + mem_kb = float(line.split(":", 1)[1].split()[0]) + mem_total_gb = round(mem_kb / (1024 * 1024), 2) + break + except Exception as e: + mem_total_gb = f"Error: {e}" + info["memory_size_GB"] = mem_total_gb + + info["cpu_core_count"] = os.cpu_count() + + sockets = set() + if os.path.exists("/proc/cpuinfo"): + try: + with open("/proc/cpuinfo", "r") as f: + for line in f: + if "physical id" in line: + sockets.add(line.split(":", 1)[1].strip()) + except Exception: + sockets = set() + info["cpu_socket_count"] = len(sockets) if len(sockets) > 0 else 1 + + return info + + +script_path = os.path.abspath(__file__) +script_dir = os.path.dirname(script_path) +script_name = os.path.splitext(os.path.basename(script_path))[0] +json_path = os.path.join(script_dir, script_name + ".jsonl") + + +def record_results(result, filename=json_path): + with open(filename, "a") as f: + f.write(json.dumps(result) + "\n") + + +def pack_to_int32( + value: torch.Tensor, num_bits: int, packed_dim: int = 1 +) -> torch.Tensor: + if value.dtype is not torch.int8: + raise ValueError("Tensor must be torch.int8 before packing") + if not (1 <= num_bits <= 8): + raise ValueError(f"num_bits must be in [1, 8], got {num_bits}") + + offset = 1 << (num_bits - 1) + value = (value + offset).to(torch.uint8) + device = value.device + + pack_factor = 32 // num_bits + + if packed_dim == 0: + value = value.transpose(0, 1) + + rows, cols = value.shape + padded_cols = math.ceil(cols / pack_factor) * pack_factor + pad_len = padded_cols - cols + + if pad_len > 0: + value = torch.nn.functional.pad(value, (0, pad_len)) + + num_groups = padded_cols // pack_factor + reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32) + bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits + packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32) + + if packed_dim == 0: + packed = packed.transpose(0, 1) + + return packed + + +def pack_tensor_per_row(q: torch.Tensor, num_bits: int) -> torch.Tensor: + e, rows, cols = q.shape + flat = q.view(e * rows, cols) + packed = pack_to_int32(flat, num_bits) + return packed.view(e, rows, -1).contiguous() + + +def quantize_k2_tensor(weights: torch.Tensor, group_size: int): + """ + K2 int4 quantization producing int32-packed weights (8 int4s each) and BF16 scales. + """ + weights_f32 = weights.to(torch.float32) + e, rows, cols = weights_f32.shape + if cols % group_size != 0 or cols % 2 != 0: + raise ValueError( + f"cols ({cols}) must be divisible by group_size ({group_size}) and 2" + ) + + reshaped = weights_f32.view(e, rows, cols // group_size, group_size) + max_abs = reshaped.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) + scales = (max_abs / 7.0).squeeze(-1) + q = torch.round(reshaped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8) + q = q.view(e, rows, cols) + packed = pack_tensor_per_row(q, num_bits=4).view(e, rows, cols // 8).contiguous() + scales = scales.to(torch.bfloat16).contiguous().view( + e, rows, cols // group_size + ).contiguous() + return packed, scales + + +def build_quantized_layer_weights(): + gate_proj = torch.randn( + (expert_num, intermediate_size, hidden_size), + dtype=torch.float32, + device="cpu", + ).contiguous() + up_proj = torch.randn( + (expert_num, intermediate_size, hidden_size), + dtype=torch.float32, + device="cpu", + ).contiguous() + down_proj = torch.randn( + (expert_num, hidden_size, intermediate_size), + dtype=torch.float32, + device="cpu", + ).contiguous() + + gate_q, gate_scales = quantize_k2_tensor(gate_proj, k_group_size) + up_q, up_scales = quantize_k2_tensor(up_proj, k_group_size) + down_q, down_scales = quantize_k2_tensor(down_proj, k_group_size) + + return { + "gate_qweight": gate_q, + "up_qweight": up_q, + "down_qweight": down_q, + "gate_scales": gate_scales, + "up_scales": up_scales, + "down_scales": down_scales, + } + + +def bench_k2_moe(): + with torch.inference_mode(): + bytes_per_elem = 0.5 + 2.0 / k_group_size + + quant_data = build_quantized_layer_weights() + config = kt_kernel_ext.moe.MOEConfig( + expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0 + ) + config.max_len = max_len + config.quant_config.bits = 4 + config.quant_config.group_size = k_group_size + config.quant_config.zero_point = False + + config.gate_proj = quant_data["gate_qweight"].data_ptr() + config.up_proj = quant_data["up_qweight"].data_ptr() + config.down_proj = quant_data["down_qweight"].data_ptr() + + config.gate_scale = quant_data["gate_scales"].data_ptr() + config.up_scale = quant_data["up_scales"].data_ptr() + config.down_scale = quant_data["down_scales"].data_ptr() + config.pool = CPUInfer.backend_ + + moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(config) + CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr())) + CPUInfer.sync() + + gen_iter = 3000 + expert_ids = ( + torch.rand(gen_iter * qlen, expert_num, device="cpu") + .argsort(dim=-1)[:, :num_experts_per_tok] + .reshape(gen_iter, qlen * num_experts_per_tok) + .contiguous() + ) + weights = torch.rand( + (gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cpu" + ).contiguous() + input_tensor = torch.randn( + (qlen, hidden_size), dtype=torch.bfloat16, device="cpu" + ).contiguous() + output_tensor = torch.empty_like(input_tensor) + bsz_tensor = torch.tensor([qlen], device="cpu") + + for i in tqdm(range(warm_up_iter), desc="Warm-up"): + CPUInfer.submit( + moe.forward_task( + bsz_tensor.data_ptr(), + num_experts_per_tok, + expert_ids[i % gen_iter].data_ptr(), + weights[i % gen_iter].data_ptr(), + input_tensor.data_ptr(), + output_tensor.data_ptr(), + False, + ) + ) + CPUInfer.sync() + + start = time.perf_counter() + for i in tqdm(range(test_iter), desc="Testing"): + CPUInfer.submit( + moe.forward_task( + bsz_tensor.data_ptr(), + num_experts_per_tok, + expert_ids[i % gen_iter].data_ptr(), + weights[i % gen_iter].data_ptr(), + input_tensor.data_ptr(), + output_tensor.data_ptr(), + False, + ) + ) + CPUInfer.sync() + end = time.perf_counter() + total_time = end - start + + time_per_iter_us = total_time / test_iter * 1e6 + bandwidth = ( + hidden_size + * intermediate_size + * 3 + * num_experts_per_tok + * (1 / 8 * 256 * (1 - (31 / 32) ** qlen)) + * bytes_per_elem + * test_iter + / total_time + / 1e9 + ) + flops = ( + hidden_size + * intermediate_size + * qlen + * 3 + * num_experts_per_tok + * 2 + * test_iter + / total_time + / 1e12 + ) + + print("Quant mode: int4_k2") + print("Time(s): ", total_time) + print("Iteration: ", test_iter) + print("Time(us) per iteration: ", time_per_iter_us) + print("Bandwidth: ", bandwidth, "GB/s") + print("Flops: ", flops, "TFLOPS") + print("") + + result = { + "quant_mode": "int4_k2", + "total_time_seconds": total_time, + "iterations": test_iter, + "time_per_iteration_us": time_per_iter_us, + "bandwidth_GBs": bandwidth, + "flops_TFLOPS": flops, + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), + "test_parameters": { + "expert_num": expert_num, + "hidden_size": hidden_size, + "intermediate_size": intermediate_size, + "max_len": max_len, + "num_experts_per_tok": num_experts_per_tok, + "qlen": qlen, + "warm_up_iter": warm_up_iter, + "test_iter": test_iter, + "k_group_size": k_group_size, + "bytes_per_elem": bytes_per_elem, + }, + } + result.update(get_git_commit()) + result.update(get_system_info()) + record_results(result) + + +if __name__ == "__main__": + bench_k2_moe() diff --git a/kt-kernel/bench/bench_k2_write_buffer.py b/kt-kernel/bench/bench_k2_write_buffer.py new file mode 100644 index 00000000..940a0247 --- /dev/null +++ b/kt-kernel/bench/bench_k2_write_buffer.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python +# coding=utf-8 +""" +Benchmark write_weight_scale_to_buffer for AMX_K2_MOE_TP (int4 packed weights + bf16 scales). +""" +import json +import os +import platform +import subprocess +import sys +import time + +from tqdm import tqdm + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build")) + +import kt_kernel_ext +import torch + +# Benchmark parameters (single MoE, mirror examples/test_k2_write_buffer.py) +expert_num = 384 +num_experts_per_tok = expert_num +gpu_tp_count = 4 + +warm_up_iter = 3 +test_iter = 7 + +gpu_experts_num = expert_num + +hidden_size = 7168 +intermediate_size = 2048 +group_size = 32 +max_len = 1 + +physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous() +CPUInfer = kt_kernel_ext.CPUInfer(96) + + +def get_git_commit(): + result = {} + try: + commit = ( + subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip() + ) + commit_msg = ( + subprocess.check_output(["git", "log", "-1", "--pretty=%B"]) + .decode("utf-8") + .strip() + ) + result["commit"] = commit + result["commit_message"] = commit_msg + + dirty_output = ( + subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip() + ) + if dirty_output: + result["dirty"] = True + result["dirty_files"] = dirty_output.splitlines() + else: + result["dirty"] = False + except Exception as e: + result["commit"] = None + result["commit_message"] = None + result["dirty"] = None + result["error"] = str(e) + return result + + +def get_system_info(): + info = {} + uname = platform.uname() + info["system_name"] = uname.system + info["node_name"] = uname.node + + cpu_model = None + if os.path.exists("/proc/cpuinfo"): + try: + with open("/proc/cpuinfo", "r") as f: + for line in f: + if "model name" in line: + cpu_model = line.split(":", 1)[1].strip() + break + except Exception as e: + cpu_model = f"Error: {e}" + info["cpu_model"] = cpu_model + + mem_total_gb = None + if os.path.exists("/proc/meminfo"): + try: + with open("/proc/meminfo", "r") as f: + for line in f: + if "MemTotal" in line: + mem_kb = float(line.split(":", 1)[1].split()[0]) + mem_total_gb = round(mem_kb / (1024 * 1024), 2) + break + except Exception as e: + mem_total_gb = f"Error: {e}" + info["memory_size_GB"] = mem_total_gb + + info["cpu_core_count"] = os.cpu_count() + + sockets = set() + if os.path.exists("/proc/cpuinfo"): + try: + with open("/proc/cpuinfo", "r") as f: + for line in f: + if "physical id" in line: + sockets.add(line.split(":", 1)[1].strip()) + except Exception: + sockets = set() + info["cpu_socket_count"] = len(sockets) if len(sockets) > 0 else 1 + + return info + + +script_path = os.path.abspath(__file__) +script_dir = os.path.dirname(script_path) +script_name = os.path.splitext(os.path.basename(script_path))[0] +json_path = os.path.join(script_dir, script_name + ".jsonl") + + +def record_results(result, filename=json_path): + with open(filename, "a") as f: + f.write(json.dumps(result) + "\n") + + +def allocate_weights(): + per_mat_weight_bytes = (hidden_size * intermediate_size) // 2 + per_mat_scale_elems = (hidden_size * intermediate_size) // group_size + + gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8) + up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8) + down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8) + + gate_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16) + up_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16) + down_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16) + + return ( + gate_q.contiguous(), + up_q.contiguous(), + down_q.contiguous(), + gate_scale.contiguous(), + up_scale.contiguous(), + down_scale.contiguous(), + per_mat_weight_bytes, + per_mat_scale_elems, + ) + + +def build_moe(): + ( + gate_q, + up_q, + down_q, + gate_scale, + up_scale, + down_scale, + per_mat_weight_bytes, + per_mat_scale_elems, + ) = allocate_weights() + + config = kt_kernel_ext.moe.MOEConfig( + expert_num, num_experts_per_tok, hidden_size, intermediate_size + ) + config.max_len = max_len + config.quant_config.bits = 4 + config.quant_config.group_size = group_size + config.quant_config.zero_point = False + config.pool = CPUInfer.backend_ + + config.gate_proj = gate_q.data_ptr() + config.up_proj = up_q.data_ptr() + config.down_proj = down_q.data_ptr() + config.gate_scale = gate_scale.data_ptr() + config.up_scale = up_scale.data_ptr() + config.down_scale = down_scale.data_ptr() + + moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(config) + CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr())) + CPUInfer.sync() + + # Buffer sizing per TP + weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count + scale_elems_per_expert_per_tp = per_mat_scale_elems // gpu_tp_count + total_weight_bytes_per_tp = gpu_experts_num * weight_bytes_per_expert_per_tp + total_scale_elems_per_tp = gpu_experts_num * scale_elems_per_expert_per_tp + + w13_weight_bufs = [ + torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count) + ] + w13_scale_bufs = [ + torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count) + ] + w2_weight_bufs = [ + torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count) + ] + w2_scale_bufs = [ + torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count) + ] + + buffer_ptrs = { + "w13_weight_ptrs": [buf.data_ptr() for buf in w13_weight_bufs], + "w13_scale_ptrs": [buf.data_ptr() for buf in w13_scale_bufs], + "w2_weight_ptrs": [buf.data_ptr() for buf in w2_weight_bufs], + "w2_scale_ptrs": [buf.data_ptr() for buf in w2_scale_bufs], + } + + buffer_shapes = { + "per_mat_weight_bytes": per_mat_weight_bytes, + "per_mat_scale_elems": per_mat_scale_elems, + "weight_bytes_per_expert_per_tp": weight_bytes_per_expert_per_tp, + "scale_elems_per_expert_per_tp": scale_elems_per_expert_per_tp, + "total_weight_bytes_per_tp": total_weight_bytes_per_tp, + "total_scale_elems_per_tp": total_scale_elems_per_tp, + } + + keep_tensors = { + "gate_q": gate_q, + "up_q": up_q, + "down_q": down_q, + "gate_scale": gate_scale, + "up_scale": up_scale, + "down_scale": down_scale, + "w13_weight_bufs": w13_weight_bufs, + "w13_scale_bufs": w13_scale_bufs, + "w2_weight_bufs": w2_weight_bufs, + "w2_scale_bufs": w2_scale_bufs, + } + + return moe, buffer_ptrs, buffer_shapes, keep_tensors + + +def bench_write_buffer(): + moe, buffer_ptrs, buffer_shapes, keep_tensors = build_moe() + + total_weights = hidden_size * intermediate_size * expert_num * 3 + # Throughput accounting consistent with examples/test_k2_write_buffer.py + bytes_per_call = total_weights // group_size + total_weights // 2 + + # Warm-up + for _ in tqdm(range(warm_up_iter), desc="Warm-up"): + CPUInfer.submit( + moe.write_weight_scale_to_buffer_task( + gpu_tp_count=gpu_tp_count, + gpu_experts_num=gpu_experts_num, + **buffer_ptrs, + ) + ) + CPUInfer.sync() + + total_time = 0 + for _ in tqdm(range(test_iter), desc="Testing"): + start = time.perf_counter() + CPUInfer.submit( + moe.write_weight_scale_to_buffer_task( + gpu_tp_count=gpu_tp_count, + gpu_experts_num=gpu_experts_num, + **buffer_ptrs, + ) + ) + CPUInfer.sync() + end = time.perf_counter() + total_time += end - start + time.sleep(0.6) + print(end - start) + + + + time_per_iter_us = total_time / test_iter * 1e6 + bandwidth_gbs = bytes_per_call * test_iter / total_time / 1e9 + + print("write_weight_scale_to_buffer benchmark") + print("Time(s): ", total_time) + print("Iteration: ", test_iter) + print("Time(us) per iteration: ", time_per_iter_us) + print("Bandwidth: ", bandwidth_gbs, "GB/s") + print("") + + result = { + "op": "write_weight_scale_to_buffer", + "total_time_seconds": total_time, + "iterations": test_iter, + "time_per_iteration_us": time_per_iter_us, + "bandwidth_GBs": bandwidth_gbs, + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), + "test_parameters": { + "expert_num": expert_num, + "hidden_size": hidden_size, + "intermediate_size": intermediate_size, + "group_size": group_size, + "max_len": max_len, + "num_experts_per_tok": num_experts_per_tok, + "gpu_tp_count": gpu_tp_count, + "gpu_experts_num": gpu_experts_num, + "warm_up_iter": warm_up_iter, + "test_iter": test_iter, + "bytes_per_call": bytes_per_call, + }, + "buffer_shapes": buffer_shapes, + "keep_tensors_alive": list(keep_tensors.keys()), + } + result.update(get_git_commit()) + result.update(get_system_info()) + record_results(result) + + +if __name__ == "__main__": + bench_write_buffer() diff --git a/kt-kernel/examples/test_k2_moe_amx.py b/kt-kernel/examples/test_k2_moe_amx.py new file mode 100644 index 00000000..903f9896 --- /dev/null +++ b/kt-kernel/examples/test_k2_moe_amx.py @@ -0,0 +1,319 @@ +import math +import os +import sys +from typing import Dict, Literal + +sys.path.insert(0, os.path.dirname(__file__) + "/../build") + +import torch +import kt_kernel_ext + +torch.manual_seed(42) + +hidden_size = 7168 +intermediate_size = 2048 +max_len = 25600 + +expert_num = 16 +num_experts_per_tok = 8 + +qlen = 1 +layer_num = 1 +CPUInfer = kt_kernel_ext.CPUInfer(40) +validation_iter = 10 +k_group_size = 32 +debug_print_count = 16 + +physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous() + + +def _pattern_uniform(groups: int) -> torch.Tensor: + return torch.full((groups,), 0.02, dtype=torch.float32) + + +def _pattern_alternating(groups: int) -> torch.Tensor: + vals = torch.full((groups,), 0.015, dtype=torch.float32) + vals[1::2] = 0.03 + return vals + + +def _pattern_ramp(groups: int) -> torch.Tensor: + return torch.linspace(0.005, 0.04, steps=groups, dtype=torch.float32) + + +WEIGHT_PATTERNS = { + "uniform_scale": ("All k-groups share the same abs max / scale", _pattern_uniform), + "alternating_scale": ("Alternate small / large abs max per k-group", _pattern_alternating), + "ramp_scale": ("Linearly increasing abs max per k-group", _pattern_ramp), + "random": ("Random bf16 weights (baseline)", None), +} + + +def act_fn(x): + return x / (1.0 + torch.exp(-x)) + + +def mlp_torch(input, gate_proj, up_proj, down_proj): + gate_buf = torch.mm(input, gate_proj.t()) + up_buf = torch.mm(input, up_proj.t()) + print(f"gate_buf: {gate_buf}") + print(f"up_buf: {up_buf}") + intermediate = act_fn(gate_buf) * up_buf + ret = torch.mm(intermediate, down_proj.t()) + print(f"intermediate: {intermediate}") + print(f"mlp output: {ret}") + return ret + + +def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj): + cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num)) + cnts.scatter_(1, expert_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = expert_ids.view(-1).argsort() + sorted_tokens = input[idxs // expert_ids.shape[1]] + + outputs = [] + start_idx = 0 + + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i]) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + t_output = ( + new_x.view(*expert_ids.shape, -1) + .type(weights.dtype) + .mul_(weights.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return t_output + + +def pack_to_int32(value: torch.Tensor, num_bits: int, packed_dim: Literal[0, 1] = 1) -> torch.Tensor: + if value.dtype is not torch.int8: + raise ValueError("Tensor must be torch.int8 before packing") + if not (1 <= num_bits <= 8): + raise ValueError(f"num_bits must be in [1, 8], got {num_bits}") + + offset = 1 << (num_bits - 1) + value = (value + offset).to(torch.uint8) + device = value.device + + pack_factor = 32 // num_bits + + if packed_dim == 0: + value = value.transpose(0, 1) + + rows, cols = value.shape + padded_cols = math.ceil(cols / pack_factor) * pack_factor + pad_len = padded_cols - cols + + if pad_len > 0: + value = torch.nn.functional.pad(value, (0, pad_len)) + + num_groups = padded_cols // pack_factor + + # Use int32 here + reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32) + bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits + packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32) + + if packed_dim == 0: + packed = packed.transpose(0, 1) + + return packed + +def pack_tensor_per_row(q: torch.Tensor, num_bits: int) -> torch.Tensor: + e, rows, cols = q.shape + flat = q.view(e * rows, cols) + packed = pack_to_int32(flat, num_bits) + return packed.view(e, rows, -1).contiguous() + + +def quantize_k2_tensor(weights: torch.Tensor, group_size: int): + """ + Symmetric max-abs/7 quantization per k-group following compressed_tensors packing. + Args: + weights: [expert_num, rows (N), cols (K)] + Returns: + packed_q: int32 tensor storing 8 int4s per element with shape [expert_num, rows * (cols // 8)] + scales: bfloat16 tensor with shape [expert_num, rows * (cols // group_size)] + """ + weights_f32 = weights.to(torch.float32) + e, rows, cols = weights_f32.shape + if cols % group_size != 0 or cols % 2 != 0: + raise ValueError(f"cols ({cols}) must be divisible by group_size ({group_size}) and 2") + + reshaped = weights_f32.view(e, rows, cols // group_size, group_size) + max_abs = reshaped.abs().amax(dim=-1, keepdim=True) + max_abs = torch.clamp(max_abs, min=1e-8) + scales = (max_abs / 7.0).squeeze(-1) + q = torch.round(reshaped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8) + q = q.view(e, rows, cols) + packed = pack_tensor_per_row(q, num_bits=4).view(e, rows, cols // 8).contiguous() + scales = scales.to(torch.bfloat16).contiguous().view(e, rows, cols // group_size).contiguous() + + print(f"Quantized weights: {packed.shape}, scales: {scales.shape}") + print(f"Quantized tensors: \n{packed},\n {scales}") + return packed, scales + + +def build_structured_tensor(shape: torch.Size, pattern: str) -> torch.Tensor: + if pattern == "random": + torch.manual_seed(42) + return (torch.randn(shape, dtype=torch.bfloat16, device="cpu") / 100.0).contiguous() + + e, rows, cols = shape + groups = cols // k_group_size + group_builder = WEIGHT_PATTERNS[pattern][1] + group_vals = group_builder(groups).to(torch.float32) + block = group_vals.view(1, 1, groups, 1).expand(e, rows, groups, k_group_size).clone() + row_signs = torch.where( + (torch.arange(rows) % 2 == 0), + torch.ones(rows, dtype=torch.float32), + -torch.ones(rows, dtype=torch.float32), + ).view(1, rows, 1, 1) + col_offsets = torch.linspace(-0.0005, 0.0005, steps=k_group_size, dtype=torch.float32).view(1, 1, 1, k_group_size) + block = block * row_signs + col_offsets + return block.reshape(shape).to(torch.bfloat16).contiguous() + + +def prepare_k2_quantized_weights(pattern: str) -> Dict[str, torch.Tensor]: + if pattern not in WEIGHT_PATTERNS: + raise ValueError(f"Unknown weight pattern: {pattern}") + + gate_proj = build_structured_tensor((expert_num, intermediate_size, hidden_size), pattern) + up_proj = build_structured_tensor((expert_num, intermediate_size, hidden_size), pattern) + down_proj = build_structured_tensor((expert_num, hidden_size, intermediate_size), pattern) + + gate_q, gate_scales = quantize_k2_tensor(gate_proj, k_group_size) + up_q, up_scales = quantize_k2_tensor(up_proj, k_group_size) + down_q, down_scales = quantize_k2_tensor(down_proj, k_group_size) + + return { + "gate_qweight": gate_q.contiguous(), + "up_qweight": up_q.contiguous(), + "down_qweight": down_q.contiguous(), + "gate_scales": gate_scales.contiguous(), + "up_scales": up_scales.contiguous(), + "down_scales": down_scales.contiguous(), + "original_fp16": { + "gate_proj": gate_proj.to(torch.float16).contiguous(), + "up_proj": up_proj.to(torch.float16).contiguous(), + "down_proj": down_proj.to(torch.float16).contiguous(), + }, + } + + +def build_moes_from_quantized_data(quant_data: Dict[str, torch.Tensor]): + moes = [] + with torch.inference_mode(mode=True): + for _ in range(layer_num): + config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0) + config.max_len = max_len + config.quant_config.bits = 4 + config.quant_config.group_size = k_group_size + config.quant_config.zero_point = False + + config.gate_proj = quant_data["gate_qweight"].data_ptr() + config.up_proj = quant_data["up_qweight"].data_ptr() + config.down_proj = quant_data["down_qweight"].data_ptr() + + config.gate_scale = quant_data["gate_scales"].data_ptr() + config.up_scale = quant_data["up_scales"].data_ptr() + config.down_scale = quant_data["down_scales"].data_ptr() + config.pool = CPUInfer.backend_ + + moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(config) + CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr())) + CPUInfer.sync() + # CPUInfer.submit(moe.warm_up_task()) + # CPUInfer.sync() + moes.append(moe) + return moes + + +def run_case(pattern: str) -> Dict[str, float]: + print("\n" + "=" * 70) + desc = WEIGHT_PATTERNS[pattern][0] + print(f"Running case: {pattern} -> {desc}") + print("=" * 70) + + quant_data = prepare_k2_quantized_weights(pattern) + moes = build_moes_from_quantized_data(quant_data) + + original_weights = quant_data["original_fp16"] + gate_fp16 = original_weights["gate_proj"] + up_fp16 = original_weights["up_proj"] + down_fp16 = original_weights["down_proj"] + + diffs = [] + with torch.inference_mode(mode=True): + for i in range(validation_iter): + torch.manual_seed(100 + i) + bsz_tensor = torch.tensor([qlen], device="cpu") + expert_ids = torch.stack( + [torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)] + ).contiguous() + weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous() + input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() / 100 + output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous() + + moe = moes[i % layer_num] + CPUInfer.submit( + moe.forward_task( + bsz_tensor.data_ptr(), + num_experts_per_tok, + expert_ids.data_ptr(), + weights.data_ptr(), + input_tensor.data_ptr(), + output.data_ptr(), + False, + ) + ) + CPUInfer.sync() + + input_tensor_fp16 = input_tensor.to(torch.float16) + t_output = moe_torch( + input_tensor_fp16, expert_ids, weights, gate_fp16, up_fp16, down_fp16 + ).to(torch.bfloat16) + + t_output = t_output.flatten() + output = output.flatten() + + diff = torch.mean(torch.abs(output - t_output)) / (torch.mean(torch.abs(t_output)) + 1e-12) + diffs.append(diff.item()) + print(f"[{pattern}] Iteration {i}: relative L1 diff = {diff:.4f}") + print(f" output {output}") + print(f" t_output {t_output}") + + mean_diff = float(sum(diffs) / len(diffs)) + max_diff = float(max(diffs)) + min_diff = float(min(diffs)) + return {"case": pattern, "description": desc, "mean": mean_diff, "max": max_diff, "min": min_diff} + + +def run_k2_moe_test(): + summary_rows = [] + for case_name in WEIGHT_PATTERNS.keys(): + results = run_case(case_name) + summary_rows.append(results) + # break + + print("\n=== Case vs. Relative Error Summary ===") + print(f"{'Case':<20} {'Mean':>10} {'Max':>10} {'Min':>10}") + for row in summary_rows: + print(f"{row['case']:<20} {row['mean']*100:9.2f}% {row['max']*100:9.2f}% {row['min']*100:9.2f}%") + + +if __name__ == "__main__": + run_k2_moe_test() diff --git a/kt-kernel/examples/test_k2_write_buffer.py b/kt-kernel/examples/test_k2_write_buffer.py new file mode 100644 index 00000000..210a4a3d --- /dev/null +++ b/kt-kernel/examples/test_k2_write_buffer.py @@ -0,0 +1,267 @@ +import os +import sys +import time + +import torch +import numpy as np + + +# Ensure we can import the local extension +# REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) +# if REPO_ROOT not in sys.path: +# sys.path.insert(0, REPO_ROOT) + +import kt_kernel_ext +from kt_kernel_ext import CPUInfer + + +def make_cpu_infer(thread_num=80): + return CPUInfer(thread_num) + + +def build_config(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size): + cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size) + cfg.max_len = 1 + cfg.quant_config.bits = 4 + cfg.quant_config.group_size = group_size + cfg.quant_config.zero_point = False + cfg.pool = cpuinfer.backend_ + return cfg + + +def allocate_weights(expert_num, hidden_size, intermediate_size, group_size): + # packed int4 weights: 2 values per byte + per_mat_weight_bytes = (hidden_size * intermediate_size) // 2 + per_mat_scale_elems = (hidden_size * intermediate_size) // group_size + + gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8) + up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8) + down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8) + + gate_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16) + up_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16) + down_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16) + + return ( + gate_q, + up_q, + down_q, + gate_scale, + up_scale, + down_scale, + per_mat_weight_bytes, + per_mat_scale_elems, + ) + + +def main(): + torch.manual_seed(123) + + expert_num = 256 # Total experts + gpu_experts = expert_num # Number of experts on GPU + gpu_tp_count = 2 # Number of TP parts + + num_experts_per_tok = 8 + hidden_size = 7168 + intermediate_size = 2048 + group_size = 32 + + cpuinfer = make_cpu_infer() + cfg = build_config(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size) + + ( + gate_q, + up_q, + down_q, + gate_scale, + up_scale, + down_scale, + per_mat_weight_bytes, + per_mat_scale_elems, + ) = allocate_weights(expert_num, hidden_size, intermediate_size, group_size) + + cfg.gate_proj = gate_q.data_ptr() + cfg.up_proj = up_q.data_ptr() + cfg.down_proj = down_q.data_ptr() + cfg.gate_scale = gate_scale.data_ptr() + cfg.up_scale = up_scale.data_ptr() + cfg.down_scale = down_scale.data_ptr() + + moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(cfg) + + physical_to_logical_map = ( + torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous() + ) + cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr())) + cpuinfer.sync() + + # TP configuration + + # Since weights are col-major, we can directly divide the total size by tp_count + # Each matrix is divided into gpu_tp_count parts in memory order + + # Calculate sizes per TP part (direct division since col-major) + weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count + scale_elems_per_expert_per_tp = per_mat_scale_elems // gpu_tp_count + + # Total sizes for all gpu_experts + total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp + total_scale_elems_per_tp = gpu_experts * scale_elems_per_expert_per_tp + + # Create buffer lists for w13 (gate+up) and w2 (down) + w13_weight_bufs = [] + w13_scale_bufs = [] + w2_weight_bufs = [] + w2_scale_bufs = [] + + for tp_idx in range(gpu_tp_count): + # w13 combines gate and up, so needs 2x the size + w13_weight_bufs.append(torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8)) + w13_scale_bufs.append(torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16)) + w2_weight_bufs.append(torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8)) + w2_scale_bufs.append(torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16)) + + # Get data pointers for all buffers + w13_weight_ptrs = [buf.data_ptr() for buf in w13_weight_bufs] + w13_scale_ptrs = [buf.data_ptr() for buf in w13_scale_bufs] + w2_weight_ptrs = [buf.data_ptr() for buf in w2_weight_bufs] + w2_scale_ptrs = [buf.data_ptr() for buf in w2_scale_bufs] + + print(f"Total experts: {expert_num}, GPU experts: {gpu_experts}") + print(f"GPU TP count: {gpu_tp_count}") + print(f"Original per matrix weight bytes: {per_mat_weight_bytes}") + print(f"Original per matrix scale elements: {per_mat_scale_elems}") + print(f"Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}") + print(f"Scale elements per expert per TP: {scale_elems_per_expert_per_tp}") + print(f"Total weight bytes per TP (w13): {2 * total_weight_bytes_per_tp}") + print(f"Total weight bytes per TP (w2): {total_weight_bytes_per_tp}") + print(f"Total scale elements per TP (w13): {2 * total_scale_elems_per_tp}") + print(f"Total scale elements per TP (w2): {total_scale_elems_per_tp}") + + for i in range(5): + cpuinfer.submit( + moe.write_weight_scale_to_buffer_task( + gpu_tp_count=gpu_tp_count, + gpu_experts_num=gpu_experts, + w13_weight_ptrs=w13_weight_ptrs, + w13_scale_ptrs=w13_scale_ptrs, + w2_weight_ptrs=w2_weight_ptrs, + w2_scale_ptrs=w2_scale_ptrs, + ) + ) + cpuinfer.sync() + + begin_time = time.perf_counter_ns() + cpuinfer.submit( + moe.write_weight_scale_to_buffer_task( + gpu_tp_count=gpu_tp_count, + gpu_experts_num=gpu_experts, + w13_weight_ptrs=w13_weight_ptrs, + w13_scale_ptrs=w13_scale_ptrs, + w2_weight_ptrs=w2_weight_ptrs, + w2_scale_ptrs=w2_scale_ptrs, + ) + ) + cpuinfer.sync() + end_time = time.perf_counter_ns() + elapsed_ms = (end_time - begin_time) / 1000000 + total_weights = hidden_size * intermediate_size * expert_num * 3 + total_bytes = total_weights // group_size + total_weights // 2 + print(f"write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms") + print(f"Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s") + def split_expert_tensor(tensor, chunk): + """Split tensor by experts""" + return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)] + + # Split by experts first + gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes) + up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes) + down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes) + + gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems) + up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems) + down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems) + + # CPU TP count is always 2 in this test setup (one TP per NUMA node) + cpu_tp_count = 2 + + # Verify buffers for each TP part + for tp_idx in range(gpu_tp_count): + expected_w13_weights = [] + expected_w13_scales = [] + expected_w2_weights = [] + expected_w2_scales = [] + + weight13_per_tp = per_mat_weight_bytes // gpu_tp_count + scale13_per_tp = per_mat_scale_elems // gpu_tp_count + # Process each GPU expert + for expert_idx in range(gpu_experts): + # For w13 (gate and up), the slicing is straightforward + + start_weight = tp_idx * weight13_per_tp + end_weight = (tp_idx + 1) * weight13_per_tp + start_scale = tp_idx * scale13_per_tp + end_scale = (tp_idx + 1) * scale13_per_tp + + # Gate + gate_weight_tp = gate_q_experts[expert_idx][start_weight:end_weight] + gate_scale_tp = gate_scale_experts[expert_idx][start_scale:end_scale] + + # Up + up_weight_tp = up_q_experts[expert_idx][start_weight:end_weight] + up_scale_tp = up_scale_experts[expert_idx][start_scale:end_scale] + + # Down matrix needs special handling because it's sliced column-wise + # We need to reconstruct it from column slices + down_weight_tp_parts = [] + down_scale_tp_parts = [] + + # Iterate through each column to extract the corresponding parts + for col_idx in range(hidden_size): + col_weight_start = col_idx * (intermediate_size // 2) + col_scale_start = col_idx * (intermediate_size // group_size) + + # Direct mapping: each CPU TP corresponds to a GPU TP + tp_slice_weight_size = (intermediate_size // gpu_tp_count) // 2 + tp_slice_scale_size = (intermediate_size // gpu_tp_count) // group_size + + tp_weight_offset = col_weight_start + tp_idx * tp_slice_weight_size + tp_scale_offset = col_scale_start + tp_idx * tp_slice_scale_size + + down_weight_tp_parts.append( + down_q_experts[expert_idx][tp_weight_offset:tp_weight_offset + tp_slice_weight_size] + ) + down_scale_tp_parts.append( + down_scale_experts[expert_idx][tp_scale_offset:tp_scale_offset + tp_slice_scale_size] + ) + + # Concatenate all column slices for this TP + down_weight_tp = torch.cat(down_weight_tp_parts) + down_scale_tp = torch.cat(down_scale_tp_parts) + + expected_w13_weights.append(gate_weight_tp) + expected_w13_weights.append(up_weight_tp) + expected_w13_scales.append(gate_scale_tp) + expected_w13_scales.append(up_scale_tp) + expected_w2_weights.append(down_weight_tp) + expected_w2_scales.append(down_scale_tp) + + # Concatenate all experts for this TP part + expected_w13_weight = torch.cat(expected_w13_weights) + expected_w13_scale = torch.cat(expected_w13_scales) + expected_w2_weight = torch.cat(expected_w2_weights) + expected_w2_scale = torch.cat(expected_w2_scales) + + print(f"=== Checking TP part {tp_idx} ===") + + # Assert all checks pass + assert torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight), f"w13 weight bytes mismatch for TP {tp_idx}" + assert torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale), f"w13 scale values mismatch for TP {tp_idx}" + assert torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight), f"w2 weight bytes mismatch for TP {tp_idx}" + assert torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale), f"w2 scale values mismatch for TP {tp_idx}" + + print(f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts from total {expert_num} experts") + + +if __name__ == "__main__": + main() diff --git a/kt-kernel/ext_bindings.cpp b/kt-kernel/ext_bindings.cpp index 687005af..cd7727a5 100644 --- a/kt-kernel/ext_bindings.cpp +++ b/kt-kernel/ext_bindings.cpp @@ -36,6 +36,7 @@ static const bool _is_plain_ = false; #if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL) #include "operators/amx/awq-moe.hpp" +#include "operators/amx/k2-moe.hpp" #include "operators/amx/la/amx_kernels.hpp" #include "operators/amx/moe.hpp" #endif @@ -43,6 +44,7 @@ static const bool _is_plain_ = false; #include #include +#include #include "operators/kvcache/kvcache.h" #include "operators/llamafile/linear.h" @@ -225,7 +227,9 @@ void bind_moe_module(py::module_& moe_module, const char* name) { using MoeClass = TP_MOE; using MoeBindings = MOEBindings; - py::class_>(moe_module, name) + auto moe_cls = py::class_>(moe_module, name); + + moe_cls .def(py::init()) .def("warm_up_task", &MoeBindings::WarmUpBindings::cpuinfer_interface) .def("load_weights_task", @@ -244,6 +248,53 @@ void bind_moe_module(py::module_& moe_module, const char* name) { .def("warm_up", &MoeClass::warm_up) .def("load_weights", &MoeClass::load_weights) .def("forward", &MoeClass::forward_binding); + +#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL) + if constexpr (std::is_same_v>) { + struct WriteWeightScaleToBufferBindings { + struct Args { + CPUInfer* cpuinfer; + MoeClass* moe; + int gpu_tp_count; + int gpu_experts_num; + std::vector w13_weight_ptrs; + std::vector w13_scale_ptrs; + std::vector w2_weight_ptrs; + std::vector w2_scale_ptrs; + }; + + static void inner(void* args) { + Args* args_ = (Args*)args; + args_->cpuinfer->enqueue(&MoeClass::write_weight_scale_to_buffer, args_->moe, + args_->gpu_tp_count, args_->gpu_experts_num, + args_->w13_weight_ptrs, args_->w13_scale_ptrs, + args_->w2_weight_ptrs, args_->w2_scale_ptrs); + } + + static std::pair cpuinfer_interface(std::shared_ptr moe, + int gpu_tp_count, int gpu_experts_num, + py::list w13_weight_ptrs, py::list w13_scale_ptrs, + py::list w2_weight_ptrs, py::list w2_scale_ptrs) { + // Convert Python lists to std::vector + std::vector w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec; + + for (auto item : w13_weight_ptrs) w13_weight_vec.push_back(py::cast(item)); + for (auto item : w13_scale_ptrs) w13_scale_vec.push_back(py::cast(item)); + for (auto item : w2_weight_ptrs) w2_weight_vec.push_back(py::cast(item)); + for (auto item : w2_scale_ptrs) w2_scale_vec.push_back(py::cast(item)); + + Args* args = new Args{nullptr, moe.get(), gpu_tp_count, gpu_experts_num, + w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec}; + return std::make_pair((intptr_t)&inner, (intptr_t)args); + } + }; + + moe_cls.def("write_weight_scale_to_buffer_task", &WriteWeightScaleToBufferBindings::cpuinfer_interface, + py::arg("gpu_tp_count"), py::arg("gpu_experts_num"), + py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"), + py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs")); + } +#endif } PYBIND11_MODULE(kt_kernel_ext, m) { @@ -513,6 +564,7 @@ PYBIND11_MODULE(kt_kernel_ext, m) { bind_moe_module>(moe_module, "AMXInt4_MOE"); bind_moe_module>(moe_module, "AMXInt4_1_MOE"); bind_moe_module>(moe_module, "AMXInt4_1KGroup_MOE"); + bind_moe_module>(moe_module, "AMXInt4_KGroup_MOE"); #endif #if defined(USE_MOE_KERNEL) bind_moe_module>(moe_module, "Int8_KERNEL_MOE"); diff --git a/kt-kernel/operators/amx/k2-moe.hpp b/kt-kernel/operators/amx/k2-moe.hpp new file mode 100644 index 00000000..c6f49244 --- /dev/null +++ b/kt-kernel/operators/amx/k2-moe.hpp @@ -0,0 +1,929 @@ +/** + * @Description : Skeleton for K2 AMX MoE operator. + * @Author : Codex + * @Date : 2024-07-22 + * @Version : 0.1.0 + * @LastEditors : Codex + * @LastEditTime : 2024-07-22 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ +#ifndef CPUINFER_OPERATOR_AMX_K2_MOE_H +#define CPUINFER_OPERATOR_AMX_K2_MOE_H + +// #define DEBUG_K2_MOE + +#include +#include +#include +// #define FORWARD_TIME_PROFILE +// #define FORWARD_TIME_REPORT + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../cpu_backend/shared_mem_buffer.h" +#include "../../cpu_backend/worker_pool.h" +#include "../common.hpp" +#include "../moe-tp.hpp" +#include "la/amx.hpp" +#include "llama.cpp/ggml.h" + +template +class AMX_K2_MOE_TP { + private: + int tp_part_idx = 0; + + void* gate_proj_ = nullptr; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] + void* up_proj_ = nullptr; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] + void* down_proj_ = nullptr; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)] + + ggml_bf16_t* m_local_input_ = nullptr; // [num_experts_per_tok * max_len * hidden_size] + ggml_bf16_t* m_local_gate_output_ = nullptr; // [num_experts_per_tok * max_len * intermediate_size] + ggml_bf16_t* m_local_up_output_ = nullptr; // [num_experts_per_tok * max_len * intermediate_size] + ggml_bf16_t* m_local_down_output_ = nullptr; // [num_experts_per_tok * max_len * hidden_size] + + std::vector> m_local_pos_; // [max_len, num_experts_per_tok] + std::vector m_local_num_; // [expert_num] + std::vector m_expert_id_map_; // [expert_num] + std::vector m_local_input_ptr_; // [expert_num] + std::vector m_local_gate_output_ptr_; // [expert_num] + std::vector m_local_up_output_ptr_; // [expert_num] + std::vector m_local_down_output_ptr_; // [expert_num] + + std::vector> gate_up_ba_; + std::vector> gate_bb_; + std::vector> gate_bc_; + std::vector> up_bb_; + std::vector> up_bc_; + std::vector> down_ba_; + std::vector> down_bb_; + std::vector> down_bc_; +#ifdef CHECK + char verify_bb[100000000]; + char check_bb[100000000]; + uint8_t compare_expers = 3; +#endif + +#ifdef CHECK + inline void load_check() { + // TODO: implement load_check for verification. + } + + void verify_load_right() { + // TODO: implement verification helpers. + } +#endif + + inline void dump_buffer_b(const std::string &quantization_type, int expert_idx, const std::string &matrix_type, + typename T::BufferB *buffer) { + auto &quant_config = config_.quant_config; + int &group_size = quant_config.group_size; + + printf("[DUMP_BUFFER_B] TP%d %s Expert%d %s:\n", tp_part_idx, quantization_type.c_str(), expert_idx, + matrix_type.c_str()); + + // Calculate dimensions based on matrix type + int rows, cols, num_groups; + size_t scale_elem_count; + if (matrix_type == "gate" || matrix_type == "up") { + rows = config_.intermediate_size; + cols = config_.hidden_size; + num_groups = cols / group_size; + scale_elem_count = num_groups * rows; + } else { // down + rows = config_.hidden_size; + cols = config_.intermediate_size; + num_groups = cols / group_size; + scale_elem_count = num_groups * rows; + } + + // Dump scales (as float) + printf(" Scales[first 16]: "); + for (int i = 0; i < std::min(16, (int)scale_elem_count); i++) { + printf("%.6f ", buffer->d[i]); + } + printf("\n"); + + if (scale_elem_count > 16) { + printf(" Scales[last 16]: "); + int start_idx = std::max(0, (int)scale_elem_count - 16); + for (int i = start_idx; i < (int)scale_elem_count; i++) { + printf("%.6f ", buffer->d[i]); + } + printf("\n"); + } + // Dump quantized weights (as hex uint8) + size_t weight_size = (rows * cols) / 2; // INT4 packed + uint8_t *weight_ptr = (uint8_t *)buffer->b; + + printf(" Weights[first 32 bytes]: "); + for (int i = 0; i < std::min(32, (int)weight_size); i++) { + printf("%02x ", weight_ptr[i]); + } + printf("\n"); + + if (weight_size > 32) { + printf(" Weights[last 32 bytes]: "); + int start_idx = std::max(32, (int)weight_size - 32); + for (int i = start_idx; i < (int)weight_size; i++) { + printf("%02x ", weight_ptr[i]); + } + printf("\n"); + } + + printf(" Matrix dimensions: %dx%d, Groups: %d, Group size: %d, Scale elements: %zu\n", rows, cols, num_groups, + group_size, scale_elem_count); + printf("\n"); + fflush(stdout); + } + +#ifdef FORWARD_TIME_REPORT + std::chrono::time_point last_now; +#endif + + public: + using input_t = ggml_bf16_t; + using output_t = float; + GeneralMOEConfig config_; + static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE; + + AMX_K2_MOE_TP(GeneralMOEConfig config, int tp_part_idx_) { + auto& quant_config = config.quant_config; + int& group_size = quant_config.group_size; + if (quant_config.group_size == 0 || quant_config.zero_point) { + throw std::runtime_error("Kimi-K2 MoE only support KGroup Int4"); + } + printf("Creating AMX_K2_MOE_TP %d at numa %d\n", tp_part_idx_, numa_node_of_cpu(sched_getcpu())); + auto& load = config.load; + auto& save = config.save; + if (load && config.path == "") { + load = false; + } + + this->tp_part_idx = tp_part_idx_; + config_ = config; + gate_proj_ = config_.gate_proj; + up_proj_ = config_.up_proj; + down_proj_ = config_.down_proj; + + MemoryRequest mem_requests; + mem_requests.append_pointer( + &m_local_input_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size); + mem_requests.append_pointer(&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * + config_.max_len * config_.intermediate_size); + mem_requests.append_pointer(&m_local_up_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * + config_.max_len * config_.intermediate_size); + mem_requests.append_pointer(&m_local_down_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * + config_.max_len * config_.hidden_size); + + m_local_pos_.resize(config_.max_len); + for (int i = 0; i < config_.max_len; i++) { + m_local_pos_[i].resize(config_.num_experts_per_tok); + } + m_expert_id_map_.resize(config_.expert_num); + m_local_num_.resize(config_.expert_num); + m_local_input_ptr_.resize(config_.expert_num); + m_local_gate_output_ptr_.resize(config_.expert_num); + m_local_up_output_ptr_.resize(config_.expert_num); + m_local_down_output_ptr_.resize(config_.expert_num); + + for (size_t i = 0; i < config_.expert_num; i++) { + gate_up_ba_.push_back( + std::make_shared(config_.max_len, config_.hidden_size, group_size, nullptr)); + gate_bc_.push_back(std::make_shared(config_.max_len, config_.intermediate_size, nullptr)); + up_bc_.push_back(std::make_shared(config_.max_len, config_.intermediate_size, nullptr)); + down_ba_.push_back( + std::make_shared(config_.max_len, config_.intermediate_size, group_size, nullptr)); + down_bc_.push_back(std::make_shared(config_.max_len, config_.hidden_size, nullptr)); + + void* gate_bb_ptr = + std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, group_size)); + gate_bb_.push_back(std::make_shared(config_.intermediate_size, config_.hidden_size, + group_size, gate_bb_ptr)); + + void* up_bb_ptr = + std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, group_size)); + up_bb_.push_back( + std::make_shared(config_.intermediate_size, config_.hidden_size, group_size, up_bb_ptr)); + + void* down_bb_ptr = + std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, group_size)); + down_bb_.push_back(std::make_shared(config_.hidden_size, config_.intermediate_size, + group_size, down_bb_ptr)); + } + for (int i = 0; i < config_.expert_num; i++) { + mem_requests.append_function([this, i](void* new_ptr) { gate_up_ba_[i]->set_data(new_ptr); }, + T::BufferA::required_size(config_.max_len, config_.hidden_size, group_size)); + mem_requests.append_function([this, i](void* new_ptr) { gate_bc_[i]->set_data(new_ptr); }, + T::BufferC::required_size(config_.max_len, config_.intermediate_size)); + mem_requests.append_function([this, i](void* new_ptr) { up_bc_[i]->set_data(new_ptr); }, + T::BufferC::required_size(config_.max_len, config_.intermediate_size)); + mem_requests.append_function([this, i](void* new_ptr) { down_ba_[i]->set_data(new_ptr); }, + T::BufferA::required_size(config_.max_len, config_.intermediate_size, group_size)); + mem_requests.append_function([this, i](void* new_ptr) { down_bc_[i]->set_data(new_ptr); }, + T::BufferC::required_size(config_.max_len, config_.hidden_size)); + } + shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests); + } + + ~AMX_K2_MOE_TP() = default; + + void load_weights() { + auto& quant_config = config_.quant_config; + int& group_size = quant_config.group_size; + const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map; + auto pool = config_.pool->get_subpool(tp_part_idx); + + if (quant_config.group_size == 0 || quant_config.zero_point) { + throw std::runtime_error("Kimi AVX MOE only support KGroup Int4."); + } + if (config_.gate_scale == nullptr) { + throw std::runtime_error("Kimi AVX MOE only support load native weight."); + } + // load weight + int nth = T::recommended_nth(config_.intermediate_size); + pool->do_work_stealing_job( + nth * config_.expert_num, nullptr, + [this, nth, physical_to_logical_map](int task_id) { + uint64_t expert_idx = task_id / nth; + uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx); + int ith = task_id % nth; + // gate part + gate_bb_[expert_idx]->from_raw_mat( + (uint8_t*)config_.gate_proj + + ((logical_expert_id * config_.intermediate_size * config_.hidden_size) >> 1), + ith, nth); + // up part + up_bb_[expert_idx]->from_raw_mat( + (uint8_t*)config_.up_proj + + ((logical_expert_id * config_.intermediate_size * config_.hidden_size) >> 1), + ith, nth); + }, + nullptr); + + nth = T::recommended_nth(config_.hidden_size); + pool->do_work_stealing_job( + nth * config_.expert_num, nullptr, + [this, nth, physical_to_logical_map](int task_id) { + uint64_t expert_idx = task_id / nth; + uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx); + int ith = task_id % nth; + // down part + down_bb_[expert_idx]->from_raw_mat( + (uint8_t*)config_.down_proj + + ((logical_expert_id * config_.hidden_size * config_.intermediate_size) >> 1), + ith, nth); + }, + nullptr); + + pool->do_work_stealing_job( + config_.expert_num, nullptr, + [this, physical_to_logical_map](int task_id) { + uint64_t expert_idx = task_id; + uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx); + size_t scale_elem_count = + (config_.hidden_size * config_.intermediate_size) / config_.quant_config.group_size; + + // convert scales from BF16 to FP32 + convert_or_copy(gate_bb_[expert_idx]->d, + (ggml_bf16_t*)config_.gate_scale + (logical_expert_id * scale_elem_count), + scale_elem_count); + convert_or_copy(up_bb_[expert_idx]->d, + (ggml_bf16_t*)config_.up_scale + (logical_expert_id * scale_elem_count), + scale_elem_count); + convert_or_copy(down_bb_[expert_idx]->d, + (ggml_bf16_t*)config_.down_scale + (logical_expert_id * scale_elem_count), + scale_elem_count); + }, + nullptr); + // dump_buffer_b("native", 0, "down", down_bb_[0].get()); + } + + // Reconstruct weights for all experts to the output buffers + // This function handles the TP-specific portion of the reconstruction for all experts + void write_weights_to_buffer(int gpu_tp_count, int cpu_tp_count, int num_experts, const GeneralMOEConfig& full_config, + const std::vector& w13_weight_ptrs, + const std::vector& w13_scale_ptrs, + const std::vector& w2_weight_ptrs, + const std::vector& w2_scale_ptrs) const { + const int group_size = config_.quant_config.group_size; + auto pool = config_.pool->get_subpool(tp_part_idx); + + // Calculate sizes for CPU TP part (this instance) + size_t cpu_tp_weight_elem_count = (size_t)config_.intermediate_size * config_.hidden_size; + size_t cpu_tp_weight_bytes = cpu_tp_weight_elem_count / 2; // int4 packing + size_t cpu_tp_scale_elem_count = cpu_tp_weight_elem_count / group_size; + + // Calculate sizes for GPU TP part + size_t gpu_tp_weight_elem_count = (size_t)full_config.intermediate_size * full_config.hidden_size / gpu_tp_count; + size_t gpu_tp_weight_bytes = gpu_tp_weight_elem_count / 2; // int4 packing + size_t gpu_tp_scale_elem_count = gpu_tp_weight_elem_count / group_size; + + // Determine mapping: which GPU TP parts should this CPU TP part write to? + // Since weights are col-major and we slice directly by memory order: + // - If cpu_tp_count >= gpu_tp_count: multiple(or one) CPU TPs write to one GPU TP + // - If cpu_tp_count < gpu_tp_count: one CPU TP writes to multiple GPU TPs + if (cpu_tp_count >= gpu_tp_count) { + // Multiple CPU TPs map to one GPU TP + int target_gpu_tp = tp_part_idx / (cpu_tp_count / gpu_tp_count); + int local_idx = tp_part_idx % (cpu_tp_count / gpu_tp_count); + + // Get pointers for this GPU TP part + uint8_t* w13_weight_dst = (uint8_t*)w13_weight_ptrs[target_gpu_tp]; + ggml_bf16_t* w13_scale_dst = (ggml_bf16_t*)w13_scale_ptrs[target_gpu_tp]; + uint8_t* w2_weight_dst = (uint8_t*)w2_weight_ptrs[target_gpu_tp]; + ggml_bf16_t* w2_scale_dst = (ggml_bf16_t*)w2_scale_ptrs[target_gpu_tp]; + + // Calculate offset within the GPU TP buffer + size_t offset_in_gpu_weight = local_idx * cpu_tp_weight_bytes; + size_t offset_in_gpu_scale = local_idx * cpu_tp_scale_elem_count; + + // Process only the first num_experts experts (GPU experts) + int nth = T::recommended_nth(config_.intermediate_size); + nth = 1; + pool->do_work_stealing_job( + nth * num_experts, nullptr, + [&, this](int task_id) { + int expert_id = task_id / nth; + // int ith = task_id % nth; + // auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth); + + // Calculate base offsets for this expert in the GPU buffers + // For w13: each expert has gate+up, so the offset needs to account for 2x size + size_t w13_expert_base_weight = expert_id * 2 * gpu_tp_weight_bytes; + size_t w13_expert_base_scale = expert_id * 2 * gpu_tp_scale_elem_count; + size_t w2_expert_base_weight = expert_id * gpu_tp_weight_bytes; + size_t w2_expert_base_scale = expert_id * gpu_tp_scale_elem_count; + + // Gate (first part of w13 for this expert) + uint8_t* gate_weight_src = (uint8_t*)gate_bb_[expert_id]->b; + float* gate_scale_src = gate_bb_[expert_id]->d; + std::memcpy(w13_weight_dst + w13_expert_base_weight + offset_in_gpu_weight, + gate_weight_src, cpu_tp_weight_bytes); + convert_or_copy((ggml_bf16_t*)(w13_scale_dst + w13_expert_base_scale + offset_in_gpu_scale), + gate_scale_src, cpu_tp_scale_elem_count); + + // Up (second part of w13 for this expert, immediately after gate) + uint8_t* up_weight_src = (uint8_t*)up_bb_[expert_id]->b; + float* up_scale_src = up_bb_[expert_id]->d; + std::memcpy(w13_weight_dst + w13_expert_base_weight + offset_in_gpu_weight + gpu_tp_weight_bytes, + up_weight_src, cpu_tp_weight_bytes); + convert_or_copy((ggml_bf16_t*)(w13_scale_dst + w13_expert_base_scale + offset_in_gpu_scale + gpu_tp_scale_elem_count), + up_scale_src, cpu_tp_scale_elem_count); + + // Down (w2) - need to handle column-wise slicing + // The down matrix is transposed compared to gate/up, so we need to extract by columns + // When multiple CPU TPs map to one GPU TP, each CPU TP has a slice of intermediate dimension + // CPU TP internal layout: each column has config_.intermediate_size elements + // GPU expects: each column has full_config.intermediate_size elements + size_t cpu_tps_per_gpu = cpu_tp_count / gpu_tp_count; + + for (size_t col = 0; col < config_.hidden_size; col++) { + // GPU buffer column width is full_config.intermediate_size / gpu_tp_count + size_t gpu_col_offset = col * ((full_config.intermediate_size / gpu_tp_count) >> 1); + size_t cpu_col_offset = col * (config_.intermediate_size >> 1); + size_t gpu_col_slice_offset = local_idx * (config_.intermediate_size >> 1); + + std::memcpy(w2_weight_dst + w2_expert_base_weight + gpu_col_offset + gpu_col_slice_offset, + (uint8_t*)down_bb_[expert_id]->b + cpu_col_offset, + config_.intermediate_size / 2); + + // Same for scales + size_t gpu_scale_col_offset = col * ((full_config.intermediate_size / gpu_tp_count) / group_size); + size_t cpu_scale_col_offset = col * (config_.intermediate_size / group_size); + size_t gpu_scale_slice_offset = local_idx * (config_.intermediate_size / group_size); + + convert_or_copy((ggml_bf16_t*)(w2_scale_dst + w2_expert_base_scale + gpu_scale_col_offset + gpu_scale_slice_offset), + down_bb_[expert_id]->d + cpu_scale_col_offset, + config_.intermediate_size / group_size); + } + }, + nullptr); + } else { + // cpu_tp_count < gpu_tp_count: one CPU TP writes to multiple GPU TPs + // Each CPU TP part contains data for multiple GPU TP parts + int gpu_tps_per_cpu_tp = gpu_tp_count / cpu_tp_count; + + // This CPU TP part writes to GPU TP indices: [start_gpu_tp, start_gpu_tp + gpu_tps_per_cpu_tp) + int start_gpu_tp = tp_part_idx * gpu_tps_per_cpu_tp; + + // Size of data per GPU TP within this CPU TP + size_t data_per_gpu_tp_weight = cpu_tp_weight_bytes / gpu_tps_per_cpu_tp; + size_t data_per_gpu_tp_scale = cpu_tp_scale_elem_count / gpu_tps_per_cpu_tp; + + // Process all experts for this GPU TP + pool->do_work_stealing_job( + gpu_tps_per_cpu_tp * num_experts, nullptr, + [&, this](int task_id) { + int expert_id = task_id % num_experts; + int local_gpu_idx = task_id / num_experts; + int gpu_tp_idx = start_gpu_tp + local_gpu_idx; + + // Get pointers for this GPU TP part + uint8_t* w13_weight_dst = (uint8_t*)w13_weight_ptrs[gpu_tp_idx]; + ggml_bf16_t* w13_scale_dst = (ggml_bf16_t*)w13_scale_ptrs[gpu_tp_idx]; + uint8_t* w2_weight_dst = (uint8_t*)w2_weight_ptrs[gpu_tp_idx]; + ggml_bf16_t* w2_scale_dst = (ggml_bf16_t*)w2_scale_ptrs[gpu_tp_idx]; + + // Calculate offsets within CPU TP buffers + size_t cpu_offset_weight = local_gpu_idx * data_per_gpu_tp_weight; + size_t cpu_offset_scale = local_gpu_idx * data_per_gpu_tp_scale; + + // Calculate offsets for this expert in GPU buffers + // For w13: each expert has gate+up, so the offset needs to account for 2x size + size_t w13_gpu_expert_offset_weight = expert_id * 2 * gpu_tp_weight_bytes; + size_t w13_gpu_expert_offset_scale = expert_id * 2 * gpu_tp_scale_elem_count; + size_t w2_gpu_expert_offset_weight = expert_id * gpu_tp_weight_bytes; + size_t w2_gpu_expert_offset_scale = expert_id * gpu_tp_scale_elem_count; + + // Gate (first part of w13 for this expert) + uint8_t* gate_weight_src = (uint8_t*)gate_bb_[expert_id]->b + cpu_offset_weight; + float* gate_scale_src = gate_bb_[expert_id]->d + cpu_offset_scale; + std::memcpy(w13_weight_dst + w13_gpu_expert_offset_weight, + gate_weight_src, data_per_gpu_tp_weight); + convert_or_copy((ggml_bf16_t*)(w13_scale_dst + w13_gpu_expert_offset_scale), + gate_scale_src, data_per_gpu_tp_scale); + + // Up (second part of w13 for this expert, immediately after gate) + uint8_t* up_weight_src = (uint8_t*)up_bb_[expert_id]->b + cpu_offset_weight; + float* up_scale_src = up_bb_[expert_id]->d + cpu_offset_scale; + std::memcpy(w13_weight_dst + w13_gpu_expert_offset_weight + gpu_tp_weight_bytes, + up_weight_src, data_per_gpu_tp_weight); + convert_or_copy((ggml_bf16_t*)(w13_scale_dst + w13_gpu_expert_offset_scale + gpu_tp_scale_elem_count), + up_scale_src, data_per_gpu_tp_scale); + + // Down (w2) - need to handle column-wise slicing + // The down matrix is transposed compared to gate/up, so we need to extract by columns + for (size_t col = 0; col < config_.hidden_size; col++) { + // Calculate the offset within the column for this GPU TP part + size_t col_offset_weight = (col * config_.intermediate_size / 2) + (local_gpu_idx * data_per_gpu_tp_weight / config_.hidden_size); + size_t col_offset_scale = (col * (config_.intermediate_size / group_size)) + (local_gpu_idx * data_per_gpu_tp_scale / config_.hidden_size); + + // Copy weights column by column + std::memcpy(w2_weight_dst + w2_gpu_expert_offset_weight + (col * (config_.intermediate_size / gpu_tps_per_cpu_tp) / 2), + (uint8_t*)down_bb_[expert_id]->b + col_offset_weight, + (config_.intermediate_size / gpu_tps_per_cpu_tp) / 2); + + // Copy scales column by column + convert_or_copy((ggml_bf16_t*)(w2_scale_dst + w2_gpu_expert_offset_scale + col * ((config_.intermediate_size / gpu_tps_per_cpu_tp) / group_size)), + down_bb_[expert_id]->d + col_offset_scale, + (config_.intermediate_size / gpu_tps_per_cpu_tp) / group_size); + } + }, + nullptr); + } + } + + void warm_up() { + int qlen = config_.max_len; + std::vector input(sizeof(ggml_bf16_t) * qlen * config_.hidden_size); + std::vector output(sizeof(ggml_bf16_t) * qlen * config_.hidden_size); + std::vector expert_ids(qlen * config_.num_experts_per_tok); + std::vector weights(qlen * config_.num_experts_per_tok); + for (int i = 0; i < qlen * config_.num_experts_per_tok; i++) { + expert_ids[i] = i % config_.expert_num; + weights[i] = 0.01; + } + forward(qlen, config_.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data()); + } + + void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) { + if (qlen > 1) { + forward_prefill(qlen, k, expert_ids, weights, input, output); + } else { + forward_decode(k, expert_ids, weights, input, output); + } + } + +#ifndef DIRECT_OR_POOL_BY_QLEN +#define DIRECT_OR_POOL_BY_QLEN(var, fn) \ + do { \ + if (qlen < 10) { \ + for (int i = 0; i < (var); i++) { \ + (fn)(i); \ + } \ + } else { \ + pool->do_work_stealing_job((var), nullptr, (fn), nullptr); \ + } \ + } while (0) +#endif + + void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, + void* output) { + for (int i = 0; i < qlen; i ++) + forward_decode(k, expert_ids + i * k, weights + i * k, (ggml_bf16_t*)input + i * config_.hidden_size, (float*)output + i * config_.hidden_size); + } + + void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) { + int qlen = 1; + auto pool = config_.pool->get_subpool(tp_part_idx); + auto& quant_config = config_.quant_config; + int& group_size = quant_config.group_size; +#ifdef FORWARD_TIME_PROFILE + auto start_time = std::chrono::high_resolution_clock::now(); + auto last = start_time; + // 用于保存各阶段耗时(单位:微秒) + long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0; + long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0; + int max_local_num = 0; // 记录最大的 local num +#endif + + int activated_expert = 0; + for (int i = 0; i < k; i++) { + if (expert_ids[i] < config_.num_gpu_experts || expert_ids[i] >= config_.expert_num) { + continue; + } + m_expert_id_map_[activated_expert] = expert_ids[i]; + activated_expert++; + } + + size_t offset = 0; + for (int i = 0; i < activated_expert; i++) { + auto expert_idx = m_expert_id_map_[i]; + m_local_gate_output_ptr_[expert_idx] = m_local_gate_output_ + offset * config_.intermediate_size; + m_local_up_output_ptr_[expert_idx] = m_local_up_output_ + offset * config_.intermediate_size; + m_local_down_output_ptr_[expert_idx] = m_local_down_output_ + offset * config_.hidden_size; + offset += qlen; + } + + gate_up_ba_[0]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1); + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + q_input_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + // calc gate & up + int nth = T::recommended_nth(config_.intermediate_size); + pool->do_work_stealing_job( + nth * activated_expert * 2, [](int _) { T::config(); }, + [this, nth, qlen](int task_id2) { + int& group_size = config_.quant_config.group_size; + int task_id = task_id2 / 2; + bool do_up = task_id2 % 2; + int expert_idx = m_expert_id_map_[task_id / nth]; + + int ith = task_id % nth; + if (do_up) { + amx::vec_mul_kgroup(qlen, config_.intermediate_size, config_.hidden_size, group_size, gate_up_ba_[0], + up_bb_[expert_idx], up_bc_[expert_idx], ith, nth); + up_bc_[expert_idx]->to_mat(qlen, m_local_up_output_ptr_[expert_idx], ith, nth); + } else { + amx::vec_mul_kgroup(qlen, config_.intermediate_size, config_.hidden_size, group_size, gate_up_ba_[0], + gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth); + gate_bc_[expert_idx]->to_mat(qlen, m_local_gate_output_ptr_[expert_idx], ith, nth); + } + }, + nullptr); + +#ifdef DEBUG_K2_MOE + if (activated_expert > 0) { + int print_elems = std::min(config_.intermediate_size, 16); + for (int dbg = 0; dbg < activated_expert; ++dbg) { + int sample_expert = m_expert_id_map_[dbg]; + ggml_bf16_t* gate_ptr = m_local_gate_output_ptr_[sample_expert]; + if (gate_ptr == nullptr) { + continue; + } + + printf("[K2][TP %d] gate_out (expert %d, first %d elems): ", tp_part_idx, sample_expert, print_elems); + for (int idx = 0; idx < print_elems; idx++) { + float val = ggml_bf16_to_fp32(gate_ptr[idx]); + printf("%.6f ", val); + } + printf("\n"); + + int tail_start = config_.intermediate_size > print_elems ? config_.intermediate_size - print_elems : 0; + printf("[K2][TP %d] gate_out (expert %d, last %d elems): ", tp_part_idx, sample_expert, print_elems); + for (int idx = 0; idx < print_elems; idx++) { + float val = ggml_bf16_to_fp32(gate_ptr[tail_start + idx]); + printf("%.6f ", val); + } + printf("\n"); + } + } +#endif + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + up_gate_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + // act + for (int task_id = 0; task_id < nth * activated_expert; task_id++) { + int expert_idx = m_expert_id_map_[task_id / nth]; + int ith = task_id % nth; + auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth); + for (int i = 0; i < qlen; i++) { + ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size]; + ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size]; + for (int j = n_start; j < n_end; j += 32) { + __m512 gate_val0, gate_val1, up_val0, up_val1; + avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1); + avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1); + __m512 result0 = amx::act_fn(gate_val0, up_val0); + __m512 result1 = amx::act_fn(gate_val1, up_val1); + avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j)); + } + } + } + + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + act_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + // quant, get down a + pool->do_work_stealing_job( + activated_expert, nullptr, + [this, qlen](int task_id) { + int expert_idx = m_expert_id_map_[task_id]; + down_ba_[expert_idx]->from_mat(qlen, m_local_gate_output_ptr_[expert_idx], 0, 1); + }, + nullptr); +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + q_down_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + // * down + nth = T::recommended_nth(config_.hidden_size); + pool->do_work_stealing_job( + nth * activated_expert, [](int _) { T::config(); }, + [this, nth, qlen](int task_id) { + int& group_size = config_.quant_config.group_size; + int expert_idx = m_expert_id_map_[task_id / nth]; + int ith = task_id % nth; + amx::vec_mul_kgroup(qlen, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx], + down_bb_[expert_idx], down_bc_[expert_idx], ith, nth); + down_bc_[expert_idx]->to_mat(qlen, m_local_down_output_ptr_[expert_idx], ith, nth); + }, + nullptr); + +#ifdef DEBUG_K2_MOE + if (activated_expert > 0) { + int print_elems = std::min(config_.hidden_size, 16); + for (int dbg = 0; dbg < activated_expert; ++dbg) { + int sample_expert = m_expert_id_map_[dbg]; + ggml_bf16_t* down_ptr = m_local_down_output_ptr_[sample_expert]; + if (down_ptr == nullptr) { + continue; + } + + printf("[K2][TP %d] down_out (expert %d, first %d elems): ", tp_part_idx, sample_expert, print_elems); + for (int idx = 0; idx < print_elems; idx++) { + float val = ggml_bf16_to_fp32(down_ptr[idx]); + printf("%.6f ", val); + } + printf("\n"); + + int tail_start = config_.hidden_size > print_elems ? config_.hidden_size - print_elems : 0; + printf("[K2][TP %d] down_out (expert %d, last %d elems): ", tp_part_idx, sample_expert, print_elems); + for (int idx = 0; idx < print_elems; idx++) { + float val = ggml_bf16_to_fp32(down_ptr[tail_start + idx]); + printf("%.6f ", val); + } + printf("\n"); + } + } +#endif + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + down_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + // get output + for (int e = 0; e < config_.hidden_size; e += 32) { + __m512 x0 = _mm512_setzero_ps(); + __m512 x1 = _mm512_setzero_ps(); + for (int j = 0; j < k; j++) { + if (expert_ids[j] < config_.num_gpu_experts || expert_ids[j] >= config_.expert_num) { + continue; + } + __m512 weight = _mm512_set1_ps(weights[j]); + __m512 down_output0, down_output1; + avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[j]] + + m_local_pos_[0][j] * config_.hidden_size + e), + &down_output0, &down_output1); + x0 = _mm512_fmadd_ps(down_output0, weight, x0); + x1 = _mm512_fmadd_ps(down_output1, weight, x1); + } + auto f32out = (__m512*)((float*)output + e); + f32out[0] = x0; + f32out[1] = x1; + } + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + weight_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } + auto end_time = std::chrono::high_resolution_clock::now(); + auto forward_total_time = std::chrono::duration_cast(end_time - start_time).count(); + // 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen + printf( + "Profiling Results (numa[%d]): activated_expert: %d, q_input: %ld us, " + "up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us\n", + tp_part_idx, activated_expert, q_input_time, up_gate_time, act_time, q_down_time, down_time, weight_time, + forward_total_time); +#endif + } +}; + +template +class TP_MOE> : public TP_MOE_Common> { + public: + using TP_MOE_Common>::TP_MOE_Common; + + void load_weights() { + auto& config = this->config; + auto& tps = this->tps; + auto& tp_count = this->tp_count; + auto pool = config.pool; + const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map; + + if (config.gate_scale == nullptr) { + throw std::runtime_error("K2 MoE only supports Packed Int4 with KGroup Scale"); + } + printf("From Packed Int4 with KGroup Scale\n"); + int& group_size = config.quant_config.group_size; + for (auto i = 0; i < tp_count; i++) { + auto& tpc = tps[i]->config_; + size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size; + tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2]; + tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2]; + tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2]; + + size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size; + + tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)]; + tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)]; + tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)]; + + if (tps[i]->config_.load == false) { + pool->get_subpool(i)->do_work_stealing_job( + tpc.expert_num, nullptr, + [&](int expert_id_) { // weight and scale are all in col majored. + size_t expert_id = expert_map(physical_to_logical_map, expert_id_); + + // weight and scale TP-slicing for gate and up + memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1), + (uint8_t*)config.gate_proj + + ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1), + ((sizeof(uint8_t) * weight_elem_count) >> 1)); + + memcpy((uint8_t*)tpc.up_proj + ((expert_id * weight_elem_count) >> 1), + (uint8_t*)config.up_proj + + ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1), + ((sizeof(uint8_t) * weight_elem_count) >> 1)); + + memcpy((ggml_bf16_t*)tpc.gate_scale + (expert_id * scales_elem_count), + (ggml_bf16_t*)config.gate_scale + + (expert_id * (config.hidden_size / group_size) * config.intermediate_size + + i * scales_elem_count), + sizeof(ggml_bf16_t) * scales_elem_count); + + memcpy((ggml_bf16_t*)tpc.up_scale + (expert_id * scales_elem_count), + (ggml_bf16_t*)config.up_scale + + (expert_id * (config.hidden_size / group_size) * config.intermediate_size + + i * scales_elem_count), + sizeof(ggml_bf16_t) * scales_elem_count); + + // memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count) >> 1), + // (uint8_t*)config.down_proj + + // ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1), + // ((sizeof(uint8_t) * weight_elem_count) >> 1)); + + // memcpy((ggml_bf16_t*)tpc.down_scale + (expert_id * scales_elem_count), + // (ggml_bf16_t*)config.down_scale + + // (expert_id * (config.intermediate_size / group_size) * config.hidden_size + + // i * scales_elem_count), + // sizeof(ggml_bf16_t) * scales_elem_count); + + // weight and scale TP-slicing for down (by column) + for (size_t col = 0; col < config.hidden_size; col++) { + memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1), + (uint8_t*)config.down_proj + ((expert_id * config.intermediate_size * config.hidden_size + + col * config.intermediate_size + i * tpc.intermediate_size) >> + 1), + (sizeof(uint8_t) * tpc.intermediate_size) >> 1); + memcpy((ggml_bf16_t*)tpc.down_scale + (expert_id * scales_elem_count + col * (tpc.intermediate_size / group_size)), + (ggml_bf16_t*)config.down_scale + ((expert_id * (config.intermediate_size / group_size) * config.hidden_size) + + col * (config.intermediate_size / group_size) + i * (tpc.intermediate_size / group_size)), + sizeof(ggml_bf16_t) * (tpc.intermediate_size / group_size)); + } + }, + nullptr); + } + printf("TP %d load weight done.\n", i); + } + + DO_TPS_LOAD_WEIGHTS(pool); + + for (auto i = 0; i < tp_count; i++) { + auto& tpc = tps[i]->config_; + delete[] (uint8_t*)(tpc.gate_proj); + delete[] (uint8_t*)(tpc.up_proj); + delete[] (uint8_t*)(tpc.down_proj); + + delete[] (ggml_bf16_t*)(tpc.gate_scale); + delete[] (ggml_bf16_t*)(tpc.up_scale); + delete[] (ggml_bf16_t*)(tpc.down_scale); + } + + this->weights_loaded = true; + } + + void write_weight_scale_to_buffer(int gpu_tp_count, int gpu_experts_num, + const std::vector& w13_weight_ptrs, + const std::vector& w13_scale_ptrs, + const std::vector& w2_weight_ptrs, + const std::vector& w2_scale_ptrs) { + if (this->weights_loaded == false) { + throw std::runtime_error("Not Loaded"); + } + if (this->tps.empty()) { + throw std::runtime_error("No TP parts initialized"); + } + + // Validate input vector sizes + if (w13_weight_ptrs.size() != gpu_tp_count || w13_scale_ptrs.size() != gpu_tp_count || + w2_weight_ptrs.size() != gpu_tp_count || w2_scale_ptrs.size() != gpu_tp_count) { + throw std::runtime_error("Pointer arrays size must match gpu_tp_count"); + } + + // Each TP part writes to its corresponding buffer + for (int tp_idx = 0; tp_idx < this->tp_count; tp_idx++) { + // Note: w13 combines gate and up projections + // Split w13 pointers for gate and up + this->tps[tp_idx]->write_weights_to_buffer( + gpu_tp_count, this->tp_count, + gpu_experts_num, this->config, + w13_weight_ptrs, w13_scale_ptrs, //gate + up use w13 + w2_weight_ptrs, w2_scale_ptrs); // down uses w2 + } + } + + void merge_results(int qlen, void* output, bool incremental) { + auto pool = this->config.pool; + auto merge_fn = [this, output, incremental](int token_nth) { + auto& local_output_numa = this->local_output_numa; + auto& tp_configs = this->tp_configs; + auto& tp_count = this->tp_count; + auto& config = this->config; + float* merge_to = local_output_numa[0] + token_nth * tp_configs[0].hidden_size; + if (incremental) { + for (int e = 0; e < config.hidden_size; e += 32) { + __m512 x0, x1; + avx512_32xbf16_to_32xfp32((__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e), &x0, &x1); + *((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), x0); + *((__m512*)(merge_to + e + 16)) = _mm512_add_ps(*((__m512*)(merge_to + e + 16)), x1); + } + } + for (int i = 1; i < tp_count; i++) { + float* merge_from = local_output_numa[i] + token_nth * tp_configs[i].hidden_size; + for (int e = 0; e < tp_configs[i].hidden_size; e += 16) { + *((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), *((__m512*)(merge_from + e))); + } + } + for (int e = 0; e < config.hidden_size; e += 32) { + __m512 x0 = *(__m512*)(merge_to + e); + __m512 x1 = *(__m512*)(merge_to + e + 16); + avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e)); + } + }; + for (int i = 0; i < qlen; i++) { + merge_fn(i); + } + } + + void merge_results(int qlen, void* output) { merge_results(qlen, output, false); } +}; + +#endif // CPUINFER_OPERATOR_AMX_K2_MOE_H diff --git a/kt-kernel/operators/amx/la/amx_buffers.hpp b/kt-kernel/operators/amx/la/amx_buffers.hpp index a2bc4883..819ae240 100644 --- a/kt-kernel/operators/amx/la/amx_buffers.hpp +++ b/kt-kernel/operators/amx/la/amx_buffers.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -344,9 +345,6 @@ struct BufferAKGroupImpl { static constexpr int K_STEP = K::K_STEP; static constexpr int K_BLOCK = K::K_BLOCK; - using index_t = Packed2DLayout::index_t; - Packed2DLayout pack; - static size_t required_size(int max_m, int k, int k_group_size) { ASSERT_RELEASE(k % k_group_size == 0, "k must be multiple of k_group_size"); return sizeof(int8_t) * max_m * k + sizeof(float) * max_m * (k / k_group_size); @@ -355,18 +353,12 @@ struct BufferAKGroupImpl { BufferAKGroupImpl(int max_m, int k, int k_group_size, void* ptr) : max_m(max_m), k(k), - k_group_size(k_group_size), - pack({{static_cast(K_STEP), 'c'}, - {static_cast(M_STEP), 'r'}, - {static_cast(k_group_size / K_STEP), 'c'}, - {static_cast(K_BLOCK / k_group_size), 'c'}, - {static_cast(max_m / M_STEP), 'r'}, - {static_cast(k / K_BLOCK), 'c'}}) { + k_group_size(k_group_size) { ASSERT_RELEASE(k % k_group_size == 0, "k must be multiple of k_group_size"); ASSERT_RELEASE(max_m % M_STEP == 0, "max_m must be multiple of M_STEP"); ASSERT_RELEASE(k % K_STEP == 0, "k must be multiple of K_STEP"); ASSERT_RELEASE(K_BLOCK % k_group_size == 0, "K_BLOCK must be multiple of k_group_size"); - ASSERT_RELEASE(k % K_BLOCK == 0, "k must be multiple of K_BLOCK"); + // ASSERT_RELEASE(k % K_BLOCK == 0, "k must be multiple of K_BLOCK"); k_group_count = k / k_group_size; set_data(ptr); @@ -922,6 +914,77 @@ struct BufferBInt4WithZeroImpl { float* get_min(int n, int n_begin) { return mins + n_begin; } }; +// BufferB for Signed Int4 with KGroup Scale (no zero point) +// Used for K2 MoE - signed int4 range: [-8, 7] +template +struct BufferBInt4KGroupImpl { + using dt = typename K::dt; + dt* b; // packed signed int4 weights, col majored + float* d; // scales only (no mins/zero-points), row majored + int n, k, k_group_size, k_group_count; + + static constexpr int N_STEP = K::N_STEP; + static constexpr int K_STEP = K::K_STEP; + static constexpr bool SCALE = true; + + // Size calculation: packed int4 weights + scales (NO mins) + static size_t required_size(int n, int k, int k_group_size) { + return sizeof(int8_t) * n * k / 2 + sizeof(float) * n * (k / k_group_size); + } + + BufferBInt4KGroupImpl(int n, int k, int k_group_size, void* ptr) : n(n), k(k), k_group_size(k_group_size) { + assert(reinterpret_cast(ptr) % 64 == 0); + assert(n % N_STEP == 0); + assert(k % K_STEP == 0); + if (n % N_STEP || k % K_STEP || k % k_group_size) { + printf("BufferBInt4KGroupImpl: n: %d, k: %d, N_STEP: %d, K_STEP: %d, k_group_size: %d\n", n, k, N_STEP, + K_STEP, k_group_size); + throw std::runtime_error("n or k is not aligned to N_STEP or K_STEP"); + } + k_group_count = k / k_group_size; + b = reinterpret_cast(ptr); + d = reinterpret_cast(offset_pointer(b, n * k / 2)); + } + + // Load from packed signed int4 format + // Input: proj is packed int4 weights (2 int4 values per byte) + // Each int4 value is in range [-8, 7] (signed) + void from_raw_mat(uint8_t* proj, int ith, int nth) { + auto [n_start, n_end] = K::split_range_n(n, ith, nth); + if (n_start >= n_end) { + return; + } + const size_t row_bytes = static_cast(k) / 2; + const size_t rows = static_cast(n_end - n_start); + uint8_t* dst_weights = reinterpret_cast(b) + n_start * row_bytes; + const uint8_t* src_weights = proj + n_start * row_bytes; + std::memcpy(dst_weights, src_weights, rows * row_bytes); + } + + // Get pointer to submatrix for computation + dt* get_submat(int n, int k, int n_begin, int k_begin) { + const size_t row_bytes = static_cast(k) / 2; + const size_t row_offset = static_cast(n_begin) * row_bytes; + const size_t col_offset = static_cast(k_begin) / 2; + return reinterpret_cast(reinterpret_cast(b) + row_offset + col_offset); + } + + // Get scale pointer for a specific row and k_group + float* get_scale(int n, int n_begin, int k, int k_begin) { + int k_group_idx = k_begin / k_group_size; + return d + n_begin * (k / k_group_size) + k_group_idx; + } + + // Split range for parallel processing + static std::pair split_range_n(int n, int ith, int nth) { + int n_per_thread = (n + nth - 1) / nth; + n_per_thread = (n_per_thread + N_STEP - 1) / N_STEP * N_STEP; + int n_start = std::min(ith * n_per_thread, n); + int n_end = std::min(n_start + n_per_thread, n); + return {n_start, n_end}; + } +}; + template struct BufferBInt4WithZeroKGroupImpl { using dt = typename K::dt; diff --git a/kt-kernel/operators/amx/la/amx_kernels.hpp b/kt-kernel/operators/amx/la/amx_kernels.hpp index 3e331834..a89d3ddf 100644 --- a/kt-kernel/operators/amx/la/amx_kernels.hpp +++ b/kt-kernel/operators/amx/la/amx_kernels.hpp @@ -1015,8 +1015,9 @@ struct GemmKernel224Int8 { static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba, BufferB* bb) { __m512i* c512 = (__m512i*)c; + int m_block_end = std::min(m - m_begin, M_STEP); if (k_block_begin == 0) { - for (int m_i = 0; m_i < m; m_i++) { + for (int m_i = 0; m_i < m_block_end; m_i++) { c512[m_i * 2] = _mm512_setzero_si512(); c512[m_i * 2 + 1] = _mm512_setzero_si512(); } @@ -1028,7 +1029,7 @@ struct GemmKernel224Int8 { int32_t* a32 = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin); __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin); - for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { + for (int m_i = 0; m_i < m_block_end; m_i++) { for (int k_i = 0; k_i < 16; k_i++) { __m512i ma = _mm512_set1_epi32(a32[m_i * 16 + k_i]); for (int n_i = 0; n_i < 2; n_i++) { @@ -1239,8 +1240,9 @@ struct GemmKernel224Int4 { BufferB* bb) { using K = GemmKernel224Int4; __m512i* c512 = (__m512i*)c; + int m_block_end = std::min(m - m_begin, M_STEP); if (k_block_begin == 0) { - for (int m_i = 0; m_i < m; m_i++) { + for (int m_i = 0; m_i < m_block_end; m_i++) { c512[m_i * 2] = _mm512_setzero_si512(); c512[m_i * 2 + 1] = _mm512_setzero_si512(); } @@ -1250,7 +1252,7 @@ struct GemmKernel224Int4 { int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin); int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP); __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin); - for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { + for (int m_i = 0; m_i < m_block_end; m_i++) { for (int k_i = 0; k_i < 16; k_i++) { __m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]); __m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]); @@ -1533,8 +1535,9 @@ struct GemmKernel224Int4_1 { BufferB* bb) { using K = GemmKernel224Int4_1; __m512i* c512 = (__m512i*)c; + int m_block_end = std::min(m - m_begin, M_STEP); if (k_block_begin == 0) { - for (int m_i = 0; m_i < m; m_i++) { + for (int m_i = 0; m_i < m_block_end; m_i++) { c512[m_i * 2] = _mm512_setzero_si512(); c512[m_i * 2 + 1] = _mm512_setzero_si512(); } @@ -1543,7 +1546,7 @@ struct GemmKernel224Int4_1 { int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin); int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP); __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin); - for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { + for (int m_i = 0; m_i < m_block_end; m_i++) { for (int k_i = 0; k_i < 16; k_i++) { __m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]); __m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]); @@ -2193,10 +2196,11 @@ struct GemmKernel224Int4KGroup { BufferB* bb, int k_group_size) { using K = GemmKernel224Int4KGroup; __m512i* c512 = (__m512i*)int_c; + int m_block_end = std::min(m - m_begin, M_STEP); // Initialize int_c to zero at the start of k_group if (k_block_begin % k_group_size == 0) { - for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { + for (int m_i = 0; m_i < m_block_end; m_i++) { c512[m_i * 2] = _mm512_setzero_si512(); c512[m_i * 2 + 1] = _mm512_setzero_si512(); } @@ -2205,7 +2209,7 @@ struct GemmKernel224Int4KGroup { if (k_offset == 0) { int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin); __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin); - for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { + for (int m_i = 0; m_i < m_block_end; m_i++) { for (int k_i = 0; k_i < 16; k_i++) { __m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]); for (int n_i = 0; n_i < 2; n_i++) { @@ -2217,7 +2221,7 @@ struct GemmKernel224Int4KGroup { } else { int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin); __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP); - for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { + for (int m_i = 0; m_i < m_block_end; m_i++) { for (int k_i = 0; k_i < 16; k_i++) { __m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]); for (int n_i = 0; n_i < 2; n_i++) { @@ -2471,8 +2475,9 @@ struct GemmKernel224Int4_1KGroup { BufferB* bb, int k_group_size) { using K = GemmKernel224Int4_1KGroup; __m512i* c512 = (__m512i*)int_c; + int m_block_end = std::min(m - m_begin, M_STEP); if (k_block_begin % k_group_size == 0) { - for (int m_i = 0; m_i < m; m_i++) { + for (int m_i = 0; m_i < m_block_end; m_i++) { c512[m_i * 2] = _mm512_setzero_si512(); c512[m_i * 2 + 1] = _mm512_setzero_si512(); } @@ -2481,7 +2486,7 @@ struct GemmKernel224Int4_1KGroup { if (k_offset == 0) { int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin); __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin); - for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { + for (int m_i = 0; m_i < m_block_end; m_i++) { for (int k_i = 0; k_i < 16; k_i++) { __m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]); for (int n_i = 0; n_i < 2; n_i++) { @@ -2493,7 +2498,7 @@ struct GemmKernel224Int4_1KGroup { } else { int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin); __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP); - for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { + for (int m_i = 0; m_i < m_block_end; m_i++) { for (int k_i = 0; k_i < 16; k_i++) { __m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]); for (int n_i = 0; n_i < 2; n_i++) { @@ -2746,8 +2751,9 @@ struct GemmKernel224Int4_1_LowKGroup { BufferB* bb, int k_group_size) { using K = GemmKernel224Int4_1_LowKGroup; __m512i* c512 = (__m512i*)int_c; + int m_block_end = std::min(m - m_begin, M_STEP); if (k_block_begin % k_group_size == 0) { - for (int m_i = 0; m_i < m; m_i++) { + for (int m_i = 0; m_i < m_block_end; m_i++) { c512[m_i * 2] = _mm512_setzero_si512(); c512[m_i * 2 + 1] = _mm512_setzero_si512(); } @@ -2756,7 +2762,7 @@ struct GemmKernel224Int4_1_LowKGroup { if (k_offset == 0) { int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin); __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin); - for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { + for (int m_i = 0; m_i < m_block_end; m_i++) { for (int k_i = 0; k_i < 16; k_i++) { __m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]); for (int n_i = 0; n_i < 2; n_i++) { @@ -2768,7 +2774,7 @@ struct GemmKernel224Int4_1_LowKGroup { } else { int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin); __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP); - for (int m_i = 0; m_i < m && m_i < M_STEP; m_i++) { + for (int m_i = 0; m_i < m_block_end; m_i++) { for (int k_i = 0; k_i < 16; k_i++) { __m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]); for (int n_i = 0; n_i < 2; n_i++) { @@ -2837,6 +2843,110 @@ struct GemmKernel224Int4_1_LowKGroup { } }; +// K2 Signed Int4 K-group quantization kernel (AVX only, no AMX) +// For K2 MoE - signed int4 range: [-8, 7] +struct GemmKernel224Int4SmallKGroup { + using dt = uint8_t; // packed int4 type + using output_t = int32_t; + static constexpr double ELEMENT_SIZE = 0.5; + static const int VNNI_BLK = 4; + + static const int M_STEP = 1; + static const int N_STEP = 32; + static const int K_STEP = 32; + + static inline const int N_BLOCK = 256; + // K_BLOCK should match k_group_size for proper scaling + static inline const int K_BLOCK = 7168; // Will be overridden by k_group_size + + static std::string name() { return "K2_INT4_KGROUP"; } + static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; } + static std::pair split_range_n(int n, int ith, int nth) { + int n_start = N_BLOCK * ith; + int n_end = std::min(n, N_BLOCK * (ith + 1)); + return {n_start, n_end}; + } + static void config() {} + + alignas(64) static constexpr uint8_t hi_mask_arr[32] = { + 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, + 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0 + }; + + alignas(64) static constexpr uint8_t lo_mask_arr[32] = { + 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, + 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F + }; + + alignas(64) static constexpr uint8_t sign_xor_arr[32] = { + 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, + 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88 + }; + static __m256i hi_mask() { return *((__m256i*)(&hi_mask_arr[0])); } + static __m256i lo_mask() { return *((__m256i*)(&lo_mask_arr[0])); } + static __m256i sign_xor_mask() { return *((__m256i*)(&sign_xor_arr[0])); } + + using BufferA = BufferAKGroupImpl; + using BufferB = BufferBInt4KGroupImpl; // Use new signed int4 buffer + using BufferC = BufferCReduceImpl; + + // K-group aware AVX kernel for signed int4 + static inline __m512i compressed_int4_to_int8_avx512(__m256i b256) { + b256 = _mm256_xor_si256(b256, sign_xor_mask()); + __m256i b_hi = _mm256_and_si256(b256, hi_mask()); + __m256i b_lo = _mm256_slli_epi16(_mm256_andnot_si256(hi_mask(), b256), 4); + + __m256i unpack_lo = _mm256_unpacklo_epi8(b_lo, b_hi); + __m256i unpack_hi = _mm256_unpackhi_epi8(b_lo, b_hi); + __m512i result = _mm512_inserti64x4(_mm512_castsi256_si512(unpack_lo), unpack_hi, 1); + const __m512i lane_shuffle = _mm512_set_epi64(7, 6, 3, 2, 5, 4, 1, 0); + return _mm512_permutexvar_epi64(lane_shuffle, result); + } + static inline void integer_mat_vec_kgroup(int m, int n, int k, int k_group_size, BufferA* ba, BufferB *bb, BufferC* bc, int ith, int nth) { + auto [n_start, n_end] = split_range_n(n, ith, nth); + for (int m_begin = 0; m_begin < m; m_begin ++) { + float* c = bc->get_submat(m, n, m_begin, 0); + __m512i* a512 = (__m512i*)ba->get_submat(m, k, m_begin, 0); + + for (int n_block_begin = n_start; n_block_begin < n_end; n_block_begin ++) { + __m256i* b256 = (__m256i*)bb->get_submat(n, k, n_block_begin, 0); + float* as = (float*)ba->get_scale(m, m_begin, k, 0); + float* bs = (float*)bb->get_scale(n, n_block_begin, k, 0); + + __m512 sum = _mm512_setzero_ps(); + #define WORK_K_BLOCK(k_block) \ + { \ + __m256 abscale0 = _mm256_set1_ps(as[(k_block)*2] * bs[(k_block)*2]); \ + __m256 abscale1 = _mm256_set1_ps(as[(k_block)*2+1] * bs[(k_block)*2+1]); \ + __m512 abscale = _mm512_insertf32x8(_mm512_castps256_ps512(abscale0), abscale1, 1); \ + __m512i mul = _mm512_setzero_si512(); \ + mul = _mm512_dpbssd_epi32(mul, a512[k_block], compressed_int4_to_int8_avx512(b256[k_block])); \ + sum = _mm512_add_ps(sum, _mm512_mul_ps(abscale, _mm512_cvtepi32_ps(mul))); \ + } + + for (int k_block = 0; k_block < k / 64; k_block += 2) { + WORK_K_BLOCK(k_block); + WORK_K_BLOCK(k_block + 1); + } + + c[n_block_begin] = _mm512_reduce_add_ps(sum) / 16; + } + } + } +}; + +inline void vec_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr ba, + std::shared_ptr bb, + std::shared_ptr bc, int ith, int nth) { + GemmKernel224Int4SmallKGroup::integer_mat_vec_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth); +} + +inline void mat_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr ba, + std::shared_ptr bb, + std::shared_ptr bc, int ith, int nth) { + GemmKernel224Int4SmallKGroup::integer_mat_vec_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth); +} + // New k-group aware matrix multiplication function template void integer_mat_mul_kgroup(int m, int n, int k, int k_group_size, typename K::BufferA* ba, typename K::BufferB* bb, diff --git a/kt-kernel/python/experts.py b/kt-kernel/python/experts.py index 78807eeb..0f89a75a 100644 --- a/kt-kernel/python/experts.py +++ b/kt-kernel/python/experts.py @@ -17,7 +17,7 @@ from typing import List, Optional from .experts_base import BaseMoEWrapper, KExpertsCPUBuffer # Import backend implementations -from .utils.amx import AMXMoEWrapper +from .utils.amx import AMXMoEWrapper, RAWAMXMoEWrapper from .utils.llamafile import LlamafileMoEWrapper from .utils.moe_kernel import GeneralMoEWrapper @@ -77,7 +77,7 @@ class KTMoEWrapper: chunked_prefill_size: Maximum prefill chunk size cpu_save: Whether to save weights to CPU memory max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0. - method: Backend method ("AMXINT4", "AMXINT8", "LLAMAFILE", "MOE_INT4", "MOE_INT8") + method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "LLAMAFILE", "MOE_INT4", "MOE_INT8") Returns: An instance of the appropriate backend implementation (e.g., AMXMoEWrapper) @@ -85,6 +85,8 @@ class KTMoEWrapper: # Select backend based on method if method in ["AMXINT4", "AMXINT8"]: backend_cls = AMXMoEWrapper + elif method == "RAWINT4": + backend_cls = RAWAMXMoEWrapper elif method == "LLAMAFILE": backend_cls = LlamafileMoEWrapper elif method in ["MOE_INT4", "MOE_INT8"]: diff --git a/kt-kernel/python/utils/__init__.py b/kt-kernel/python/utils/__init__.py index f71809b0..729699f2 100644 --- a/kt-kernel/python/utils/__init__.py +++ b/kt-kernel/python/utils/__init__.py @@ -4,13 +4,15 @@ Utilities for kt_kernel package. """ -from .amx import AMXMoEWrapper +from .amx import AMXMoEWrapper, RAWAMXMoEWrapper from .llamafile import LlamafileMoEWrapper -from .loader import SafeTensorLoader, GGUFLoader +from .loader import SafeTensorLoader, GGUFLoader, CompressedSafeTensorLoader __all__ = [ "AMXMoEWrapper", + "RAWAMXMoEWrapper", "LlamafileMoEWrapper", "SafeTensorLoader", + "CompressedSafeTensorLoader", "GGUFLoader", ] diff --git a/kt-kernel/python/utils/amx.py b/kt-kernel/python/utils/amx.py index 751da47c..b36ba3dc 100644 --- a/kt-kernel/python/utils/amx.py +++ b/kt-kernel/python/utils/amx.py @@ -4,16 +4,16 @@ import ctypes # Use relative imports for package structure from ..experts_base import BaseMoEWrapper -from .loader import SafeTensorLoader +from .loader import SafeTensorLoader, CompressedSafeTensorLoader from kt_kernel_ext.moe import MOEConfig try: - from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE + from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE _HAS_AMX_SUPPORT = True except (ImportError, AttributeError): _HAS_AMX_SUPPORT = False - AMXInt4_MOE, AMXInt8_MOE = None, None + AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE = None, None, None from typing import Optional @@ -301,3 +301,152 @@ class AMXMoEWrapper(BaseMoEWrapper): del self.gate_scales del self.up_scales del self.down_scales + + +class RAWAMXMoEWrapper(BaseMoEWrapper): + """Wrapper for RAWINT4 experts stored in compressed SafeTensor format.""" + + _compressed_loader_instance = None + + def __init__( + self, + layer_idx: int, + num_experts: int, + num_experts_per_tok: int, + hidden_size: int, + moe_intermediate_size: int, + num_gpu_experts: int, + cpuinfer_threads: int, + threadpool_count: int, + weight_path: str, + chunked_prefill_size: int, + cpu_save: bool = False, + max_deferred_experts_per_token: Optional[int] = None, + method: str = "RAWINT4", + ): + if not _HAS_AMX_SUPPORT or AMXInt4_KGroup_MOE is None: + raise RuntimeError("AMX backend with RAWINT4 support is not available.") + + super().__init__( + layer_idx=layer_idx, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + hidden_size=hidden_size, + moe_intermediate_size=moe_intermediate_size, + num_gpu_experts=num_gpu_experts, + cpuinfer_threads=cpuinfer_threads, + threadpool_count=threadpool_count, + weight_path=weight_path, + chunked_prefill_size=chunked_prefill_size, + cpu_save=cpu_save, + max_deferred_experts_per_token=max_deferred_experts_per_token, + method=method, + ) + + if RAWAMXMoEWrapper._compressed_loader_instance is None: + RAWAMXMoEWrapper._compressed_loader_instance = CompressedSafeTensorLoader(weight_path) + self.loader = RAWAMXMoEWrapper._compressed_loader_instance + + self.gate_weights = None + self.up_weights = None + self.down_weights = None + self.gate_scales = None + self.up_scales = None + self.down_scales = None + + def load_weights_from_tensors( + self, + gate_proj: torch.Tensor, + up_proj: torch.Tensor, + down_proj: torch.Tensor, + physical_to_logical_map_cpu: torch.Tensor, + ): + raise NotImplementedError("RAWINT4 wrapper expects pre-quantized safetensor weights.") + + def load_weights(self, physical_to_logical_map_cpu: torch.Tensor): + base_key = f"model.layers.{self.layer_idx}" + weights = self.loader.load_experts(base_key) + + self.gate_weights = torch.stack(weights["gate"], dim=0).contiguous() + self.up_weights = torch.stack(weights["up"], dim=0).contiguous() + self.down_weights = torch.stack(weights["down"], dim=0).contiguous() + + self.gate_scales = torch.stack(weights["gate_scale"], dim=0).to(torch.bfloat16).contiguous() + self.up_scales = torch.stack(weights["up_scale"], dim=0).to(torch.bfloat16).contiguous() + self.down_scales = torch.stack(weights["down_scale"], dim=0).to(torch.bfloat16).contiguous() + + moe_config = MOEConfig( + self.num_experts, + self.num_experts_per_tok, + self.hidden_size, + self.moe_intermediate_size, + self.num_gpu_experts, + ) + moe_config.layer_idx = self.layer_idx + moe_config.pool = self.cpu_infer.backend_ + moe_config.max_len = self.chunked_prefill_size + + moe_config.quant_config.bits = 4 + moe_config.quant_config.group_size = 32 + moe_config.quant_config.zero_point = False + + moe_config.gate_proj = self.gate_weights.data_ptr() + moe_config.up_proj = self.up_weights.data_ptr() + moe_config.down_proj = self.down_weights.data_ptr() + moe_config.gate_scale = self.gate_scales.data_ptr() + moe_config.up_scale = self.up_scales.data_ptr() + moe_config.down_scale = self.down_scales.data_ptr() + + self.moe = AMXInt4_KGroup_MOE(moe_config) + + self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr())) + self.cpu_infer.sync() + + del self.gate_weights + del self.up_weights + del self.down_weights + del self.gate_scales + del self.up_scales + del self.down_scales + + def submit_write_weight_scale_to_buffer( + self, + gpu_tp_count: int, + gpu_experts_num: int, + w13_weight_ptrs, + w13_scale_ptrs, + w2_weight_ptrs, + w2_scale_ptrs, + ): + """ + Submit the write_weight_scale_to_buffer task for RAWINT4 KGroup AMX implementation. + + This method submits the C++-exposed task `write_weight_scale_to_buffer_task` to the + shared CPUInfer queue. The pointer lists should be plain integer lists (e.g. from + tensor.data_ptr()). + """ + if self.moe is None: + raise RuntimeError("MoE instance not initialized; cannot submit write_weight_scale_to_buffer task.") + + if not hasattr(self.moe, "write_weight_scale_to_buffer_task"): + raise NotImplementedError( + "write_weight_scale_to_buffer_task is not available for this backend implementation." + ) + + self.cpu_infer.submit( + self.moe.write_weight_scale_to_buffer_task( + gpu_tp_count, + gpu_experts_num, + w13_weight_ptrs, + w13_scale_ptrs, + w2_weight_ptrs, + w2_scale_ptrs, + ) + ) + + def sync_write_weight_scale_to_buffer(self): + """ + Block until previously submitted write_weight_scale_to_buffer tasks finish. + """ + # The CPUInfer.sync() call blocks until pending tasks complete. + self.cpu_infer.sync() diff --git a/kt-kernel/python/utils/loader.py b/kt-kernel/python/utils/loader.py index b3b563c8..db689f40 100644 --- a/kt-kernel/python/utils/loader.py +++ b/kt-kernel/python/utils/loader.py @@ -237,6 +237,56 @@ class SafeTensorLoader: return name in self.tensor_file_map +class CompressedSafeTensorLoader(SafeTensorLoader): + """Loader for compressed SafeTensor layouts (RAWINT4 weights).""" + + def load_experts(self, base_key: str, device: str = "cpu"): + """Load raw expert weights stored in compressed safetensor format.""" + + experts_prefix = f"{base_key}.mlp.experts" + + expert_idx = 0 + while self.has_tensor(f"{experts_prefix}.{expert_idx}.up_proj.weight_packed"): + expert_idx += 1 + + if expert_idx == 0: + raise ValueError(f"No experts found for key {experts_prefix}") + + def load_projection(proj_name: str): + weight_entries = [] + scale_entries = [] + + for exp_id in range(expert_idx): + weight_key = f"{experts_prefix}.{exp_id}.{proj_name}_proj.weight_packed" + scale_key = f"{experts_prefix}.{exp_id}.{proj_name}_proj.weight_scale" + + if not self.has_tensor(weight_key): + raise KeyError(f"Missing tensor: {weight_key}") + if not self.has_tensor(scale_key): + raise KeyError(f"Missing tensor: {scale_key}") + + weight_tensor = self.load_tensor(weight_key, device).contiguous() + scale_tensor = self.load_tensor(scale_key, device).contiguous() + + weight_entries.append(weight_tensor) + scale_entries.append(scale_tensor) + + return weight_entries, scale_entries + + gate_weights, gate_scales = load_projection("gate") + up_weights, up_scales = load_projection("up") + down_weights, down_scales = load_projection("down") + + return { + "gate": gate_weights, + "up": up_weights, + "down": down_weights, + "gate_scale": gate_scales, + "up_scale": up_scales, + "down_scale": down_scales, + } + + class GGUFLoader: """ GGUF format loader using the official gguf library (gguf.gguf_reader.GGUFReader)