diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index d5e74de..5b40455 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -58,7 +58,7 @@ def local_chat( gguf_path: str | None = None, max_new_tokens: int = 300, cpu_infer: int = Config().cpu_infer, - use_cuda_graph: bool = True, + use_cuda_graph: bool = False, prompt_file : str | None = None, mode: str = "normal", force_think: bool = False, @@ -160,6 +160,9 @@ def local_chat( input_tensor = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ) + + # input_tensor = torch.tensor([[0, 6657, 84646]], device=input_tensor.device) + if force_think: token_thinks = torch.tensor([tokenizer.encode("\\n",add_special_tokens=False)],device=input_tensor.device) input_tensor = torch.cat( @@ -181,4 +184,6 @@ def local_chat( if __name__ == "__main__": - fire.Fire(local_chat) \ No newline at end of file + # fire.Fire(local_chat) + # local_chat(model_path="/mnt/data/model/DeepSeek-R1", gguf_path="/mnt/data/model/DeepseekV3-q4km-gguf", cpu_infer=33, force_think=False) + local_chat(model_path="/mnt/data/model/Moonlight-16B-A3B-Instruct", gguf_path="/mnt/data/model/Moonlight-16B-A3B-Instruct-GGUF", cpu_infer=33, force_think=False) \ No newline at end of file diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index 85378ee..b4c5402 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -441,10 +441,10 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): # mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank] # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank] # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank] - attn_output = attn_output.transpose(1, 2) - attn_output = torch.matmul(attn_output, out_absorb.mT) + attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank] + attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim] - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim] attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 21b4830..04c04c5 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -450,9 +450,9 @@ class KExpertsTorch(KExpertsBase): self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype) self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype) - self.up = torch.cat(self.gate, dim=0) + self.up = torch.cat(self.up, dim=0) self.gate = torch.cat(self.gate, dim=0) - self.down = torch.cat(self.gate, dim=0) + self.down = torch.cat(self.down, dim=0) return def unload(self): diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml index 6fb6586..4c8eca2 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml @@ -1,7 +1,7 @@ - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: - class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 + class: ktransformers.operators.RoPE.RotaryEmbeddingV3 kwargs: generate_device: "cuda" prefill_device: "cuda"