mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
fix precision bug imported by position_ids in 0.2.0
This commit is contained in:
parent
b84524622e
commit
038bc30888
10 changed files with 471 additions and 45 deletions
|
@ -17,6 +17,7 @@ from ktransformers.operators import base_operator
|
|||
from ktransformers.models.custom_cache import StaticCache
|
||||
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||
from ktransformers.util.textstream import TextStreamer
|
||||
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
|
||||
|
||||
warm_uped = False
|
||||
|
||||
|
@ -87,7 +88,8 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
|
|||
module.load()
|
||||
|
||||
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
|
||||
mode = 'normal', force_think: bool = False):
|
||||
mode = 'normal', force_think: bool = False, use_flashinfer_mla = False,
|
||||
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
|
||||
import os
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
torch._dynamo.config.suppress_errors = True
|
||||
|
@ -137,7 +139,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
)
|
||||
else:
|
||||
past_key_values = None
|
||||
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.long)
|
||||
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
|
||||
generated_ids = torch.zeros(
|
||||
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
|
||||
)
|
||||
|
@ -182,7 +184,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
generated_ids[:, seq_length] = next_token
|
||||
tokens.append(int(next_token))
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.long)
|
||||
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
seq_length += 1
|
||||
|
||||
|
@ -195,7 +197,10 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
warm_uped = True
|
||||
cuda_graph_runner = CUDAGraphRunner()
|
||||
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
||||
|
||||
if i > 1 and use_flashinfer_mla:
|
||||
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
|
||||
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
|
||||
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
|
||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
generated_ids[:, cache_position] = next_token.int()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue