mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
Fix TypeError when invoke KLinearCPUInfer.forward()
Fix the following error: File "/home/aubrey/work/ktransformers/ktransformers/operators/linear.py", line 825, in forward y = self.generate_linear.forward(x, bsz_tensor) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TypeError: KLinearCPUInfer.forward() takes 2 positional arguments but 3 were given
This commit is contained in:
parent
6ca743ed7a
commit
12a4c631df
1 changed files with 1 additions and 1 deletions
|
@ -699,7 +699,7 @@ class KLinearCPUInfer(KLinearBase):
|
||||||
self.group_max_len = group_max_len
|
self.group_max_len = group_max_len
|
||||||
self.out_device = out_device
|
self.out_device = out_device
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:
|
||||||
origin_shape = x.shape # [batch_size, q_len, hidden_size]
|
origin_shape = x.shape # [batch_size, q_len, hidden_size]
|
||||||
if origin_shape[1] == 1 and torch.cuda.is_current_stream_capturing():
|
if origin_shape[1] == 1 and torch.cuda.is_current_stream_capturing():
|
||||||
out_device = x.device
|
out_device = x.device
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue