mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-04 03:29:49 +00:00
support smt and qlm4
This commit is contained in:
parent
712ad1fa3c
commit
48bc6185b5
9 changed files with 65 additions and 74 deletions
|
@ -80,27 +80,11 @@ class KGlm4MoeForCausalLM(Glm4MoePreTrainedModel):
|
|||
|
||||
freqs_cis = self.model.rotary_emb(hidden_states.unsqueeze(0), batch.minibatch.position_ids.unsqueeze(0))
|
||||
|
||||
|
||||
with torch.cuda.stream(current_stream):
|
||||
residual = torch.zeros_like(hidden_states)
|
||||
for i, decode_layer in enumerate(self.model.layers):
|
||||
if self.model.transfer_map is not None and i in self.model.transfer_map:
|
||||
prev_stream = torch.cuda.current_stream()
|
||||
cur_device = self.model.transfer_map[i]
|
||||
if cur_device not in self.model.stream_device_map:
|
||||
self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
|
||||
torch.cuda.set_device(cur_device)
|
||||
self.model.stream_device_map[cur_device].wait_stream(prev_stream)
|
||||
torch.cuda.set_stream(self.model.stream_device_map[cur_device])
|
||||
hidden_states = hidden_states.to(
|
||||
self.model.transfer_map[i], non_blocking=True
|
||||
)
|
||||
|
||||
batch.minibatch.position_ids = (
|
||||
batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)
|
||||
if batch.minibatch.position_ids is not None
|
||||
else None
|
||||
)
|
||||
router_input = hidden_states.clone()
|
||||
hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||
hidden_states = decode_layer.self_attn(hidden_states, self.cache,
|
||||
freqs_cis,
|
||||
|
@ -110,9 +94,9 @@ class KGlm4MoeForCausalLM(Glm4MoePreTrainedModel):
|
|||
|
||||
hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||
if i < self.model.config.first_k_dense_replace:
|
||||
hidden_states = decode_layer.feed_forward(router_input, hidden_states, num_tokens_tensors)
|
||||
hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors)
|
||||
else:
|
||||
hidden_states = decode_layer.feed_forward(hidden_states, num_tokens_tensors, cuda_graph_idx)
|
||||
hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors, cuda_graph_idx)
|
||||
# hidden_states = hidden_states.squeeze(0)
|
||||
|
||||
forward_batch_output = ForwardBatchOutput()
|
||||
|
|
|
@ -625,7 +625,7 @@ class Glm4MoeForCausalLM(Glm4MoePreTrainedModel, GenerationMixin):
|
|||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
# **kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
@ -635,7 +635,7 @@ class Glm4MoeForCausalLM(Glm4MoePreTrainedModel, GenerationMixin):
|
|||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
|
|
@ -522,7 +522,9 @@ class KGlm4MoeRotaryEmbedding(BaseInjectedModule, Glm4MoeRotaryEmbedding):
|
|||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
freqs_cis = freqs_cis * self.attention_scaling
|
||||
return freqs_cis
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
@ -568,15 +568,31 @@ class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
|
|||
|
||||
def apply_rotary_pos_emb(
|
||||
self,
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
|
||||
unsqueeze_dim=2
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
# Keep half or full tensor for later concatenation
|
||||
cos = freqs_cis[0]
|
||||
sin = freqs_cis[1]
|
||||
rotary_dim = cos.shape[-1]
|
||||
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
|
||||
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
|
||||
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
|
||||
|
||||
# Apply rotary embeddings on the first half or full tensor
|
||||
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
|
||||
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
|
||||
|
||||
# Concatenate back to full shape
|
||||
q_embed = torch.cat([q_embed, q_pass], dim=-1)
|
||||
k_embed = torch.cat([k_embed, k_pass], dim=-1)
|
||||
return q_embed, k_embed
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
@ -587,18 +603,20 @@ class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
|
|||
position_ids: torch.Tensor = None,
|
||||
):
|
||||
|
||||
if self.use_qk_norm:
|
||||
query_states = self.q_norm(query_states)
|
||||
key_states = self.k_norm(key_states)
|
||||
|
||||
q_len, _ = hidden_states.size()
|
||||
query_states = self.q_proj(hidden_states, bsz_tensors)
|
||||
key_states = self.k_proj(hidden_states, bsz_tensors)
|
||||
value_states = self.v_proj(hidden_states, bsz_tensors)
|
||||
|
||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
if self.use_qk_norm:
|
||||
query_states = self.q_norm(query_states, bsz_tensors)
|
||||
key_states = self.k_norm(key_states, bsz_tensors)
|
||||
|
||||
|
||||
query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
|
||||
|
||||
# cos, sin = freqs_cis
|
||||
"""
|
||||
|
@ -607,14 +625,14 @@ class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
|
|||
print(cos.shape)
|
||||
print(sin.shape)
|
||||
"""
|
||||
if freqs_cis:
|
||||
if freqs_cis is not None:
|
||||
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
|
||||
|
||||
|
||||
|
||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
|
||||
|
||||
k_cache = kv_cache.get_k_cache(self.layer_idx)
|
||||
v_cache = kv_cache.get_v_cache(self.layer_idx)
|
||||
|
@ -623,6 +641,6 @@ class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
|
|||
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
|
||||
|
||||
|
||||
attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)
|
||||
attn_output = self.o_proj(attn_output.view(q_len, self.config.num_attention_heads * self.head_dim), bsz_tensors)
|
||||
|
||||
return attn_output
|
|
@ -1840,31 +1840,13 @@ class KGlm4MoeMoE(BaseInjectedModule, Glm4MoeMoE):
|
|||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
|
||||
topk_idx, topk_weight = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
if bsz_tensor is None:
|
||||
router_logits = self.gate(hidden_states)
|
||||
else:
|
||||
router_logits = self.gate(hidden_states, bsz_tensor)
|
||||
|
||||
if router_logits.device.type == "xpu":
|
||||
# TODO: support self.moe_primary_router_apply_softmax False case
|
||||
from ipex_llm.transformers.models.common import moe_softmax_topk
|
||||
selected_experts, routing_weights = moe_softmax_topk(
|
||||
router_logits.half(), self.top_k, self.norm_topk_prob
|
||||
)
|
||||
else:
|
||||
routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
if self.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
# only for generate phase
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
||||
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
|
||||
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
|
||||
|
@ -1873,29 +1855,29 @@ class KGlm4MoeMoE(BaseInjectedModule, Glm4MoeMoE):
|
|||
y.resize_(*orig_shape)
|
||||
return y
|
||||
|
||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = (
|
||||
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
# )
|
||||
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer(hidden_states, selected_experts, routing_weights)
|
||||
self.moe_infer(hidden_states, topk_idx, topk_weight)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
else:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
|
||||
self.moe_infer_simple(hidden_states, topk_idx, topk_weight)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
# y += y_
|
||||
y += y_
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -64,7 +64,7 @@ class KGlm4MoeMLP(Glm4MoeMLP, BaseInjectedModule):
|
|||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config)
|
||||
self.orig_module.__init__(orig_module.config, orig_module.hidden_size, orig_module.intermediate_size)
|
||||
def forward(self, x, bsz_tensor):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)
|
||||
return down_proj
|
|
@ -60,7 +60,7 @@
|
|||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.balance_serve_attention.KSmallthinkerAttention # optimized MLA implementation
|
||||
class: ktransformers.operators.balance_serve_attention.KGlm4MoeAttention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
|
|
@ -462,6 +462,8 @@ class BalanceServeInterface(BackendInterfaceBase):
|
|||
profiler.create_and_start_timer("prefill")
|
||||
|
||||
query_add = sched_ext.QueryAdd()
|
||||
input_ids = torch.tensor([[151331, 151333, 98964, 117392, 103408, 99668, 3837, 99073, 99444,
|
||||
99052, 101052, 11314]], device='cuda')
|
||||
query_add.query_token = input_ids[0].tolist()
|
||||
query_length = input_ids[0].shape[0]
|
||||
query_add.query_length = query_length
|
||||
|
|
|
@ -149,7 +149,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name")
|
||||
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
|
||||
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
|
||||
parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens")
|
||||
parser.add_argument("--max_tokens", type=int, default=500, help="max decode tokens")
|
||||
|
||||
args = parser.parse_args()
|
||||
SERVER_URL = args.api_url
|
||||
|
@ -161,5 +161,8 @@ if __name__ == "__main__":
|
|||
prompt = ktansformer_prompt1024 * 2
|
||||
elif args.prompt_lens == 4096:
|
||||
prompt = ktansformer_prompt1024 * 4
|
||||
|
||||
prompt = "介绍秦始皇"
|
||||
|
||||
asyncio.run(main(args.concurrent, prompt, max_tokens, model))
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue