mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-07 21:19:51 +00:00
Merge 8c8cb207aa
into ee2ede0412
This commit is contained in:
commit
e204a0bb6b
2 changed files with 25 additions and 11 deletions
|
@ -42,18 +42,22 @@ class KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
|
|||
|
||||
def init_wrapper(self, use_cuda_graph, device, max_batch_size, max_pages):
|
||||
self.use_cuda_graph = use_cuda_graph
|
||||
self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
|
||||
# Increase buffer sizes to be safe
|
||||
self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0)
|
||||
self.qo_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device)
|
||||
self.paged_kv_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device)
|
||||
self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device)
|
||||
# Make sure this buffer is large enough
|
||||
self.paged_kv_indices_buf = torch.empty((max_pages * 2,), dtype=torch.int32, device=device)
|
||||
self.paged_kv_len_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)
|
||||
self.bsz_tensor_buf = torch.empty((1, ), dtype=torch.int32, device=device)
|
||||
|
||||
|
||||
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
|
||||
self.workspace_buffer, use_cuda_graph=use_cuda_graph,
|
||||
qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,
|
||||
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,
|
||||
qo_indptr=self.qo_indptr_buf,
|
||||
kv_indptr=self.paged_kv_indptr_buf,
|
||||
kv_indices=self.paged_kv_indices_buf,
|
||||
kv_len_arr=self.paged_kv_len_buf,
|
||||
bsz_tensor=self.bsz_tensor_buf,
|
||||
backend = "fa2",
|
||||
)
|
||||
|
@ -145,4 +149,4 @@ class KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
|
|||
minibatch = batch.minibatch
|
||||
self.wrapper.plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices,
|
||||
minibatch.kv_len, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_data_type, kv_data_type, bsz_tensors)
|
||||
|
||||
|
||||
|
|
|
@ -52,8 +52,9 @@ class SamplingOptions():
|
|||
self.is_all_greedy = False
|
||||
|
||||
class Sampler(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, device=torch.device('cuda')):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -63,7 +64,11 @@ class Sampler(nn.Module):
|
|||
if sampling_config == None:
|
||||
sampling_config = SamplingOptions()
|
||||
|
||||
logits = logits.contiguous()
|
||||
# Ensure all tensors are on the same device
|
||||
device = logits.device
|
||||
logits = logits.contiguous().to(device)
|
||||
sampling_config.temperatures = sampling_config.temperatures.to(device)
|
||||
|
||||
origin_logits = logits.clone()
|
||||
if sampling_config.is_all_greedy:
|
||||
# Use torch.argmax if all requests use greedy sampling
|
||||
|
@ -71,7 +76,8 @@ class Sampler(nn.Module):
|
|||
batch_next_token_ids = torch.argmax(logits, -1)
|
||||
else:
|
||||
# Post process logits
|
||||
logits.div_(sampling_config.temperatures)
|
||||
safe_temperatures = sampling_config.temperatures.masked_fill(sampling_config.temperatures == 0, 1.0)
|
||||
logits.div_(safe_temperatures)
|
||||
max_top_k_round, batch_size = 32, logits.shape[0]
|
||||
if sampling_config.need_min_p_sampling:
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
|
@ -82,8 +88,10 @@ class Sampler(nn.Module):
|
|||
batch_next_token_ids = min_p_sampling_from_probs(
|
||||
probs, sampling_config.min_ps
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
|
||||
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
|
||||
if temperature_0_idx.numel() > 0:
|
||||
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
|
||||
else:
|
||||
# TODO: use different kernel when don't need top_k or top_p
|
||||
# @TODO get probs
|
||||
|
@ -94,7 +102,9 @@ class Sampler(nn.Module):
|
|||
sampling_config.top_ps,
|
||||
filter_apply_order="joint",
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
|
||||
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
|
||||
if temperature_0_idx.numel() > 0:
|
||||
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
|
||||
|
||||
return batch_next_token_ids.to(torch.int32), probs
|
||||
return batch_next_token_ids.to(torch.int32), probs
|
||||
|
|
Loading…
Add table
Reference in a new issue