mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 14:51:06 +00:00
1) Linear and MLP operators support qlen>1; 2) All operators now share a single memory buffer; 3) Refactor CPUInfer submit/sync logic.
This commit is contained in:
parent
442e13bc97
commit
c1cc7d2cd2
21 changed files with 749 additions and 731 deletions
|
@ -6,7 +6,7 @@ Author : chenht2022
|
|||
Date : 2024-07-25 10:32:05
|
||||
Version : 1.0.0
|
||||
LastEditors : chenht2022
|
||||
LastEditTime : 2024-07-25 10:34:06
|
||||
LastEditTime : 2024-08-06 10:38:05
|
||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
'''
|
||||
import os, sys
|
||||
|
@ -15,25 +15,64 @@ sys.path.append(os.path.dirname(__file__) + '/../build')
|
|||
import cpuinfer_ext
|
||||
import torch
|
||||
|
||||
with torch.inference_mode(mode=True):
|
||||
expert_num = 10
|
||||
hidden_size = 5120
|
||||
intermediate_size = 1536
|
||||
stride = 32
|
||||
group_min_len = 10
|
||||
group_max_len = 1024
|
||||
gate_type = 1 # ggml_type::GGML_TYPE_F16
|
||||
up_type = 1 # ggml_type::GGML_TYPE_F16
|
||||
down_type = 1 # ggml_type::GGML_TYPE_F16
|
||||
hidden_type = 1 # ggml_type::GGML_TYPE_F16
|
||||
n_routed_experts = 6
|
||||
qlen = 30
|
||||
layer_num = 10
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(48)
|
||||
validation_iter = 100
|
||||
warm_up_iter = 1000
|
||||
test_iter = 10000
|
||||
expert_num = 160
|
||||
hidden_size = 5120
|
||||
intermediate_size = 1536
|
||||
stride = 32
|
||||
group_min_len = 10
|
||||
group_max_len = 1024
|
||||
gate_type = 1 # ggml_type::GGML_TYPE_F16
|
||||
up_type = 1 # ggml_type::GGML_TYPE_F16
|
||||
down_type = 1 # ggml_type::GGML_TYPE_F16
|
||||
hidden_type = 1 # ggml_type::GGML_TYPE_F16
|
||||
n_routed_experts = 6
|
||||
qlen = 30
|
||||
layer_num = 10
|
||||
CPUInfer = cpuinfer_ext.CPUInfer(48)
|
||||
validation_iter = 100
|
||||
|
||||
def act_fn(x):
|
||||
return x / (1.0 + torch.exp(-x))
|
||||
|
||||
def mlp_torch(input, gate_proj, up_proj, down_proj):
|
||||
gate_buf = torch.mm(input, gate_proj.t())
|
||||
up_buf = torch.mm(input, up_proj.t())
|
||||
intermediate = act_fn(gate_buf) * up_buf
|
||||
ret = torch.mm(intermediate, down_proj.t())
|
||||
return ret
|
||||
|
||||
def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
|
||||
cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))
|
||||
cnts.scatter_(1, expert_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = expert_ids.view(-1).argsort()
|
||||
sorted_tokens = input[idxs // expert_ids.shape[1]]
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
t_output = (
|
||||
new_x.view(*expert_ids.shape, -1)
|
||||
.type(weights.dtype)
|
||||
.mul_(weights.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return t_output
|
||||
|
||||
with torch.inference_mode(mode=True):
|
||||
moes = []
|
||||
gate_projs = []
|
||||
up_projs = []
|
||||
|
@ -51,63 +90,32 @@ with torch.inference_mode(mode=True):
|
|||
|
||||
# validation
|
||||
for i in range(validation_iter):
|
||||
moe = moes[i % layer_num]
|
||||
expert_ids = torch.randint(0, expert_num, (qlen, n_routed_experts), dtype=torch.int64).contiguous()
|
||||
expert_ids = torch.stack([torch.randperm(expert_num)[:n_routed_experts] for _ in range(qlen)]).contiguous()
|
||||
weights = torch.rand((qlen, n_routed_experts), dtype=torch.float32).contiguous()
|
||||
input = torch.randn((qlen, 1, hidden_size), dtype=torch.float16).contiguous()
|
||||
output = torch.empty((qlen, 1, hidden_size), dtype=torch.float16).contiguous()
|
||||
input = torch.randn((qlen, hidden_size), dtype=torch.float16).contiguous()
|
||||
output = torch.empty((qlen, hidden_size), dtype=torch.float16).contiguous()
|
||||
input = input / 100
|
||||
|
||||
CPUInfer.submit(moe.forward, qlen, n_routed_experts, expert_ids.data_ptr(), weights.data_ptr(), input.data_ptr(), output.data_ptr())
|
||||
moe = moes[i % layer_num]
|
||||
CPUInfer.submit(
|
||||
moe.forward(
|
||||
qlen,
|
||||
n_routed_experts,
|
||||
expert_ids.data_ptr(),
|
||||
weights.data_ptr(),
|
||||
input.data_ptr(),
|
||||
output.data_ptr()
|
||||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
# print('cpuinfer output', output)
|
||||
|
||||
def act_fn(x):
|
||||
return x / (1.0 + torch.exp(-x))
|
||||
t_output = torch.zeros((qlen, 1, hidden_size), dtype=torch.float32).contiguous()
|
||||
gate_proj = gate_projs[i%layer_num]
|
||||
up_proj = up_projs[i%layer_num]
|
||||
down_proj = down_projs[i%layer_num]
|
||||
for token_idx in range(qlen):
|
||||
for i, expert_id in enumerate(expert_ids[token_idx]):
|
||||
gate_buf = torch.mm(input[token_idx], gate_proj[expert_id].t())
|
||||
up_buf = torch.mm(input[token_idx], up_proj[expert_id].t())
|
||||
intermediate = act_fn(gate_buf) * up_buf
|
||||
expert_output = torch.mm(intermediate, down_proj[expert_id].t())
|
||||
t_output[token_idx] += weights[token_idx][i] * expert_output
|
||||
t_output = moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj)
|
||||
# print('torch output', t_output)
|
||||
|
||||
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))
|
||||
print('diff = ', diff)
|
||||
assert(diff < 0.001)
|
||||
|
||||
# warm up
|
||||
for i in range(warm_up_iter):
|
||||
moe = moes[i % layer_num]
|
||||
expert_ids = torch.randint(0, expert_num, (qlen, n_routed_experts), dtype=torch.int64).contiguous()
|
||||
weights = torch.rand((qlen, n_routed_experts), dtype=torch.float32).contiguous()
|
||||
input = torch.randn((qlen, hidden_size), dtype=torch.float16).contiguous()
|
||||
output = torch.empty((qlen, hidden_size), dtype=torch.float16).contiguous()
|
||||
input = input / 100
|
||||
CPUInfer.submit(moe.forward, qlen, n_routed_experts, expert_ids.data_ptr(), weights.data_ptr(), input.data_ptr(), output.data_ptr())
|
||||
CPUInfer.sync()
|
||||
|
||||
# test
|
||||
total_time = 0
|
||||
for i in range(test_iter):
|
||||
moe = moes[i % layer_num]
|
||||
expert_ids = torch.randint(0, expert_num, (qlen, n_routed_experts), dtype=torch.int64).contiguous()
|
||||
weights = torch.rand((qlen, n_routed_experts), dtype=torch.float32).contiguous()
|
||||
input = torch.randn((qlen, hidden_size), dtype=torch.float16).contiguous()
|
||||
output = torch.empty((qlen, hidden_size), dtype=torch.float16).contiguous()
|
||||
input = input / 100
|
||||
start = time.perf_counter()
|
||||
CPUInfer.submit(moe.forward, qlen, n_routed_experts, expert_ids.data_ptr(), weights.data_ptr(), input.data_ptr(), output.data_ptr())
|
||||
CPUInfer.sync()
|
||||
end = time.perf_counter()
|
||||
total_time += end - start
|
||||
print('Time: ', total_time)
|
||||
print('Iteration: ', test_iter)
|
||||
print('Time per iteration: ', total_time / test_iter)
|
||||
print('Bandwidth: ', hidden_size * intermediate_size * 3 * n_routed_experts * 2 * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')
|
||||
print("All tasks completed.")
|
Loading…
Add table
Add a link
Reference in a new issue