fix precision bug imported by position_ids in 0.2.0

This commit is contained in:
Atream 2025-02-17 09:23:14 +00:00
parent b84524622e
commit 038bc30888
10 changed files with 471 additions and 45 deletions

View file

@ -30,6 +30,7 @@ from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate
from ktransformers.server.config.config import Config
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
@ -170,9 +171,16 @@ def local_chat(
torch.set_default_dtype(
torch.bfloat16
) # TODO: Remove this, replace dtype using config
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode, force_think
)
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled:
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think,
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
)
else:
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think,
)
if __name__ == "__main__":