From 8c8cb207aa0ebe6d00ede36489adfe6ce23bd436 Mon Sep 17 00:00:00 2001 From: Jesse CreateThis Date: Sun, 6 Jul 2025 19:45:06 +0000 Subject: [PATCH] Apply magikRUKKOLA's patch from issue #1417 --- .../models/custom_modeling_deepseek_v3.py | 14 +++++++----- .../inference/sampling/sampler.py | 22 ++++++++++++++----- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/ktransformers/models/custom_modeling_deepseek_v3.py b/ktransformers/models/custom_modeling_deepseek_v3.py index 589f6c3..e6a8fdd 100644 --- a/ktransformers/models/custom_modeling_deepseek_v3.py +++ b/ktransformers/models/custom_modeling_deepseek_v3.py @@ -42,18 +42,22 @@ class KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): def init_wrapper(self, use_cuda_graph, device, max_batch_size, max_pages): self.use_cuda_graph = use_cuda_graph - self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + # Increase buffer sizes to be safe + self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) self.qo_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device) self.paged_kv_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device) - self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device) + # Make sure this buffer is large enough + self.paged_kv_indices_buf = torch.empty((max_pages * 2,), dtype=torch.int32, device=device) self.paged_kv_len_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device) self.bsz_tensor_buf = torch.empty((1, ), dtype=torch.int32, device=device) self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( self.workspace_buffer, use_cuda_graph=use_cuda_graph, - qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf, - kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf, + qo_indptr=self.qo_indptr_buf, + kv_indptr=self.paged_kv_indptr_buf, + kv_indices=self.paged_kv_indices_buf, + kv_len_arr=self.paged_kv_len_buf, bsz_tensor=self.bsz_tensor_buf, backend = "fa2", ) @@ -145,4 +149,4 @@ class KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): minibatch = batch.minibatch self.wrapper.plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, minibatch.kv_len, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_data_type, kv_data_type, bsz_tensors) - \ No newline at end of file + diff --git a/ktransformers/server/balance_serve/inference/sampling/sampler.py b/ktransformers/server/balance_serve/inference/sampling/sampler.py index f491c97..95df61a 100644 --- a/ktransformers/server/balance_serve/inference/sampling/sampler.py +++ b/ktransformers/server/balance_serve/inference/sampling/sampler.py @@ -52,8 +52,9 @@ class SamplingOptions(): self.is_all_greedy = False class Sampler(nn.Module): - def __init__(self): + def __init__(self, device=torch.device('cuda')): super().__init__() + self.device = device def forward( self, @@ -63,7 +64,11 @@ class Sampler(nn.Module): if sampling_config == None: sampling_config = SamplingOptions() - logits = logits.contiguous() + # Ensure all tensors are on the same device + device = logits.device + logits = logits.contiguous().to(device) + sampling_config.temperatures = sampling_config.temperatures.to(device) + origin_logits = logits.clone() if sampling_config.is_all_greedy: # Use torch.argmax if all requests use greedy sampling @@ -71,7 +76,8 @@ class Sampler(nn.Module): batch_next_token_ids = torch.argmax(logits, -1) else: # Post process logits - logits.div_(sampling_config.temperatures) + safe_temperatures = sampling_config.temperatures.masked_fill(sampling_config.temperatures == 0, 1.0) + logits.div_(safe_temperatures) max_top_k_round, batch_size = 32, logits.shape[0] if sampling_config.need_min_p_sampling: probs = torch.softmax(logits, dim=-1) @@ -82,8 +88,10 @@ class Sampler(nn.Module): batch_next_token_ids = min_p_sampling_from_probs( probs, sampling_config.min_ps ) + torch.cuda.synchronize() temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0] - batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32) + if temperature_0_idx.numel() > 0: + batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32) else: # TODO: use different kernel when don't need top_k or top_p # @TODO get probs @@ -94,7 +102,9 @@ class Sampler(nn.Module): sampling_config.top_ps, filter_apply_order="joint", ) + torch.cuda.synchronize() temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0] - batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32) + if temperature_0_idx.numel() > 0: + batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32) - return batch_next_token_ids.to(torch.int32), probs \ No newline at end of file + return batch_next_token_ids.to(torch.int32), probs