mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 04:30:03 +00:00
Merge pull request #1320 from aubreyli/no_cuda_graph_err
VLinearMarlin: padding to input.shape[0] to avoid CUDA error
This commit is contained in:
commit
01311d251d
1 changed files with 27 additions and 14 deletions
|
@ -504,6 +504,9 @@ class VLinearMarlin(KLinearBase):
|
|||
marlin_s = self.marlin_s.to(x.dtype)
|
||||
sms = -1
|
||||
|
||||
# padding x.shape[0] to avoid CUDA illegal memory access error
|
||||
x, orig_size_m = self._pad_input(x)
|
||||
|
||||
x = vLLMMarlin.gptq_marlin_gemm(
|
||||
x,
|
||||
self.marlin_q_w,
|
||||
|
@ -513,26 +516,15 @@ class VLinearMarlin(KLinearBase):
|
|||
self.workspace.scratch,
|
||||
self.num_bits,
|
||||
bsz_tensor,
|
||||
# torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device),
|
||||
x.shape[0],
|
||||
self.n,
|
||||
x.shape[-1],
|
||||
sms,
|
||||
self.is_k_full,
|
||||
)
|
||||
# x = KTransformersOps.gptq_marlin_gemm(
|
||||
# x,
|
||||
# self.marlin_q_w,
|
||||
# marlin_s,
|
||||
# self.g_idx,
|
||||
# self.sort_indices,
|
||||
# self.workspace.scratch,
|
||||
# self.num_bits,
|
||||
# x.shape[0],
|
||||
# self.n,
|
||||
# x.shape[-1],
|
||||
# self.is_k_full,
|
||||
# )
|
||||
|
||||
x = x[:orig_size_m]
|
||||
|
||||
if self.has_bias:
|
||||
x = x + self.bias
|
||||
orig_shape[-1] = self.n
|
||||
|
@ -548,6 +540,27 @@ class VLinearMarlin(KLinearBase):
|
|||
self.sort_indices = None
|
||||
self.workspace = None
|
||||
|
||||
def _pad_input(self, x):
|
||||
|
||||
size_m = x.shape[0]
|
||||
size_k = x.shape[1]
|
||||
|
||||
# size_m and align value depends on VLinearMarlin implementation
|
||||
if size_m > 1024:
|
||||
align = 1024
|
||||
elif size_m > 64:
|
||||
align = 64
|
||||
else:
|
||||
align = 1
|
||||
|
||||
padded_size_m = ((size_m + align - 1) // align) * align
|
||||
|
||||
if padded_size_m > size_m:
|
||||
pad_len = padded_size_m - size_m
|
||||
pad_tensor = torch.zeros((pad_len, size_k), dtype=x.dtype, device=x.device)
|
||||
x = torch.cat([x, pad_tensor], dim = 0).contiguous()
|
||||
return x, size_m
|
||||
|
||||
class KLinearMarlin(KLinearBase):
|
||||
marlin_q_w: torch.Tensor
|
||||
marlin_s: torch.Tensor
|
||||
|
|
Loading…
Add table
Reference in a new issue