mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
VLinearMarlin: padding to input.shape[0] to avoid CUDA error
Fix the following runtime error with --no-use_cuda_graph option Traceback (most recent call last): File "/home/aubrey/miniforge3/envs/kt/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap self.run() File "/home/aubrey/miniforge3/envs/kt/lib/python3.11/multiprocessing/process.py", line 108, in run self._target(*self._args, **self._kwargs) File "/home/aubrey/miniforge3/envs/kt/lib/python3.11/site-packages/ktransformers/server/backend/interfaces/balance_serve.py", line 282, in run_engine engine.loop() File "/home/aubrey/miniforge3/envs/kt/lib/python3.11/site-packages/ktransformers/server/backend/interfaces/balance_serve.py", line 234, in loop self.model_runner.run(self.batch, self.query_manager) File "/home/aubrey/miniforge3/envs/kt/lib/python3.11/site-packages/ktransformers/server/balance_serve/inference/model_runner.py", line 220, in run self.output.logits[0] = self.output.logits[0][self.input[cuda_graph_idx].minibatch.logits_start] ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: CUDA error: an illegal memory access was encountered CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
This commit is contained in:
parent
8caecf37d8
commit
d347aeb518
1 changed files with 27 additions and 14 deletions
|
@ -504,6 +504,9 @@ class VLinearMarlin(KLinearBase):
|
|||
marlin_s = self.marlin_s.to(x.dtype)
|
||||
sms = -1
|
||||
|
||||
# padding x.shape[0] to avoid CUDA illegal memory access error
|
||||
x, orig_size_m = self._pad_input(x)
|
||||
|
||||
x = vLLMMarlin.gptq_marlin_gemm(
|
||||
x,
|
||||
self.marlin_q_w,
|
||||
|
@ -513,26 +516,15 @@ class VLinearMarlin(KLinearBase):
|
|||
self.workspace.scratch,
|
||||
self.num_bits,
|
||||
bsz_tensor,
|
||||
# torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device),
|
||||
x.shape[0],
|
||||
self.n,
|
||||
x.shape[-1],
|
||||
sms,
|
||||
self.is_k_full,
|
||||
)
|
||||
# x = KTransformersOps.gptq_marlin_gemm(
|
||||
# x,
|
||||
# self.marlin_q_w,
|
||||
# marlin_s,
|
||||
# self.g_idx,
|
||||
# self.sort_indices,
|
||||
# self.workspace.scratch,
|
||||
# self.num_bits,
|
||||
# x.shape[0],
|
||||
# self.n,
|
||||
# x.shape[-1],
|
||||
# self.is_k_full,
|
||||
# )
|
||||
|
||||
x = x[:orig_size_m]
|
||||
|
||||
if self.has_bias:
|
||||
x = x + self.bias
|
||||
orig_shape[-1] = self.n
|
||||
|
@ -548,6 +540,27 @@ class VLinearMarlin(KLinearBase):
|
|||
self.sort_indices = None
|
||||
self.workspace = None
|
||||
|
||||
def _pad_input(self, x):
|
||||
|
||||
size_m = x.shape[0]
|
||||
size_k = x.shape[1]
|
||||
|
||||
# size_m and align value depends on VLinearMarlin implementation
|
||||
if size_m > 1024:
|
||||
align = 1024
|
||||
elif size_m > 64:
|
||||
align = 64
|
||||
else:
|
||||
align = 1
|
||||
|
||||
padded_size_m = ((size_m + align - 1) // align) * align
|
||||
|
||||
if padded_size_m > size_m:
|
||||
pad_len = padded_size_m - size_m
|
||||
pad_tensor = torch.zeros((pad_len, size_k), dtype=x.dtype, device=x.device)
|
||||
x = torch.cat([x, pad_tensor], dim = 0).contiguous()
|
||||
return x, size_m
|
||||
|
||||
class KLinearMarlin(KLinearBase):
|
||||
marlin_q_w: torch.Tensor
|
||||
marlin_s: torch.Tensor
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue