diff --git a/ktransformers/models/custom_modeling_glm4_moe.py b/ktransformers/models/custom_modeling_glm4_moe.py index b5986c3..f33c177 100644 --- a/ktransformers/models/custom_modeling_glm4_moe.py +++ b/ktransformers/models/custom_modeling_glm4_moe.py @@ -80,27 +80,11 @@ class KGlm4MoeForCausalLM(Glm4MoePreTrainedModel): freqs_cis = self.model.rotary_emb(hidden_states.unsqueeze(0), batch.minibatch.position_ids.unsqueeze(0)) + with torch.cuda.stream(current_stream): residual = torch.zeros_like(hidden_states) for i, decode_layer in enumerate(self.model.layers): - if self.model.transfer_map is not None and i in self.model.transfer_map: - prev_stream = torch.cuda.current_stream() - cur_device = self.model.transfer_map[i] - if cur_device not in self.model.stream_device_map: - self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) - torch.cuda.set_device(cur_device) - self.model.stream_device_map[cur_device].wait_stream(prev_stream) - torch.cuda.set_stream(self.model.stream_device_map[cur_device]) - hidden_states = hidden_states.to( - self.model.transfer_map[i], non_blocking=True - ) - batch.minibatch.position_ids = ( - batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True) - if batch.minibatch.position_ids is not None - else None - ) - router_input = hidden_states.clone() hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual) hidden_states = decode_layer.self_attn(hidden_states, self.cache, freqs_cis, @@ -110,9 +94,9 @@ class KGlm4MoeForCausalLM(Glm4MoePreTrainedModel): hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual) if i < self.model.config.first_k_dense_replace: - hidden_states = decode_layer.feed_forward(router_input, hidden_states, num_tokens_tensors) + hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors) else: - hidden_states = decode_layer.feed_forward(hidden_states, num_tokens_tensors, cuda_graph_idx) + hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors, cuda_graph_idx) # hidden_states = hidden_states.squeeze(0) forward_batch_output = ForwardBatchOutput() diff --git a/ktransformers/models/modeling_glm4_moe.py b/ktransformers/models/modeling_glm4_moe.py index 9709ada..32727a8 100644 --- a/ktransformers/models/modeling_glm4_moe.py +++ b/ktransformers/models/modeling_glm4_moe.py @@ -625,7 +625,7 @@ class Glm4MoeForCausalLM(Glm4MoePreTrainedModel, GenerationMixin): inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, - **kwargs, + # **kwargs, ) hidden_states = outputs.last_hidden_state @@ -635,7 +635,7 @@ class Glm4MoeForCausalLM(Glm4MoePreTrainedModel, GenerationMixin): loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) return CausalLMOutputWithPast( loss=loss, diff --git a/ktransformers/operators/RoPE.py b/ktransformers/operators/RoPE.py index 2a60a56..968c7b9 100644 --- a/ktransformers/operators/RoPE.py +++ b/ktransformers/operators/RoPE.py @@ -522,7 +522,9 @@ class KGlm4MoeRotaryEmbedding(BaseInjectedModule, Glm4MoeRotaryEmbedding): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) - freqs_cis = freqs_cis * self.attention_scaling - return freqs_cis \ No newline at end of file + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) \ No newline at end of file diff --git a/ktransformers/operators/balance_serve_attention.py b/ktransformers/operators/balance_serve_attention.py index 2b06263..d493329 100644 --- a/ktransformers/operators/balance_serve_attention.py +++ b/ktransformers/operators/balance_serve_attention.py @@ -568,15 +568,31 @@ class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention): def apply_rotary_pos_emb( self, - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], + unsqueeze_dim=2 ) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) + + # Keep half or full tensor for later concatenation + cos = freqs_cis[0] + sin = freqs_cis[1] + rotary_dim = cos.shape[-1] + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed def forward(self, hidden_states: torch.Tensor, @@ -587,18 +603,20 @@ class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention): position_ids: torch.Tensor = None, ): - if self.use_qk_norm: - query_states = self.q_norm(query_states) - key_states = self.k_norm(key_states) - q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states, bsz_tensors) key_states = self.k_proj(hidden_states, bsz_tensors) value_states = self.v_proj(hidden_states, bsz_tensors) - query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim) - key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) + + if self.use_qk_norm: + query_states = self.q_norm(query_states, bsz_tensors) + key_states = self.k_norm(key_states, bsz_tensors) + + + query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim) + key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim) + value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim) # cos, sin = freqs_cis """ @@ -607,14 +625,14 @@ class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention): print(cos.shape) print(sin.shape) """ - if freqs_cis: + if freqs_cis is not None: query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis) - query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim) - key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) + query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim) + key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim) + value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim) k_cache = kv_cache.get_k_cache(self.layer_idx) v_cache = kv_cache.get_v_cache(self.layer_idx) @@ -623,6 +641,6 @@ class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention): attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states) - attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors) + attn_output = self.o_proj(attn_output.view(q_len, self.config.num_attention_heads * self.head_dim), bsz_tensors) return attn_output \ No newline at end of file diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 69bd5d0..466c0d6 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -1840,31 +1840,13 @@ class KGlm4MoeMoE(BaseInjectedModule, Glm4MoeMoE): orig_shape = hidden_states.shape sequence_length = orig_shape[1] + topk_idx, topk_weight = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - if bsz_tensor is None: - router_logits = self.gate(hidden_states) - else: - router_logits = self.gate(hidden_states, bsz_tensor) - - if router_logits.device.type == "xpu": - # TODO: support self.moe_primary_router_apply_softmax False case - from ipex_llm.transformers.models.common import moe_softmax_topk - selected_experts, routing_weights = moe_softmax_topk( - router_logits.half(), self.top_k, self.norm_topk_prob - ) - else: - routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - if self.norm_topk_prob: - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - # we cast back to the input dtype - routing_weights = routing_weights.to(hidden_states.dtype) - # only for generate phase if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug - self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx) - y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) + self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx) + y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0) # y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0) @@ -1873,29 +1855,29 @@ class KGlm4MoeMoE(BaseInjectedModule, Glm4MoeMoE): y.resize_(*orig_shape) return y - # y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) + y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0) # y_ = ( # F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ # ) if isinstance(self.experts, KExpertsBase): - y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device) + y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device) elif hidden_states.size(0) > 10: # TODO may bugs here y = ( - self.moe_infer(hidden_states, selected_experts, routing_weights) + self.moe_infer(hidden_states, topk_idx, topk_weight) .view(*orig_shape) .to(device=hidden_states.device) ) else: # TODO may bugs here y = ( - self.moe_infer_simple(hidden_states, selected_experts, routing_weights) + self.moe_infer_simple(hidden_states, topk_idx, topk_weight) .view(*orig_shape) .to(device=hidden_states.device) ) - # y += y_ + y += y_ return y @torch.no_grad() diff --git a/ktransformers/operators/mlp.py b/ktransformers/operators/mlp.py index ea6f36d..6d3e812 100644 --- a/ktransformers/operators/mlp.py +++ b/ktransformers/operators/mlp.py @@ -64,7 +64,7 @@ class KGlm4MoeMLP(Glm4MoeMLP, BaseInjectedModule): generate_device: str = "cuda", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) - self.orig_module.__init__(orig_module.config) + self.orig_module.__init__(orig_module.config, orig_module.hidden_size, orig_module.intermediate_size) def forward(self, x, bsz_tensor): down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor) return down_proj \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml b/ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml index 58dc887..56345df 100644 --- a/ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml +++ b/ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml @@ -60,7 +60,7 @@ - match: name: "^model\\.layers\\..*\\.self_attn$" replace: - class: ktransformers.operators.balance_serve_attention.KSmallthinkerAttention # optimized MLA implementation + class: ktransformers.operators.balance_serve_attention.KGlm4MoeAttention # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" diff --git a/ktransformers/server/backend/interfaces/balance_serve.py b/ktransformers/server/backend/interfaces/balance_serve.py index 4a0ab23..9a869a7 100644 --- a/ktransformers/server/backend/interfaces/balance_serve.py +++ b/ktransformers/server/backend/interfaces/balance_serve.py @@ -462,6 +462,8 @@ class BalanceServeInterface(BackendInterfaceBase): profiler.create_and_start_timer("prefill") query_add = sched_ext.QueryAdd() + input_ids = torch.tensor([[151331, 151333, 98964, 117392, 103408, 99668, 3837, 99073, 99444, + 99052, 101052, 11314]], device='cuda') query_add.query_token = input_ids[0].tolist() query_length = input_ids[0].shape[0] query_add.query_length = query_length diff --git a/ktransformers/tests/test_speed.py b/ktransformers/tests/test_speed.py index 6f435b4..a70e7e4 100644 --- a/ktransformers/tests/test_speed.py +++ b/ktransformers/tests/test_speed.py @@ -149,7 +149,7 @@ if __name__ == "__main__": parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name") parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048") parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL") - parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens") + parser.add_argument("--max_tokens", type=int, default=500, help="max decode tokens") args = parser.parse_args() SERVER_URL = args.api_url @@ -161,5 +161,8 @@ if __name__ == "__main__": prompt = ktansformer_prompt1024 * 2 elif args.prompt_lens == 4096: prompt = ktansformer_prompt1024 * 4 + + prompt = "介绍秦始皇" + asyncio.run(main(args.concurrent, prompt, max_tokens, model))