mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-05-05 15:40:13 +00:00
support GLM 4.7 (#1791)
Some checks failed
Book-CI / test-2 (push) Has been cancelled
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
Some checks failed
Book-CI / test-2 (push) Has been cancelled
Book-CI / test (push) Has been cancelled
Book-CI / test-1 (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (ubuntu-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled
support GLM 4.7
This commit is contained in:
parent
667030d6e6
commit
6277da4c2b
14 changed files with 2336 additions and 144 deletions
|
|
@ -2,12 +2,14 @@
|
|||
Test write_weight_scale_to_buffer for AMX MOE operators.
|
||||
|
||||
Supports:
|
||||
- FP8: FP8 weights (1 byte) + float32 scales
|
||||
- FP8: FP8 weights (1 byte) + float32 scales (block-wise)
|
||||
- FP8_PERCHANNEL: FP8 weights (1 byte) + float32 per-channel scales
|
||||
- BF16: Native BF16 weights (2 bytes), no scales
|
||||
|
||||
Usage:
|
||||
python test_write_buffer.py # Run all modes
|
||||
python test_write_buffer.py fp8 # Run FP8 only
|
||||
python test_write_buffer.py fp8_perchannel # Run FP8 per-channel only
|
||||
python test_write_buffer.py bf16 # Run BF16 only
|
||||
"""
|
||||
|
||||
|
|
@ -41,6 +43,17 @@ def build_config_fp8(cpuinfer, expert_num, num_experts_per_tok, hidden_size, int
|
|||
return cfg
|
||||
|
||||
|
||||
def build_config_fp8_perchannel(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size):
|
||||
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
cfg.max_len = 1
|
||||
cfg.quant_config.bits = 8 # FP8
|
||||
cfg.quant_config.group_size = 0 # Not used for per-channel
|
||||
cfg.quant_config.zero_point = False
|
||||
cfg.quant_config.per_channel = True
|
||||
cfg.pool = cpuinfer.backend_
|
||||
return cfg
|
||||
|
||||
|
||||
def build_config_bf16(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size):
|
||||
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
cfg.max_len = 1
|
||||
|
|
@ -83,6 +96,33 @@ def allocate_weights_fp8(expert_num, hidden_size, intermediate_size, group_size)
|
|||
}
|
||||
|
||||
|
||||
def allocate_weights_fp8_perchannel(expert_num, hidden_size, intermediate_size):
|
||||
"""Allocate FP8 per-channel weights and scales for testing"""
|
||||
per_mat_weight_bytes = hidden_size * intermediate_size
|
||||
per_mat_scale_elems_gate_up = intermediate_size # one scale per output channel
|
||||
per_mat_scale_elems_down = hidden_size
|
||||
|
||||
gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
||||
up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
||||
down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
||||
|
||||
gate_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
|
||||
up_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
|
||||
down_scale = torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32)
|
||||
|
||||
return {
|
||||
"gate_q": gate_q,
|
||||
"up_q": up_q,
|
||||
"down_q": down_q,
|
||||
"gate_scale": gate_scale,
|
||||
"up_scale": up_scale,
|
||||
"down_scale": down_scale,
|
||||
"per_mat_weight_bytes": per_mat_weight_bytes,
|
||||
"per_mat_scale_elems_gate_up": per_mat_scale_elems_gate_up,
|
||||
"per_mat_scale_elems_down": per_mat_scale_elems_down,
|
||||
}
|
||||
|
||||
|
||||
def allocate_weights_bf16(expert_num, hidden_size, intermediate_size):
|
||||
"""Allocate BF16 weights for testing (no scales)"""
|
||||
# BF16 weights: 2 bytes per element
|
||||
|
|
@ -312,6 +352,195 @@ def test_fp8_write_buffer(gpu_tp_count):
|
|||
return True
|
||||
|
||||
|
||||
def test_fp8_perchannel_write_buffer(gpu_tp_count):
|
||||
"""Test write_weight_scale_to_buffer with FP8 per-channel weights"""
|
||||
torch.manual_seed(123)
|
||||
|
||||
expert_num = 256
|
||||
gpu_experts = expert_num
|
||||
num_experts_per_tok = 8
|
||||
hidden_size = 3072
|
||||
intermediate_size = 1536
|
||||
|
||||
cpuinfer = make_cpu_infer()
|
||||
cfg = build_config_fp8_perchannel(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
weights = allocate_weights_fp8_perchannel(expert_num, hidden_size, intermediate_size)
|
||||
|
||||
cfg.gate_proj = weights["gate_q"].data_ptr()
|
||||
cfg.up_proj = weights["up_q"].data_ptr()
|
||||
cfg.down_proj = weights["down_q"].data_ptr()
|
||||
cfg.gate_scale = weights["gate_scale"].data_ptr()
|
||||
cfg.up_scale = weights["up_scale"].data_ptr()
|
||||
cfg.down_scale = weights["down_scale"].data_ptr()
|
||||
|
||||
moe = kt_kernel_ext.moe.AMXFP8PerChannel_MOE(cfg)
|
||||
|
||||
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
|
||||
cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
cpuinfer.sync()
|
||||
|
||||
per_mat_weight_bytes = weights["per_mat_weight_bytes"]
|
||||
per_mat_scale_elems_gate_up = weights["per_mat_scale_elems_gate_up"]
|
||||
per_mat_scale_elems_down = weights["per_mat_scale_elems_down"]
|
||||
|
||||
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
|
||||
gpu_n_w13 = intermediate_size // gpu_tp_count
|
||||
scale_elems_per_expert_per_tp_gate_up = gpu_n_w13
|
||||
scale_elems_per_expert_per_tp_down = per_mat_scale_elems_down
|
||||
|
||||
total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp
|
||||
total_scale_elems_per_tp_gate_up = gpu_experts * scale_elems_per_expert_per_tp_gate_up
|
||||
total_scale_elems_per_tp_down = gpu_experts * scale_elems_per_expert_per_tp_down
|
||||
|
||||
w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
||||
w13_scale_bufs = [
|
||||
torch.empty(2 * total_scale_elems_per_tp_gate_up, dtype=torch.float32) for _ in range(gpu_tp_count)
|
||||
]
|
||||
w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
||||
w2_scale_bufs = [torch.empty(total_scale_elems_per_tp_down, dtype=torch.float32) for _ in range(gpu_tp_count)]
|
||||
|
||||
print(f"[FP8_PERCHANNEL] GPU TP count: {gpu_tp_count}, Experts: {expert_num}")
|
||||
print(f"[FP8_PERCHANNEL] Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}")
|
||||
print(f"[FP8_PERCHANNEL] Scale elements per expert per TP (gate/up): {scale_elems_per_expert_per_tp_gate_up}")
|
||||
|
||||
def get_expert_ptrs(expert_id):
|
||||
w13_weight_ptrs = []
|
||||
w13_scale_ptrs = []
|
||||
w2_weight_ptrs = []
|
||||
w2_scale_ptrs = []
|
||||
for tp_idx in range(gpu_tp_count):
|
||||
w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp
|
||||
w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp_gate_up
|
||||
w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp
|
||||
w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp_down
|
||||
|
||||
w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)
|
||||
w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 4)
|
||||
w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)
|
||||
w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 4)
|
||||
return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs
|
||||
|
||||
for _ in range(2):
|
||||
for expert_id in range(gpu_experts):
|
||||
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
||||
cpuinfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
expert_id=expert_id,
|
||||
w13_weight_ptrs=w13_weight_ptrs,
|
||||
w13_scale_ptrs=w13_scale_ptrs,
|
||||
w2_weight_ptrs=w2_weight_ptrs,
|
||||
w2_scale_ptrs=w2_scale_ptrs,
|
||||
)
|
||||
)
|
||||
cpuinfer.sync()
|
||||
|
||||
begin_time = time.perf_counter_ns()
|
||||
for expert_id in range(gpu_experts):
|
||||
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
||||
cpuinfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
expert_id=expert_id,
|
||||
w13_weight_ptrs=w13_weight_ptrs,
|
||||
w13_scale_ptrs=w13_scale_ptrs,
|
||||
w2_weight_ptrs=w2_weight_ptrs,
|
||||
w2_scale_ptrs=w2_scale_ptrs,
|
||||
)
|
||||
)
|
||||
cpuinfer.sync()
|
||||
end_time = time.perf_counter_ns()
|
||||
elapsed_ms = (end_time - begin_time) / 1e6
|
||||
|
||||
total_bytes = (
|
||||
hidden_size * intermediate_size * gpu_experts * 3
|
||||
+ (per_mat_scale_elems_gate_up * 2 + per_mat_scale_elems_down) * gpu_experts * 4
|
||||
)
|
||||
print(f"[FP8_PERCHANNEL] write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
|
||||
print(f"[FP8_PERCHANNEL] Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
|
||||
|
||||
def split_expert_tensor(tensor, chunk):
|
||||
return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]
|
||||
|
||||
gate_q = weights["gate_q"]
|
||||
up_q = weights["up_q"]
|
||||
down_q = weights["down_q"]
|
||||
gate_scale = weights["gate_scale"]
|
||||
up_scale = weights["up_scale"]
|
||||
down_scale = weights["down_scale"]
|
||||
|
||||
gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes)
|
||||
up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes)
|
||||
down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes)
|
||||
gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems_gate_up)
|
||||
up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems_gate_up)
|
||||
down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems_down)
|
||||
|
||||
for tp_idx in range(gpu_tp_count):
|
||||
expected_w13_weights = []
|
||||
expected_w13_scales = []
|
||||
expected_w2_weights = []
|
||||
expected_w2_scales = []
|
||||
|
||||
weight13_per_tp = per_mat_weight_bytes // gpu_tp_count
|
||||
scale13_per_tp = per_mat_scale_elems_gate_up // gpu_tp_count
|
||||
|
||||
for expert_id in range(gpu_experts):
|
||||
start_weight = tp_idx * weight13_per_tp
|
||||
end_weight = (tp_idx + 1) * weight13_per_tp
|
||||
start_scale = tp_idx * scale13_per_tp
|
||||
end_scale = (tp_idx + 1) * scale13_per_tp
|
||||
|
||||
gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight]
|
||||
gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale]
|
||||
up_weight_tp = up_q_experts[expert_id][start_weight:end_weight]
|
||||
up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale]
|
||||
|
||||
down_weight_tp_parts = []
|
||||
tp_slice_weight_size = intermediate_size // gpu_tp_count
|
||||
|
||||
for row_idx in range(hidden_size):
|
||||
row_weight_start = row_idx * intermediate_size
|
||||
tp_weight_offset = row_weight_start + tp_idx * tp_slice_weight_size
|
||||
down_weight_tp_parts.append(
|
||||
down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]
|
||||
)
|
||||
|
||||
down_weight_tp = torch.cat(down_weight_tp_parts)
|
||||
down_scale_tp = down_scale_experts[expert_id]
|
||||
|
||||
expected_w13_weights.append(gate_weight_tp)
|
||||
expected_w13_weights.append(up_weight_tp)
|
||||
expected_w13_scales.append(gate_scale_tp)
|
||||
expected_w13_scales.append(up_scale_tp)
|
||||
expected_w2_weights.append(down_weight_tp)
|
||||
expected_w2_scales.append(down_scale_tp)
|
||||
|
||||
expected_w13_weight = torch.cat(expected_w13_weights)
|
||||
expected_w13_scale = torch.cat(expected_w13_scales)
|
||||
expected_w2_weight = torch.cat(expected_w2_weights)
|
||||
expected_w2_scale = torch.cat(expected_w2_scales)
|
||||
|
||||
if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):
|
||||
diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight
|
||||
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
||||
raise AssertionError(f"[FP8_PERCHANNEL] w13 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
|
||||
|
||||
if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale):
|
||||
raise AssertionError(f"[FP8_PERCHANNEL] w13 scale mismatch for TP {tp_idx}")
|
||||
|
||||
if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):
|
||||
diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight
|
||||
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
||||
raise AssertionError(f"[FP8_PERCHANNEL] w2 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
|
||||
|
||||
if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale):
|
||||
raise AssertionError(f"[FP8_PERCHANNEL] w2 scale mismatch for TP {tp_idx}")
|
||||
|
||||
print(f"[FP8_PERCHANNEL] TP={gpu_tp_count} PASSED (verified {gpu_experts} experts across {gpu_tp_count} TP parts)")
|
||||
return True
|
||||
|
||||
|
||||
def test_bf16_write_buffer(gpu_tp_count):
|
||||
"""Test write_weight_scale_to_buffer with BF16 weights (no scales)"""
|
||||
torch.manual_seed(123)
|
||||
|
|
@ -478,6 +707,8 @@ def test_with_tp(quant_mode: str, gpu_tp_count: int):
|
|||
"""Test write_weight_scale_to_buffer with specified mode and TP count"""
|
||||
if quant_mode == "fp8":
|
||||
return test_fp8_write_buffer(gpu_tp_count)
|
||||
elif quant_mode == "fp8_perchannel":
|
||||
return test_fp8_perchannel_write_buffer(gpu_tp_count)
|
||||
elif quant_mode == "bf16":
|
||||
return test_bf16_write_buffer(gpu_tp_count)
|
||||
else:
|
||||
|
|
@ -487,7 +718,7 @@ def test_with_tp(quant_mode: str, gpu_tp_count: int):
|
|||
def main(quant_modes=None):
|
||||
"""Run tests for specified quant modes"""
|
||||
if quant_modes is None:
|
||||
quant_modes = ["fp8", "bf16"]
|
||||
quant_modes = ["fp8", "fp8_perchannel", "bf16"]
|
||||
|
||||
tp_values = [1, 2, 4]
|
||||
all_passed = True
|
||||
|
|
@ -525,10 +756,10 @@ def main(quant_modes=None):
|
|||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
mode = sys.argv[1].lower()
|
||||
if mode in ["fp8", "bf16"]:
|
||||
if mode in ["fp8", "fp8_perchannel", "bf16"]:
|
||||
main([mode])
|
||||
else:
|
||||
print(f"Unknown mode: {mode}. Use 'fp8' or 'bf16'")
|
||||
print(f"Unknown mode: {mode}. Use 'fp8', 'fp8_perchannel' or 'bf16'")
|
||||
sys.exit(1)
|
||||
else:
|
||||
main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue