mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 23:34:35 +00:00
GLM4 and SmallThinker
This commit is contained in:
parent
c7307aa0ae
commit
9e1560bb82
7 changed files with 58 additions and 37 deletions
|
@ -10,10 +10,10 @@ message(STATUS "Using compiler: ${CMAKE_CXX_COMPILER}")
|
|||
project(balance_serve VERSION 0.1.0)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
# set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -g -fPIC")
|
||||
# set(CMAKE_BUILD_TYPE "Debug")
|
||||
set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -fPIC")
|
||||
set(CMAKE_BUILD_TYPE "Release")
|
||||
set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -g -fPIC")
|
||||
set(CMAKE_BUILD_TYPE "Debug")
|
||||
# set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -fPIC")
|
||||
# set(CMAKE_BUILD_TYPE "Release")
|
||||
|
||||
|
||||
if(NOT DEFINED _GLIBCXX_USE_CXX11_ABI)
|
||||
|
|
|
@ -69,6 +69,12 @@ static inline __m512 act_fn(__m512 gate_val, __m512 up_val) {
|
|||
return _mm512_mul_ps(act_val, up_val);
|
||||
}
|
||||
|
||||
static inline __m512 relu_act_fn(__m512 gate_val, __m512 up_val) {
|
||||
__m512 zero_vec = _mm512_setzero_ps();
|
||||
__m512 act_val = _mm512_max_ps(zero_vec, gate_val);
|
||||
return _mm512_mul_ps(act_val, up_val);
|
||||
}
|
||||
|
||||
struct AMX_MOEConfig {
|
||||
int expert_num;
|
||||
int routed_expert_num;
|
||||
|
@ -337,6 +343,7 @@ public:
|
|||
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
|
||||
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
|
||||
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
|
||||
if (config_.use_silu) {
|
||||
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
|
||||
ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
|
@ -349,6 +356,22 @@ public:
|
|||
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
|
||||
ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
for (int j = n_start; j < n_end; j += 32) {
|
||||
__m512 gate_val0, gate_val1, up_val0, up_val1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);
|
||||
avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);
|
||||
__m512 result0 = relu_act_fn(gate_val0, up_val0);
|
||||
__m512 result1 = relu_act_fn(gate_val1, up_val1);
|
||||
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
},
|
||||
nullptr);
|
||||
backend->do_work_stealing_job(
|
||||
|
|
|
@ -194,7 +194,7 @@ class KExpertsCPU(KExpertsBase):
|
|||
64,
|
||||
10,
|
||||
1024,
|
||||
self.config.model_type != "smallthinker",
|
||||
self.config.hidden_act == 'silu',
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
down_ptr,
|
||||
|
@ -215,7 +215,7 @@ class KExpertsCPU(KExpertsBase):
|
|||
self.config.hidden_size,
|
||||
self.config.moe_intermediate_size,
|
||||
max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size,
|
||||
self.config.model_type != "smallthinker",
|
||||
self.config.hidden_act == 'silu',
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
down_ptr,
|
||||
|
@ -234,7 +234,7 @@ class KExpertsCPU(KExpertsBase):
|
|||
self.config.hidden_size,
|
||||
self.config.moe_intermediate_size,
|
||||
max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size,
|
||||
self.config.model_type != "smallthinker",
|
||||
self.config.hidden_act == 'silu',
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
down_ptr,
|
||||
|
|
|
@ -88,9 +88,16 @@ class KLinearBase(ABC):
|
|||
if isinstance(self.gguf_loader, SafeTensorLoader):
|
||||
# using safetensor_loader
|
||||
tensor = self.gguf_loader.load_tensor(key+'.weight')
|
||||
try:
|
||||
bias = self.gguf_loader.load_tensor(key+'.bias')
|
||||
except:
|
||||
bias = None
|
||||
if self.gguf_loader.has_tensor(key+'.weight_scale_inv'):
|
||||
weight_scale_inv = self.gguf_loader.load_tensor(key+'.weight_scale_inv')
|
||||
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
|
||||
if bias is not None:
|
||||
return nn.Parameter(tensor), nn.Parameter(bias)
|
||||
else:
|
||||
return nn.Parameter(tensor)
|
||||
|
||||
elif self.gguf_loader.has_tensor(key + ".weight") or "kv_b_proj" in key:
|
||||
|
|
|
@ -138,17 +138,11 @@ class ArgumentParser:
|
|||
self.cfg.server_ip = args.host
|
||||
self.cfg.server_port = args.port
|
||||
self.cfg.user_force_think = args.force_think
|
||||
|
||||
if args.model_name == "Qwen3MoeForCausalLM":
|
||||
model_config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
elif args.model_name == "Glm4MoeForCausalLM":
|
||||
model_config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
elif args.model_name == "SmallThinkerForCausalLM":
|
||||
model_config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
model_config._attn_implementation = "eager"
|
||||
else:
|
||||
try:
|
||||
model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
except:
|
||||
try:
|
||||
model_config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
except:
|
||||
raise ValueError(f"Model {args.model_name} not supported. Please check your model directory or model name.")
|
||||
|
||||
|
|
|
@ -129,13 +129,13 @@ class Engine:
|
|||
self.sched_client = SchedulerClient(args.sched_port)
|
||||
self.updates = []
|
||||
|
||||
print(f"args.model_name: {args.model_name}")
|
||||
print(f"args.architectures: {args.architectures}")
|
||||
|
||||
if args.model_name == "Qwen3MoeForCausalLM":
|
||||
if args.architectures == "Qwen3MoeForCausalLM":
|
||||
config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
elif args.model_name == "Glm4MoeForCausalLM":
|
||||
elif args.architectures == "Glm4MoeForCausalLM":
|
||||
config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
elif args.model_name == "SmallThinkerForCausalLM":
|
||||
elif args.architectures == "SmallThinkerForCausalLM":
|
||||
config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
config._attn_implementation = "eager"
|
||||
config.moe_intermediate_size = config.moe_ffn_hidden_size
|
||||
|
@ -143,7 +143,7 @@ class Engine:
|
|||
try:
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
except:
|
||||
raise ValueError(f"Model {args.model_name} not supported. Please check your model directory or model name.")
|
||||
raise ValueError(f"Model {args.architectures} not supported. Please check your model directory or model name.")
|
||||
|
||||
|
||||
|
||||
|
@ -463,8 +463,6 @@ class BalanceServeInterface(BackendInterfaceBase):
|
|||
profiler.create_and_start_timer("prefill")
|
||||
|
||||
query_add = sched_ext.QueryAdd()
|
||||
# input_ids = torch.tensor([[151331, 151333, 98964, 117392, 103408, 99668, 3837, 99073, 99444,
|
||||
# 99052, 101052, 11314]], device='cuda')
|
||||
query_add.query_token = input_ids[0].tolist()
|
||||
query_length = input_ids[0].shape[0]
|
||||
query_add.query_length = query_length
|
||||
|
|
|
@ -162,7 +162,6 @@ if __name__ == "__main__":
|
|||
elif args.prompt_lens == 4096:
|
||||
prompt = ktansformer_prompt1024 * 4
|
||||
|
||||
prompt = "介绍秦始皇"
|
||||
|
||||
asyncio.run(main(args.concurrent, prompt, max_tokens, model))
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue