From 71286ec1c0275bd0dfb57267b64a3e0f9455d68a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=81=E9=B9=8F=E6=B6=9B?= Date: Sat, 1 Mar 2025 21:52:48 +0800 Subject: [PATCH] Update local_chat.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM" 永远为真 --- ktransformers/local_chat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index c60546a..4acaf86 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -169,7 +169,7 @@ def local_chat( assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ "please change max_seq_len in ~/.ktransformers/config.yaml" - if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8: + if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8: generated = prefill_and_generate( model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size, use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim @@ -181,4 +181,4 @@ def local_chat( if __name__ == "__main__": - fire.Fire(local_chat) \ No newline at end of file + fire.Fire(local_chat)