fix precision bug imported by position_ids in 0.2.0

This commit is contained in:
Atream 2025-02-17 09:23:14 +00:00
parent b84524622e
commit 038bc30888
10 changed files with 471 additions and 45 deletions

View file

@ -17,6 +17,7 @@ from ktransformers.operators import base_operator
from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.util.textstream import TextStreamer
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
warm_uped = False
@ -87,7 +88,8 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
module.load()
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
mode = 'normal', force_think: bool = False):
mode = 'normal', force_think: bool = False, use_flashinfer_mla = False,
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._dynamo.config.suppress_errors = True
@ -137,7 +139,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
)
else:
past_key_values = None
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.long)
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
generated_ids = torch.zeros(
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
)
@ -182,7 +184,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
generated_ids[:, seq_length] = next_token
tokens.append(int(next_token))
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.long)
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32)
position_ids = cache_position.unsqueeze(0)
seq_length += 1
@ -195,7 +197,10 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
warm_uped = True
cuda_graph_runner = CUDAGraphRunner()
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
if i > 1 and use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()