mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-08 05:29:29 +00:00
Merge 8c8cb207aa
into ee2ede0412
This commit is contained in:
commit
e204a0bb6b
2 changed files with 25 additions and 11 deletions
|
@ -42,18 +42,22 @@ class KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
|
||||||
|
|
||||||
def init_wrapper(self, use_cuda_graph, device, max_batch_size, max_pages):
|
def init_wrapper(self, use_cuda_graph, device, max_batch_size, max_pages):
|
||||||
self.use_cuda_graph = use_cuda_graph
|
self.use_cuda_graph = use_cuda_graph
|
||||||
self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
|
# Increase buffer sizes to be safe
|
||||||
|
self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0)
|
||||||
self.qo_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device)
|
self.qo_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device)
|
||||||
self.paged_kv_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device)
|
self.paged_kv_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device)
|
||||||
self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device)
|
# Make sure this buffer is large enough
|
||||||
|
self.paged_kv_indices_buf = torch.empty((max_pages * 2,), dtype=torch.int32, device=device)
|
||||||
self.paged_kv_len_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)
|
self.paged_kv_len_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)
|
||||||
self.bsz_tensor_buf = torch.empty((1, ), dtype=torch.int32, device=device)
|
self.bsz_tensor_buf = torch.empty((1, ), dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
|
||||||
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
|
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
|
||||||
self.workspace_buffer, use_cuda_graph=use_cuda_graph,
|
self.workspace_buffer, use_cuda_graph=use_cuda_graph,
|
||||||
qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,
|
qo_indptr=self.qo_indptr_buf,
|
||||||
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,
|
kv_indptr=self.paged_kv_indptr_buf,
|
||||||
|
kv_indices=self.paged_kv_indices_buf,
|
||||||
|
kv_len_arr=self.paged_kv_len_buf,
|
||||||
bsz_tensor=self.bsz_tensor_buf,
|
bsz_tensor=self.bsz_tensor_buf,
|
||||||
backend = "fa2",
|
backend = "fa2",
|
||||||
)
|
)
|
||||||
|
@ -145,4 +149,4 @@ class KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
|
||||||
minibatch = batch.minibatch
|
minibatch = batch.minibatch
|
||||||
self.wrapper.plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices,
|
self.wrapper.plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices,
|
||||||
minibatch.kv_len, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_data_type, kv_data_type, bsz_tensors)
|
minibatch.kv_len, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_data_type, kv_data_type, bsz_tensors)
|
||||||
|
|
||||||
|
|
|
@ -52,8 +52,9 @@ class SamplingOptions():
|
||||||
self.is_all_greedy = False
|
self.is_all_greedy = False
|
||||||
|
|
||||||
class Sampler(nn.Module):
|
class Sampler(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, device=torch.device('cuda')):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -63,7 +64,11 @@ class Sampler(nn.Module):
|
||||||
if sampling_config == None:
|
if sampling_config == None:
|
||||||
sampling_config = SamplingOptions()
|
sampling_config = SamplingOptions()
|
||||||
|
|
||||||
logits = logits.contiguous()
|
# Ensure all tensors are on the same device
|
||||||
|
device = logits.device
|
||||||
|
logits = logits.contiguous().to(device)
|
||||||
|
sampling_config.temperatures = sampling_config.temperatures.to(device)
|
||||||
|
|
||||||
origin_logits = logits.clone()
|
origin_logits = logits.clone()
|
||||||
if sampling_config.is_all_greedy:
|
if sampling_config.is_all_greedy:
|
||||||
# Use torch.argmax if all requests use greedy sampling
|
# Use torch.argmax if all requests use greedy sampling
|
||||||
|
@ -71,7 +76,8 @@ class Sampler(nn.Module):
|
||||||
batch_next_token_ids = torch.argmax(logits, -1)
|
batch_next_token_ids = torch.argmax(logits, -1)
|
||||||
else:
|
else:
|
||||||
# Post process logits
|
# Post process logits
|
||||||
logits.div_(sampling_config.temperatures)
|
safe_temperatures = sampling_config.temperatures.masked_fill(sampling_config.temperatures == 0, 1.0)
|
||||||
|
logits.div_(safe_temperatures)
|
||||||
max_top_k_round, batch_size = 32, logits.shape[0]
|
max_top_k_round, batch_size = 32, logits.shape[0]
|
||||||
if sampling_config.need_min_p_sampling:
|
if sampling_config.need_min_p_sampling:
|
||||||
probs = torch.softmax(logits, dim=-1)
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
@ -82,8 +88,10 @@ class Sampler(nn.Module):
|
||||||
batch_next_token_ids = min_p_sampling_from_probs(
|
batch_next_token_ids = min_p_sampling_from_probs(
|
||||||
probs, sampling_config.min_ps
|
probs, sampling_config.min_ps
|
||||||
)
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
|
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
|
||||||
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
|
if temperature_0_idx.numel() > 0:
|
||||||
|
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
|
||||||
else:
|
else:
|
||||||
# TODO: use different kernel when don't need top_k or top_p
|
# TODO: use different kernel when don't need top_k or top_p
|
||||||
# @TODO get probs
|
# @TODO get probs
|
||||||
|
@ -94,7 +102,9 @@ class Sampler(nn.Module):
|
||||||
sampling_config.top_ps,
|
sampling_config.top_ps,
|
||||||
filter_apply_order="joint",
|
filter_apply_order="joint",
|
||||||
)
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
|
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
|
||||||
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
|
if temperature_0_idx.numel() > 0:
|
||||||
|
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
|
||||||
|
|
||||||
return batch_next_token_ids.to(torch.int32), probs
|
return batch_next_token_ids.to(torch.int32), probs
|
||||||
|
|
Loading…
Add table
Reference in a new issue