mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
📝 add doc support and fix bug in qwen2
This commit is contained in:
parent
8bad019ef2
commit
c74453d8ca
11 changed files with 170 additions and 12 deletions
|
@ -54,15 +54,15 @@ class KLinearBase(ABC):
|
|||
|
||||
self.has_bias = False
|
||||
self.dtype = torch.get_default_dtype()
|
||||
# if orig_module is not None:
|
||||
# self.in_features = orig_module.in_features
|
||||
# self.out_features = orig_module.out_features
|
||||
# else:
|
||||
shape = self.gguf_loader.tensor_info[key + ".weight"]["shape"]
|
||||
if len(shape) == 1:
|
||||
print("Warning: orig_module is not set, but has in_features or out_features equals to 1, can't get in_features and out_features from GGUF")
|
||||
self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0]
|
||||
self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1]
|
||||
if orig_module is not None:
|
||||
self.in_features = orig_module.in_features
|
||||
self.out_features = orig_module.out_features
|
||||
else:
|
||||
shape = self.gguf_loader.tensor_info[key + ".weight"]["shape"]
|
||||
if len(shape) == 1:
|
||||
print("Warning: orig_module is not set, but has in_features or out_features equals to 1, can't get in_features and out_features from GGUF")
|
||||
self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0]
|
||||
self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1]
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -136,12 +136,19 @@ class KLinearTorch(KLinearBase):
|
|||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||
if device is None: device = self.device
|
||||
if w is None: w = self.load_weight(device=device)
|
||||
# else: self.out_features = w.shape[0], self.in_features = w.shape[1]
|
||||
|
||||
if isinstance(w, nn.Parameter):
|
||||
self.w = w.to(dtype=self.dtype).T
|
||||
try:
|
||||
self.w = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
||||
except:
|
||||
self.w = w.to(dtype=self.dtype).T
|
||||
self.has_bias = False
|
||||
elif isinstance(w, tuple):
|
||||
self.w = w[0].to(dtype=self.dtype).T
|
||||
try:
|
||||
self.w = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
||||
except:
|
||||
self.w = w[0].to(dtype=self.dtype).T
|
||||
self.bias = w[1].to(dtype=self.dtype)
|
||||
self.has_bias = True
|
||||
else:
|
||||
|
@ -187,7 +194,8 @@ class KLinearMarlin(KLinearBase):
|
|||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||
if device is None: device = self.device
|
||||
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
|
||||
if w is None: w = self.load_weight(device=device)
|
||||
if w is None:
|
||||
w = self.load_weight(device=device)
|
||||
|
||||
if isinstance(w, nn.Parameter):
|
||||
# pad weight
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue