Fix TypeError when invoke KLinearCPUInfer.forward()

Fix the following error:

  File "/home/aubrey/work/ktransformers/ktransformers/operators/linear.py", line 825, in forward
    y = self.generate_linear.forward(x, bsz_tensor)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: KLinearCPUInfer.forward() takes 2 positional arguments but 3 were given
This commit is contained in:
Aubrey Li 2025-04-07 12:02:27 +08:00
parent 6ca743ed7a
commit 12a4c631df

View file

@ -699,7 +699,7 @@ class KLinearCPUInfer(KLinearBase):
self.group_max_len = group_max_len self.group_max_len = group_max_len
self.out_device = out_device self.out_device = out_device
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:
origin_shape = x.shape # [batch_size, q_len, hidden_size] origin_shape = x.shape # [batch_size, q_len, hidden_size]
if origin_shape[1] == 1 and torch.cuda.is_current_stream_capturing(): if origin_shape[1] == 1 and torch.cuda.is_current_stream_capturing():
out_device = x.device out_device = x.device