VLinearMarlin: padding to input.shape[0] to avoid CUDA error

Fix the following runtime error with --no-use_cuda_graph option

Traceback (most recent call last):
  File "/home/aubrey/miniforge3/envs/kt/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/aubrey/miniforge3/envs/kt/lib/python3.11/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/aubrey/miniforge3/envs/kt/lib/python3.11/site-packages/ktransformers/server/backend/interfaces/balance_serve.py", line 282, in run_engine
    engine.loop()
  File "/home/aubrey/miniforge3/envs/kt/lib/python3.11/site-packages/ktransformers/server/backend/interfaces/balance_serve.py", line 234, in loop
    self.model_runner.run(self.batch, self.query_manager)
  File "/home/aubrey/miniforge3/envs/kt/lib/python3.11/site-packages/ktransformers/server/balance_serve/inference/model_runner.py", line 220, in run
    self.output.logits[0] = self.output.logits[0][self.input[cuda_graph_idx].minibatch.logits_start]
                            ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
This commit is contained in:
Aubrey Li 2025-05-18 15:11:37 +08:00
parent 8caecf37d8
commit d347aeb518

View file

@ -504,6 +504,9 @@ class VLinearMarlin(KLinearBase):
marlin_s = self.marlin_s.to(x.dtype)
sms = -1
# padding x.shape[0] to avoid CUDA illegal memory access error
x, orig_size_m = self._pad_input(x)
x = vLLMMarlin.gptq_marlin_gemm(
x,
self.marlin_q_w,
@ -513,26 +516,15 @@ class VLinearMarlin(KLinearBase):
self.workspace.scratch,
self.num_bits,
bsz_tensor,
# torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device),
x.shape[0],
self.n,
x.shape[-1],
sms,
self.is_k_full,
)
# x = KTransformersOps.gptq_marlin_gemm(
# x,
# self.marlin_q_w,
# marlin_s,
# self.g_idx,
# self.sort_indices,
# self.workspace.scratch,
# self.num_bits,
# x.shape[0],
# self.n,
# x.shape[-1],
# self.is_k_full,
# )
x = x[:orig_size_m]
if self.has_bias:
x = x + self.bias
orig_shape[-1] = self.n
@ -548,6 +540,27 @@ class VLinearMarlin(KLinearBase):
self.sort_indices = None
self.workspace = None
def _pad_input(self, x):
size_m = x.shape[0]
size_k = x.shape[1]
# size_m and align value depends on VLinearMarlin implementation
if size_m > 1024:
align = 1024
elif size_m > 64:
align = 64
else:
align = 1
padded_size_m = ((size_m + align - 1) // align) * align
if padded_size_m > size_m:
pad_len = padded_size_m - size_m
pad_tensor = torch.zeros((pad_len, size_k), dtype=x.dtype, device=x.device)
x = torch.cat([x, pad_tensor], dim = 0).contiguous()
return x, size_m
class KLinearMarlin(KLinearBase):
marlin_q_w: torch.Tensor
marlin_s: torch.Tensor