mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
fix some bugs
This commit is contained in:
parent
d2cf81423f
commit
921061666c
3 changed files with 7 additions and 7 deletions
|
@ -138,7 +138,7 @@ class KLinearTorch(KLinearBase):
|
|||
self.weight = None
|
||||
self.has_bias = False
|
||||
|
||||
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None, **kwargs) -> torch.Tensor:
|
||||
dtype = x.dtype
|
||||
out_device = x.device
|
||||
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
|
||||
|
@ -201,7 +201,7 @@ class KLinearQ8(KLinearBase):
|
|||
self.bias = None
|
||||
self.loaded = False
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
out_device = x.device
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue