From d347aeb5189bff9450a81023976adb269b19c2a9 Mon Sep 17 00:00:00 2001 From: Aubrey Li Date: Sun, 18 May 2025 15:11:37 +0800 Subject: [PATCH] VLinearMarlin: padding to input.shape[0] to avoid CUDA error Fix the following runtime error with --no-use_cuda_graph option Traceback (most recent call last): File "/home/aubrey/miniforge3/envs/kt/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap self.run() File "/home/aubrey/miniforge3/envs/kt/lib/python3.11/multiprocessing/process.py", line 108, in run self._target(*self._args, **self._kwargs) File "/home/aubrey/miniforge3/envs/kt/lib/python3.11/site-packages/ktransformers/server/backend/interfaces/balance_serve.py", line 282, in run_engine engine.loop() File "/home/aubrey/miniforge3/envs/kt/lib/python3.11/site-packages/ktransformers/server/backend/interfaces/balance_serve.py", line 234, in loop self.model_runner.run(self.batch, self.query_manager) File "/home/aubrey/miniforge3/envs/kt/lib/python3.11/site-packages/ktransformers/server/balance_serve/inference/model_runner.py", line 220, in run self.output.logits[0] = self.output.logits[0][self.input[cuda_graph_idx].minibatch.logits_start] ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: CUDA error: an illegal memory access was encountered CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. --- ktransformers/operators/linear.py | 41 ++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 9ce45d1..19500db 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -504,6 +504,9 @@ class VLinearMarlin(KLinearBase): marlin_s = self.marlin_s.to(x.dtype) sms = -1 + # padding x.shape[0] to avoid CUDA illegal memory access error + x, orig_size_m = self._pad_input(x) + x = vLLMMarlin.gptq_marlin_gemm( x, self.marlin_q_w, @@ -513,26 +516,15 @@ class VLinearMarlin(KLinearBase): self.workspace.scratch, self.num_bits, bsz_tensor, - # torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device), x.shape[0], self.n, x.shape[-1], sms, self.is_k_full, ) - # x = KTransformersOps.gptq_marlin_gemm( - # x, - # self.marlin_q_w, - # marlin_s, - # self.g_idx, - # self.sort_indices, - # self.workspace.scratch, - # self.num_bits, - # x.shape[0], - # self.n, - # x.shape[-1], - # self.is_k_full, - # ) + + x = x[:orig_size_m] + if self.has_bias: x = x + self.bias orig_shape[-1] = self.n @@ -548,6 +540,27 @@ class VLinearMarlin(KLinearBase): self.sort_indices = None self.workspace = None + def _pad_input(self, x): + + size_m = x.shape[0] + size_k = x.shape[1] + + # size_m and align value depends on VLinearMarlin implementation + if size_m > 1024: + align = 1024 + elif size_m > 64: + align = 64 + else: + align = 1 + + padded_size_m = ((size_m + align - 1) // align) * align + + if padded_size_m > size_m: + pad_len = padded_size_m - size_m + pad_tensor = torch.zeros((pad_len, size_k), dtype=x.dtype, device=x.device) + x = torch.cat([x, pad_tensor], dim = 0).contiguous() + return x, size_m + class KLinearMarlin(KLinearBase): marlin_q_w: torch.Tensor marlin_s: torch.Tensor