mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 12:40:02 +00:00
Merge pull request #1320 from aubreyli/no_cuda_graph_err
VLinearMarlin: padding to input.shape[0] to avoid CUDA error
This commit is contained in:
commit
01311d251d
1 changed files with 27 additions and 14 deletions
|
@ -504,6 +504,9 @@ class VLinearMarlin(KLinearBase):
|
||||||
marlin_s = self.marlin_s.to(x.dtype)
|
marlin_s = self.marlin_s.to(x.dtype)
|
||||||
sms = -1
|
sms = -1
|
||||||
|
|
||||||
|
# padding x.shape[0] to avoid CUDA illegal memory access error
|
||||||
|
x, orig_size_m = self._pad_input(x)
|
||||||
|
|
||||||
x = vLLMMarlin.gptq_marlin_gemm(
|
x = vLLMMarlin.gptq_marlin_gemm(
|
||||||
x,
|
x,
|
||||||
self.marlin_q_w,
|
self.marlin_q_w,
|
||||||
|
@ -513,26 +516,15 @@ class VLinearMarlin(KLinearBase):
|
||||||
self.workspace.scratch,
|
self.workspace.scratch,
|
||||||
self.num_bits,
|
self.num_bits,
|
||||||
bsz_tensor,
|
bsz_tensor,
|
||||||
# torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device),
|
|
||||||
x.shape[0],
|
x.shape[0],
|
||||||
self.n,
|
self.n,
|
||||||
x.shape[-1],
|
x.shape[-1],
|
||||||
sms,
|
sms,
|
||||||
self.is_k_full,
|
self.is_k_full,
|
||||||
)
|
)
|
||||||
# x = KTransformersOps.gptq_marlin_gemm(
|
|
||||||
# x,
|
x = x[:orig_size_m]
|
||||||
# self.marlin_q_w,
|
|
||||||
# marlin_s,
|
|
||||||
# self.g_idx,
|
|
||||||
# self.sort_indices,
|
|
||||||
# self.workspace.scratch,
|
|
||||||
# self.num_bits,
|
|
||||||
# x.shape[0],
|
|
||||||
# self.n,
|
|
||||||
# x.shape[-1],
|
|
||||||
# self.is_k_full,
|
|
||||||
# )
|
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
x = x + self.bias
|
x = x + self.bias
|
||||||
orig_shape[-1] = self.n
|
orig_shape[-1] = self.n
|
||||||
|
@ -548,6 +540,27 @@ class VLinearMarlin(KLinearBase):
|
||||||
self.sort_indices = None
|
self.sort_indices = None
|
||||||
self.workspace = None
|
self.workspace = None
|
||||||
|
|
||||||
|
def _pad_input(self, x):
|
||||||
|
|
||||||
|
size_m = x.shape[0]
|
||||||
|
size_k = x.shape[1]
|
||||||
|
|
||||||
|
# size_m and align value depends on VLinearMarlin implementation
|
||||||
|
if size_m > 1024:
|
||||||
|
align = 1024
|
||||||
|
elif size_m > 64:
|
||||||
|
align = 64
|
||||||
|
else:
|
||||||
|
align = 1
|
||||||
|
|
||||||
|
padded_size_m = ((size_m + align - 1) // align) * align
|
||||||
|
|
||||||
|
if padded_size_m > size_m:
|
||||||
|
pad_len = padded_size_m - size_m
|
||||||
|
pad_tensor = torch.zeros((pad_len, size_k), dtype=x.dtype, device=x.device)
|
||||||
|
x = torch.cat([x, pad_tensor], dim = 0).contiguous()
|
||||||
|
return x, size_m
|
||||||
|
|
||||||
class KLinearMarlin(KLinearBase):
|
class KLinearMarlin(KLinearBase):
|
||||||
marlin_q_w: torch.Tensor
|
marlin_q_w: torch.Tensor
|
||||||
marlin_s: torch.Tensor
|
marlin_s: torch.Tensor
|
||||||
|
|
Loading…
Add table
Reference in a new issue