fix some bugs

This commit is contained in:
root 2025-04-17 00:48:09 +08:00
parent d2cf81423f
commit 921061666c
3 changed files with 7 additions and 7 deletions

View file

@ -138,7 +138,7 @@ class KLinearTorch(KLinearBase):
self.weight = None
self.has_bias = False
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None, **kwargs) -> torch.Tensor:
dtype = x.dtype
out_device = x.device
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
@ -201,7 +201,7 @@ class KLinearQ8(KLinearBase):
self.bias = None
self.loaded = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None) -> torch.Tensor:
orig_dtype = x.dtype
out_device = x.device