GLM4 and SmallThinker

This commit is contained in:
qiyuxinlin 2025-07-25 16:56:36 +00:00
parent c7307aa0ae
commit 9e1560bb82
7 changed files with 58 additions and 37 deletions

View file

@ -10,10 +10,10 @@ message(STATUS "Using compiler: ${CMAKE_CXX_COMPILER}")
project(balance_serve VERSION 0.1.0) project(balance_serve VERSION 0.1.0)
set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD 20)
# set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -g -fPIC") set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -g -fPIC")
# set(CMAKE_BUILD_TYPE "Debug") set(CMAKE_BUILD_TYPE "Debug")
set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -fPIC") # set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -fPIC")
set(CMAKE_BUILD_TYPE "Release") # set(CMAKE_BUILD_TYPE "Release")
if(NOT DEFINED _GLIBCXX_USE_CXX11_ABI) if(NOT DEFINED _GLIBCXX_USE_CXX11_ABI)

View file

@ -69,6 +69,12 @@ static inline __m512 act_fn(__m512 gate_val, __m512 up_val) {
return _mm512_mul_ps(act_val, up_val); return _mm512_mul_ps(act_val, up_val);
} }
static inline __m512 relu_act_fn(__m512 gate_val, __m512 up_val) {
__m512 zero_vec = _mm512_setzero_ps();
__m512 act_val = _mm512_max_ps(zero_vec, gate_val);
return _mm512_mul_ps(act_val, up_val);
}
struct AMX_MOEConfig { struct AMX_MOEConfig {
int expert_num; int expert_num;
int routed_expert_num; int routed_expert_num;
@ -337,18 +343,35 @@ public:
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth); gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth); up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth); auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
for (int i = 0; i < m_local_num_[expert_idx]; i++) { if (config_.use_silu) {
ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size]; for (int i = 0; i < m_local_num_[expert_idx]; i++) {
ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size]; ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
for (int j = n_start; j < n_end; j += 32) { ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
__m512 gate_val0, gate_val1, up_val0, up_val1; for (int j = n_start; j < n_end; j += 32) {
avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1); __m512 gate_val0, gate_val1, up_val0, up_val1;
avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1); avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);
__m512 result0 = act_fn(gate_val0, up_val0); avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);
__m512 result1 = act_fn(gate_val1, up_val1); __m512 result0 = act_fn(gate_val0, up_val0);
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j)); __m512 result1 = act_fn(gate_val1, up_val1);
} avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));
}
}
} }
else {
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
for (int j = n_start; j < n_end; j += 32) {
__m512 gate_val0, gate_val1, up_val0, up_val1;
avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);
avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);
__m512 result0 = relu_act_fn(gate_val0, up_val0);
__m512 result1 = relu_act_fn(gate_val1, up_val1);
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));
}
}
}
}, },
nullptr); nullptr);
backend->do_work_stealing_job( backend->do_work_stealing_job(

View file

@ -194,7 +194,7 @@ class KExpertsCPU(KExpertsBase):
64, 64,
10, 10,
1024, 1024,
self.config.model_type != "smallthinker", self.config.hidden_act == 'silu',
gate_ptr, gate_ptr,
up_ptr, up_ptr,
down_ptr, down_ptr,
@ -215,7 +215,7 @@ class KExpertsCPU(KExpertsBase):
self.config.hidden_size, self.config.hidden_size,
self.config.moe_intermediate_size, self.config.moe_intermediate_size,
max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size, max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size,
self.config.model_type != "smallthinker", self.config.hidden_act == 'silu',
gate_ptr, gate_ptr,
up_ptr, up_ptr,
down_ptr, down_ptr,
@ -234,7 +234,7 @@ class KExpertsCPU(KExpertsBase):
self.config.hidden_size, self.config.hidden_size,
self.config.moe_intermediate_size, self.config.moe_intermediate_size,
max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size, max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size,
self.config.model_type != "smallthinker", self.config.hidden_act == 'silu',
gate_ptr, gate_ptr,
up_ptr, up_ptr,
down_ptr, down_ptr,

View file

@ -88,10 +88,17 @@ class KLinearBase(ABC):
if isinstance(self.gguf_loader, SafeTensorLoader): if isinstance(self.gguf_loader, SafeTensorLoader):
# using safetensor_loader # using safetensor_loader
tensor = self.gguf_loader.load_tensor(key+'.weight') tensor = self.gguf_loader.load_tensor(key+'.weight')
try:
bias = self.gguf_loader.load_tensor(key+'.bias')
except:
bias = None
if self.gguf_loader.has_tensor(key+'.weight_scale_inv'): if self.gguf_loader.has_tensor(key+'.weight_scale_inv'):
weight_scale_inv = self.gguf_loader.load_tensor(key+'.weight_scale_inv') weight_scale_inv = self.gguf_loader.load_tensor(key+'.weight_scale_inv')
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv) return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
return nn.Parameter(tensor) if bias is not None:
return nn.Parameter(tensor), nn.Parameter(bias)
else:
return nn.Parameter(tensor)
elif self.gguf_loader.has_tensor(key + ".weight") or "kv_b_proj" in key: elif self.gguf_loader.has_tensor(key + ".weight") or "kv_b_proj" in key:
if key + ".bias" in self.gguf_loader.tensor_file_map: if key + ".bias" in self.gguf_loader.tensor_file_map:

View file

@ -138,17 +138,11 @@ class ArgumentParser:
self.cfg.server_ip = args.host self.cfg.server_ip = args.host
self.cfg.server_port = args.port self.cfg.server_port = args.port
self.cfg.user_force_think = args.force_think self.cfg.user_force_think = args.force_think
try:
if args.model_name == "Qwen3MoeForCausalLM": model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
model_config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True) except:
elif args.model_name == "Glm4MoeForCausalLM":
model_config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
elif args.model_name == "SmallThinkerForCausalLM":
model_config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
model_config._attn_implementation = "eager"
else:
try: try:
model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) model_config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
except: except:
raise ValueError(f"Model {args.model_name} not supported. Please check your model directory or model name.") raise ValueError(f"Model {args.model_name} not supported. Please check your model directory or model name.")

View file

@ -129,13 +129,13 @@ class Engine:
self.sched_client = SchedulerClient(args.sched_port) self.sched_client = SchedulerClient(args.sched_port)
self.updates = [] self.updates = []
print(f"args.model_name: {args.model_name}") print(f"args.architectures: {args.architectures}")
if args.model_name == "Qwen3MoeForCausalLM": if args.architectures == "Qwen3MoeForCausalLM":
config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True) config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
elif args.model_name == "Glm4MoeForCausalLM": elif args.architectures == "Glm4MoeForCausalLM":
config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True) config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
elif args.model_name == "SmallThinkerForCausalLM": elif args.architectures == "SmallThinkerForCausalLM":
config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True) config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
config._attn_implementation = "eager" config._attn_implementation = "eager"
config.moe_intermediate_size = config.moe_ffn_hidden_size config.moe_intermediate_size = config.moe_ffn_hidden_size
@ -143,7 +143,7 @@ class Engine:
try: try:
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
except: except:
raise ValueError(f"Model {args.model_name} not supported. Please check your model directory or model name.") raise ValueError(f"Model {args.architectures} not supported. Please check your model directory or model name.")
@ -463,8 +463,6 @@ class BalanceServeInterface(BackendInterfaceBase):
profiler.create_and_start_timer("prefill") profiler.create_and_start_timer("prefill")
query_add = sched_ext.QueryAdd() query_add = sched_ext.QueryAdd()
# input_ids = torch.tensor([[151331, 151333, 98964, 117392, 103408, 99668, 3837, 99073, 99444,
# 99052, 101052, 11314]], device='cuda')
query_add.query_token = input_ids[0].tolist() query_add.query_token = input_ids[0].tolist()
query_length = input_ids[0].shape[0] query_length = input_ids[0].shape[0]
query_add.query_length = query_length query_add.query_length = query_length

View file

@ -162,7 +162,6 @@ if __name__ == "__main__":
elif args.prompt_lens == 4096: elif args.prompt_lens == 4096:
prompt = ktansformer_prompt1024 * 4 prompt = ktansformer_prompt1024 * 4
prompt = "介绍秦始皇"
asyncio.run(main(args.concurrent, prompt, max_tokens, model)) asyncio.run(main(args.concurrent, prompt, max_tokens, model))