Merge pull request #1320 from aubreyli/no_cuda_graph_err
Some checks are pending
Book-CI / test (push) Waiting to run
Deploy / deploy (macos-latest) (push) Waiting to run
Deploy / deploy (ubuntu-latest) (push) Waiting to run
Deploy / deploy (windows-latest) (push) Waiting to run

VLinearMarlin: padding to input.shape[0] to avoid CUDA error
This commit is contained in:
Atream 2025-05-18 02:45:05 -06:00 committed by GitHub
commit 01311d251d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -504,6 +504,9 @@ class VLinearMarlin(KLinearBase):
marlin_s = self.marlin_s.to(x.dtype) marlin_s = self.marlin_s.to(x.dtype)
sms = -1 sms = -1
# padding x.shape[0] to avoid CUDA illegal memory access error
x, orig_size_m = self._pad_input(x)
x = vLLMMarlin.gptq_marlin_gemm( x = vLLMMarlin.gptq_marlin_gemm(
x, x,
self.marlin_q_w, self.marlin_q_w,
@ -513,26 +516,15 @@ class VLinearMarlin(KLinearBase):
self.workspace.scratch, self.workspace.scratch,
self.num_bits, self.num_bits,
bsz_tensor, bsz_tensor,
# torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device),
x.shape[0], x.shape[0],
self.n, self.n,
x.shape[-1], x.shape[-1],
sms, sms,
self.is_k_full, self.is_k_full,
) )
# x = KTransformersOps.gptq_marlin_gemm(
# x, x = x[:orig_size_m]
# self.marlin_q_w,
# marlin_s,
# self.g_idx,
# self.sort_indices,
# self.workspace.scratch,
# self.num_bits,
# x.shape[0],
# self.n,
# x.shape[-1],
# self.is_k_full,
# )
if self.has_bias: if self.has_bias:
x = x + self.bias x = x + self.bias
orig_shape[-1] = self.n orig_shape[-1] = self.n
@ -548,6 +540,27 @@ class VLinearMarlin(KLinearBase):
self.sort_indices = None self.sort_indices = None
self.workspace = None self.workspace = None
def _pad_input(self, x):
size_m = x.shape[0]
size_k = x.shape[1]
# size_m and align value depends on VLinearMarlin implementation
if size_m > 1024:
align = 1024
elif size_m > 64:
align = 64
else:
align = 1
padded_size_m = ((size_m + align - 1) // align) * align
if padded_size_m > size_m:
pad_len = padded_size_m - size_m
pad_tensor = torch.zeros((pad_len, size_k), dtype=x.dtype, device=x.device)
x = torch.cat([x, pad_tensor], dim = 0).contiguous()
return x, size_m
class KLinearMarlin(KLinearBase): class KLinearMarlin(KLinearBase):
marlin_q_w: torch.Tensor marlin_q_w: torch.Tensor
marlin_s: torch.Tensor marlin_s: torch.Tensor