support GLM 4.7 (#1791)
Some checks failed
Book-CI / test-2 (push) Has been cancelled
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled

support GLM 4.7
This commit is contained in:
Oql 2026-01-13 17:36:25 +08:00 committed by GitHub
parent 667030d6e6
commit 6277da4c2b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 2336 additions and 144 deletions

View file

@ -2,12 +2,14 @@
Test write_weight_scale_to_buffer for AMX MOE operators.
Supports:
- FP8: FP8 weights (1 byte) + float32 scales
- FP8: FP8 weights (1 byte) + float32 scales (block-wise)
- FP8_PERCHANNEL: FP8 weights (1 byte) + float32 per-channel scales
- BF16: Native BF16 weights (2 bytes), no scales
Usage:
python test_write_buffer.py # Run all modes
python test_write_buffer.py fp8 # Run FP8 only
python test_write_buffer.py fp8_perchannel # Run FP8 per-channel only
python test_write_buffer.py bf16 # Run BF16 only
"""
@ -41,6 +43,17 @@ def build_config_fp8(cpuinfer, expert_num, num_experts_per_tok, hidden_size, int
return cfg
def build_config_fp8_perchannel(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size):
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
cfg.max_len = 1
cfg.quant_config.bits = 8 # FP8
cfg.quant_config.group_size = 0 # Not used for per-channel
cfg.quant_config.zero_point = False
cfg.quant_config.per_channel = True
cfg.pool = cpuinfer.backend_
return cfg
def build_config_bf16(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size):
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
cfg.max_len = 1
@ -83,6 +96,33 @@ def allocate_weights_fp8(expert_num, hidden_size, intermediate_size, group_size)
}
def allocate_weights_fp8_perchannel(expert_num, hidden_size, intermediate_size):
"""Allocate FP8 per-channel weights and scales for testing"""
per_mat_weight_bytes = hidden_size * intermediate_size
per_mat_scale_elems_gate_up = intermediate_size # one scale per output channel
per_mat_scale_elems_down = hidden_size
gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
gate_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
up_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
down_scale = torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32)
return {
"gate_q": gate_q,
"up_q": up_q,
"down_q": down_q,
"gate_scale": gate_scale,
"up_scale": up_scale,
"down_scale": down_scale,
"per_mat_weight_bytes": per_mat_weight_bytes,
"per_mat_scale_elems_gate_up": per_mat_scale_elems_gate_up,
"per_mat_scale_elems_down": per_mat_scale_elems_down,
}
def allocate_weights_bf16(expert_num, hidden_size, intermediate_size):
"""Allocate BF16 weights for testing (no scales)"""
# BF16 weights: 2 bytes per element
@ -312,6 +352,195 @@ def test_fp8_write_buffer(gpu_tp_count):
return True
def test_fp8_perchannel_write_buffer(gpu_tp_count):
"""Test write_weight_scale_to_buffer with FP8 per-channel weights"""
torch.manual_seed(123)
expert_num = 256
gpu_experts = expert_num
num_experts_per_tok = 8
hidden_size = 3072
intermediate_size = 1536
cpuinfer = make_cpu_infer()
cfg = build_config_fp8_perchannel(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size)
weights = allocate_weights_fp8_perchannel(expert_num, hidden_size, intermediate_size)
cfg.gate_proj = weights["gate_q"].data_ptr()
cfg.up_proj = weights["up_q"].data_ptr()
cfg.down_proj = weights["down_q"].data_ptr()
cfg.gate_scale = weights["gate_scale"].data_ptr()
cfg.up_scale = weights["up_scale"].data_ptr()
cfg.down_scale = weights["down_scale"].data_ptr()
moe = kt_kernel_ext.moe.AMXFP8PerChannel_MOE(cfg)
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
cpuinfer.sync()
per_mat_weight_bytes = weights["per_mat_weight_bytes"]
per_mat_scale_elems_gate_up = weights["per_mat_scale_elems_gate_up"]
per_mat_scale_elems_down = weights["per_mat_scale_elems_down"]
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
gpu_n_w13 = intermediate_size // gpu_tp_count
scale_elems_per_expert_per_tp_gate_up = gpu_n_w13
scale_elems_per_expert_per_tp_down = per_mat_scale_elems_down
total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp
total_scale_elems_per_tp_gate_up = gpu_experts * scale_elems_per_expert_per_tp_gate_up
total_scale_elems_per_tp_down = gpu_experts * scale_elems_per_expert_per_tp_down
w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
w13_scale_bufs = [
torch.empty(2 * total_scale_elems_per_tp_gate_up, dtype=torch.float32) for _ in range(gpu_tp_count)
]
w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
w2_scale_bufs = [torch.empty(total_scale_elems_per_tp_down, dtype=torch.float32) for _ in range(gpu_tp_count)]
print(f"[FP8_PERCHANNEL] GPU TP count: {gpu_tp_count}, Experts: {expert_num}")
print(f"[FP8_PERCHANNEL] Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}")
print(f"[FP8_PERCHANNEL] Scale elements per expert per TP (gate/up): {scale_elems_per_expert_per_tp_gate_up}")
def get_expert_ptrs(expert_id):
w13_weight_ptrs = []
w13_scale_ptrs = []
w2_weight_ptrs = []
w2_scale_ptrs = []
for tp_idx in range(gpu_tp_count):
w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp
w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp_gate_up
w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp
w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp_down
w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)
w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 4)
w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)
w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 4)
return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs
for _ in range(2):
for expert_id in range(gpu_experts):
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
cpuinfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
expert_id=expert_id,
w13_weight_ptrs=w13_weight_ptrs,
w13_scale_ptrs=w13_scale_ptrs,
w2_weight_ptrs=w2_weight_ptrs,
w2_scale_ptrs=w2_scale_ptrs,
)
)
cpuinfer.sync()
begin_time = time.perf_counter_ns()
for expert_id in range(gpu_experts):
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
cpuinfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
expert_id=expert_id,
w13_weight_ptrs=w13_weight_ptrs,
w13_scale_ptrs=w13_scale_ptrs,
w2_weight_ptrs=w2_weight_ptrs,
w2_scale_ptrs=w2_scale_ptrs,
)
)
cpuinfer.sync()
end_time = time.perf_counter_ns()
elapsed_ms = (end_time - begin_time) / 1e6
total_bytes = (
hidden_size * intermediate_size * gpu_experts * 3
+ (per_mat_scale_elems_gate_up * 2 + per_mat_scale_elems_down) * gpu_experts * 4
)
print(f"[FP8_PERCHANNEL] write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
print(f"[FP8_PERCHANNEL] Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
def split_expert_tensor(tensor, chunk):
return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]
gate_q = weights["gate_q"]
up_q = weights["up_q"]
down_q = weights["down_q"]
gate_scale = weights["gate_scale"]
up_scale = weights["up_scale"]
down_scale = weights["down_scale"]
gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes)
up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes)
down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes)
gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems_gate_up)
up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems_gate_up)
down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems_down)
for tp_idx in range(gpu_tp_count):
expected_w13_weights = []
expected_w13_scales = []
expected_w2_weights = []
expected_w2_scales = []
weight13_per_tp = per_mat_weight_bytes // gpu_tp_count
scale13_per_tp = per_mat_scale_elems_gate_up // gpu_tp_count
for expert_id in range(gpu_experts):
start_weight = tp_idx * weight13_per_tp
end_weight = (tp_idx + 1) * weight13_per_tp
start_scale = tp_idx * scale13_per_tp
end_scale = (tp_idx + 1) * scale13_per_tp
gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight]
gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale]
up_weight_tp = up_q_experts[expert_id][start_weight:end_weight]
up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale]
down_weight_tp_parts = []
tp_slice_weight_size = intermediate_size // gpu_tp_count
for row_idx in range(hidden_size):
row_weight_start = row_idx * intermediate_size
tp_weight_offset = row_weight_start + tp_idx * tp_slice_weight_size
down_weight_tp_parts.append(
down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]
)
down_weight_tp = torch.cat(down_weight_tp_parts)
down_scale_tp = down_scale_experts[expert_id]
expected_w13_weights.append(gate_weight_tp)
expected_w13_weights.append(up_weight_tp)
expected_w13_scales.append(gate_scale_tp)
expected_w13_scales.append(up_scale_tp)
expected_w2_weights.append(down_weight_tp)
expected_w2_scales.append(down_scale_tp)
expected_w13_weight = torch.cat(expected_w13_weights)
expected_w13_scale = torch.cat(expected_w13_scales)
expected_w2_weight = torch.cat(expected_w2_weights)
expected_w2_scale = torch.cat(expected_w2_scales)
if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):
diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
raise AssertionError(f"[FP8_PERCHANNEL] w13 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale):
raise AssertionError(f"[FP8_PERCHANNEL] w13 scale mismatch for TP {tp_idx}")
if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):
diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
raise AssertionError(f"[FP8_PERCHANNEL] w2 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale):
raise AssertionError(f"[FP8_PERCHANNEL] w2 scale mismatch for TP {tp_idx}")
print(f"[FP8_PERCHANNEL] TP={gpu_tp_count} PASSED (verified {gpu_experts} experts across {gpu_tp_count} TP parts)")
return True
def test_bf16_write_buffer(gpu_tp_count):
"""Test write_weight_scale_to_buffer with BF16 weights (no scales)"""
torch.manual_seed(123)
@ -478,6 +707,8 @@ def test_with_tp(quant_mode: str, gpu_tp_count: int):
"""Test write_weight_scale_to_buffer with specified mode and TP count"""
if quant_mode == "fp8":
return test_fp8_write_buffer(gpu_tp_count)
elif quant_mode == "fp8_perchannel":
return test_fp8_perchannel_write_buffer(gpu_tp_count)
elif quant_mode == "bf16":
return test_bf16_write_buffer(gpu_tp_count)
else:
@ -487,7 +718,7 @@ def test_with_tp(quant_mode: str, gpu_tp_count: int):
def main(quant_modes=None):
"""Run tests for specified quant modes"""
if quant_modes is None:
quant_modes = ["fp8", "bf16"]
quant_modes = ["fp8", "fp8_perchannel", "bf16"]
tp_values = [1, 2, 4]
all_passed = True
@ -525,10 +756,10 @@ def main(quant_modes=None):
if __name__ == "__main__":
if len(sys.argv) > 1:
mode = sys.argv[1].lower()
if mode in ["fp8", "bf16"]:
if mode in ["fp8", "fp8_perchannel", "bf16"]:
main([mode])
else:
print(f"Unknown mode: {mode}. Use 'fp8' or 'bf16'")
print(f"Unknown mode: {mode}. Use 'fp8', 'fp8_perchannel' or 'bf16'")
sys.exit(1)
else:
main()