mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 12:40:02 +00:00
support smt and qlm4
This commit is contained in:
parent
712ad1fa3c
commit
48bc6185b5
9 changed files with 65 additions and 74 deletions
|
@ -80,27 +80,11 @@ class KGlm4MoeForCausalLM(Glm4MoePreTrainedModel):
|
||||||
|
|
||||||
freqs_cis = self.model.rotary_emb(hidden_states.unsqueeze(0), batch.minibatch.position_ids.unsqueeze(0))
|
freqs_cis = self.model.rotary_emb(hidden_states.unsqueeze(0), batch.minibatch.position_ids.unsqueeze(0))
|
||||||
|
|
||||||
|
|
||||||
with torch.cuda.stream(current_stream):
|
with torch.cuda.stream(current_stream):
|
||||||
residual = torch.zeros_like(hidden_states)
|
residual = torch.zeros_like(hidden_states)
|
||||||
for i, decode_layer in enumerate(self.model.layers):
|
for i, decode_layer in enumerate(self.model.layers):
|
||||||
if self.model.transfer_map is not None and i in self.model.transfer_map:
|
|
||||||
prev_stream = torch.cuda.current_stream()
|
|
||||||
cur_device = self.model.transfer_map[i]
|
|
||||||
if cur_device not in self.model.stream_device_map:
|
|
||||||
self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
|
|
||||||
torch.cuda.set_device(cur_device)
|
|
||||||
self.model.stream_device_map[cur_device].wait_stream(prev_stream)
|
|
||||||
torch.cuda.set_stream(self.model.stream_device_map[cur_device])
|
|
||||||
hidden_states = hidden_states.to(
|
|
||||||
self.model.transfer_map[i], non_blocking=True
|
|
||||||
)
|
|
||||||
|
|
||||||
batch.minibatch.position_ids = (
|
|
||||||
batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)
|
|
||||||
if batch.minibatch.position_ids is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
router_input = hidden_states.clone()
|
|
||||||
hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)
|
hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||||
hidden_states = decode_layer.self_attn(hidden_states, self.cache,
|
hidden_states = decode_layer.self_attn(hidden_states, self.cache,
|
||||||
freqs_cis,
|
freqs_cis,
|
||||||
|
@ -110,9 +94,9 @@ class KGlm4MoeForCausalLM(Glm4MoePreTrainedModel):
|
||||||
|
|
||||||
hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)
|
hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||||
if i < self.model.config.first_k_dense_replace:
|
if i < self.model.config.first_k_dense_replace:
|
||||||
hidden_states = decode_layer.feed_forward(router_input, hidden_states, num_tokens_tensors)
|
hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors)
|
||||||
else:
|
else:
|
||||||
hidden_states = decode_layer.feed_forward(hidden_states, num_tokens_tensors, cuda_graph_idx)
|
hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors, cuda_graph_idx)
|
||||||
# hidden_states = hidden_states.squeeze(0)
|
# hidden_states = hidden_states.squeeze(0)
|
||||||
|
|
||||||
forward_batch_output = ForwardBatchOutput()
|
forward_batch_output = ForwardBatchOutput()
|
||||||
|
|
|
@ -625,7 +625,7 @@ class Glm4MoeForCausalLM(Glm4MoePreTrainedModel, GenerationMixin):
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
**kwargs,
|
# **kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs.last_hidden_state
|
hidden_states = outputs.last_hidden_state
|
||||||
|
@ -635,7 +635,7 @@ class Glm4MoeForCausalLM(Glm4MoePreTrainedModel, GenerationMixin):
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
return CausalLMOutputWithPast(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
|
|
|
@ -522,7 +522,9 @@ class KGlm4MoeRotaryEmbedding(BaseInjectedModule, Glm4MoeRotaryEmbedding):
|
||||||
device_type = x.device.type
|
device_type = x.device.type
|
||||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||||
with torch.autocast(device_type=device_type, enabled=False):
|
with torch.autocast(device_type=device_type, enabled=False):
|
||||||
freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
freqs_cis = freqs_cis * self.attention_scaling
|
cos = emb.cos() * self.attention_scaling
|
||||||
return freqs_cis
|
sin = emb.sin() * self.attention_scaling
|
||||||
|
|
||||||
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
@ -568,15 +568,31 @@ class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
|
||||||
|
|
||||||
def apply_rotary_pos_emb(
|
def apply_rotary_pos_emb(
|
||||||
self,
|
self,
|
||||||
xq: torch.Tensor,
|
q: torch.Tensor,
|
||||||
xk: torch.Tensor,
|
k: torch.Tensor,
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
unsqueeze_dim=2
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
|
||||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
# Keep half or full tensor for later concatenation
|
||||||
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
|
cos = freqs_cis[0]
|
||||||
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
|
sin = freqs_cis[1]
|
||||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
rotary_dim = cos.shape[-1]
|
||||||
|
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
|
|
||||||
|
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
|
||||||
|
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
|
||||||
|
|
||||||
|
# Apply rotary embeddings on the first half or full tensor
|
||||||
|
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
|
||||||
|
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
|
||||||
|
|
||||||
|
# Concatenate back to full shape
|
||||||
|
q_embed = torch.cat([q_embed, q_pass], dim=-1)
|
||||||
|
k_embed = torch.cat([k_embed, k_pass], dim=-1)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
@ -587,18 +603,20 @@ class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
|
||||||
position_ids: torch.Tensor = None,
|
position_ids: torch.Tensor = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if self.use_qk_norm:
|
|
||||||
query_states = self.q_norm(query_states)
|
|
||||||
key_states = self.k_norm(key_states)
|
|
||||||
|
|
||||||
q_len, _ = hidden_states.size()
|
q_len, _ = hidden_states.size()
|
||||||
query_states = self.q_proj(hidden_states, bsz_tensors)
|
query_states = self.q_proj(hidden_states, bsz_tensors)
|
||||||
key_states = self.k_proj(hidden_states, bsz_tensors)
|
key_states = self.k_proj(hidden_states, bsz_tensors)
|
||||||
value_states = self.v_proj(hidden_states, bsz_tensors)
|
value_states = self.v_proj(hidden_states, bsz_tensors)
|
||||||
|
|
||||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
|
||||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
if self.use_qk_norm:
|
||||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
query_states = self.q_norm(query_states, bsz_tensors)
|
||||||
|
key_states = self.k_norm(key_states, bsz_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim)
|
||||||
|
key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
|
||||||
|
value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
|
||||||
|
|
||||||
# cos, sin = freqs_cis
|
# cos, sin = freqs_cis
|
||||||
"""
|
"""
|
||||||
|
@ -607,14 +625,14 @@ class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
|
||||||
print(cos.shape)
|
print(cos.shape)
|
||||||
print(sin.shape)
|
print(sin.shape)
|
||||||
"""
|
"""
|
||||||
if freqs_cis:
|
if freqs_cis is not None:
|
||||||
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
|
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim)
|
||||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
|
||||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim)
|
||||||
|
|
||||||
k_cache = kv_cache.get_k_cache(self.layer_idx)
|
k_cache = kv_cache.get_k_cache(self.layer_idx)
|
||||||
v_cache = kv_cache.get_v_cache(self.layer_idx)
|
v_cache = kv_cache.get_v_cache(self.layer_idx)
|
||||||
|
@ -623,6 +641,6 @@ class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
|
||||||
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
|
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
|
||||||
|
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)
|
attn_output = self.o_proj(attn_output.view(q_len, self.config.num_attention_heads * self.head_dim), bsz_tensors)
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
|
@ -1840,31 +1840,13 @@ class KGlm4MoeMoE(BaseInjectedModule, Glm4MoeMoE):
|
||||||
orig_shape = hidden_states.shape
|
orig_shape = hidden_states.shape
|
||||||
sequence_length = orig_shape[1]
|
sequence_length = orig_shape[1]
|
||||||
|
|
||||||
|
topk_idx, topk_weight = self.gate(hidden_states)
|
||||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
|
||||||
if bsz_tensor is None:
|
|
||||||
router_logits = self.gate(hidden_states)
|
|
||||||
else:
|
|
||||||
router_logits = self.gate(hidden_states, bsz_tensor)
|
|
||||||
|
|
||||||
if router_logits.device.type == "xpu":
|
|
||||||
# TODO: support self.moe_primary_router_apply_softmax False case
|
|
||||||
from ipex_llm.transformers.models.common import moe_softmax_topk
|
|
||||||
selected_experts, routing_weights = moe_softmax_topk(
|
|
||||||
router_logits.half(), self.top_k, self.norm_topk_prob
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
|
|
||||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
|
||||||
if self.norm_topk_prob:
|
|
||||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
||||||
# we cast back to the input dtype
|
|
||||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
|
||||||
|
|
||||||
# only for generate phase
|
# only for generate phase
|
||||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||||
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0)
|
||||||
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||||
|
|
||||||
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
|
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
|
||||||
|
@ -1873,29 +1855,29 @@ class KGlm4MoeMoE(BaseInjectedModule, Glm4MoeMoE):
|
||||||
y.resize_(*orig_shape)
|
y.resize_(*orig_shape)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0)
|
||||||
# y_ = (
|
# y_ = (
|
||||||
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
||||||
if isinstance(self.experts, KExpertsBase):
|
if isinstance(self.experts, KExpertsBase):
|
||||||
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||||
elif hidden_states.size(0) > 10:
|
elif hidden_states.size(0) > 10:
|
||||||
# TODO may bugs here
|
# TODO may bugs here
|
||||||
y = (
|
y = (
|
||||||
self.moe_infer(hidden_states, selected_experts, routing_weights)
|
self.moe_infer(hidden_states, topk_idx, topk_weight)
|
||||||
.view(*orig_shape)
|
.view(*orig_shape)
|
||||||
.to(device=hidden_states.device)
|
.to(device=hidden_states.device)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# TODO may bugs here
|
# TODO may bugs here
|
||||||
y = (
|
y = (
|
||||||
self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
|
self.moe_infer_simple(hidden_states, topk_idx, topk_weight)
|
||||||
.view(*orig_shape)
|
.view(*orig_shape)
|
||||||
.to(device=hidden_states.device)
|
.to(device=hidden_states.device)
|
||||||
)
|
)
|
||||||
# y += y_
|
y += y_
|
||||||
return y
|
return y
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
|
@ -64,7 +64,7 @@ class KGlm4MoeMLP(Glm4MoeMLP, BaseInjectedModule):
|
||||||
generate_device: str = "cuda",
|
generate_device: str = "cuda",
|
||||||
**kwargs):
|
**kwargs):
|
||||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||||
self.orig_module.__init__(orig_module.config)
|
self.orig_module.__init__(orig_module.config, orig_module.hidden_size, orig_module.intermediate_size)
|
||||||
def forward(self, x, bsz_tensor):
|
def forward(self, x, bsz_tensor):
|
||||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)
|
||||||
return down_proj
|
return down_proj
|
|
@ -60,7 +60,7 @@
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\..*\\.self_attn$"
|
name: "^model\\.layers\\..*\\.self_attn$"
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.balance_serve_attention.KSmallthinkerAttention # optimized MLA implementation
|
class: ktransformers.operators.balance_serve_attention.KGlm4MoeAttention # optimized MLA implementation
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cuda"
|
generate_device: "cuda"
|
||||||
prefill_device: "cuda"
|
prefill_device: "cuda"
|
||||||
|
|
|
@ -462,6 +462,8 @@ class BalanceServeInterface(BackendInterfaceBase):
|
||||||
profiler.create_and_start_timer("prefill")
|
profiler.create_and_start_timer("prefill")
|
||||||
|
|
||||||
query_add = sched_ext.QueryAdd()
|
query_add = sched_ext.QueryAdd()
|
||||||
|
input_ids = torch.tensor([[151331, 151333, 98964, 117392, 103408, 99668, 3837, 99073, 99444,
|
||||||
|
99052, 101052, 11314]], device='cuda')
|
||||||
query_add.query_token = input_ids[0].tolist()
|
query_add.query_token = input_ids[0].tolist()
|
||||||
query_length = input_ids[0].shape[0]
|
query_length = input_ids[0].shape[0]
|
||||||
query_add.query_length = query_length
|
query_add.query_length = query_length
|
||||||
|
|
|
@ -149,7 +149,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name")
|
parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name")
|
||||||
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
|
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
|
||||||
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
|
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
|
||||||
parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens")
|
parser.add_argument("--max_tokens", type=int, default=500, help="max decode tokens")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
SERVER_URL = args.api_url
|
SERVER_URL = args.api_url
|
||||||
|
@ -161,5 +161,8 @@ if __name__ == "__main__":
|
||||||
prompt = ktansformer_prompt1024 * 2
|
prompt = ktansformer_prompt1024 * 2
|
||||||
elif args.prompt_lens == 4096:
|
elif args.prompt_lens == 4096:
|
||||||
prompt = ktansformer_prompt1024 * 4
|
prompt = ktansformer_prompt1024 * 4
|
||||||
|
|
||||||
|
prompt = "介绍秦始皇"
|
||||||
|
|
||||||
asyncio.run(main(args.concurrent, prompt, max_tokens, model))
|
asyncio.run(main(args.concurrent, prompt, max_tokens, model))
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue