mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 23:34:35 +00:00
Add data loader to read special weights for fp8; Add special weight process script
This commit is contained in:
parent
7b7c6a657d
commit
581a524f65
10 changed files with 481 additions and 26 deletions
|
@ -76,7 +76,13 @@ class KLinearBase(ABC):
|
|||
keys = [self.key]
|
||||
|
||||
for key in keys:
|
||||
if key + ".weight" in self.gguf_loader.tensor_file_map:
|
||||
if self.gguf_loader.safetensor_loader is not None:
|
||||
# using safetensor_loader
|
||||
tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight')
|
||||
weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
|
||||
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
|
||||
|
||||
elif key + ".weight" in self.gguf_loader.tensor_file_map:
|
||||
if key + ".bias" in self.gguf_loader.tensor_file_map:
|
||||
tensors = self.load_multi(key, ["weight", "bias"], device=device)
|
||||
tensor = tensors["weight"]
|
||||
|
@ -166,6 +172,8 @@ class KLinearTorch(KLinearBase):
|
|||
self.bias = None
|
||||
|
||||
class KLinearFP8(KLinearBase):
|
||||
# this kernel requires special handling for weight
|
||||
# Please load the weight file downloaded from KVCache.AI
|
||||
marlin_q_w: torch.Tensor
|
||||
marlin_s: torch.Tensor
|
||||
g_idx: torch.Tensor
|
||||
|
@ -191,26 +199,20 @@ class KLinearFP8(KLinearBase):
|
|||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.to(self.device)
|
||||
orig_shape = list(x.shape)
|
||||
orig_dtype = x.dtype
|
||||
x = x.reshape(-1, orig_shape[-1])
|
||||
orig_dtype = x.dtype
|
||||
x_quantized, scale_x = act_quant(x, self.block_size)
|
||||
y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight.scale)
|
||||
if self.bias is not None:
|
||||
y += self.bias
|
||||
return y.to(orig_dtype).reshape(orig_shape)
|
||||
y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight_scale_inv)
|
||||
return y.to(dtype=orig_dtype)
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||
if device is None: device = self.device
|
||||
if w is None:
|
||||
w = self.load_weight(device=device)
|
||||
if isinstance(w, nn.Parameter):
|
||||
self.weight = w.to(device)
|
||||
self.has_bias = False
|
||||
elif isinstance(w, tuple):
|
||||
### TODO fit weight_inv format
|
||||
if isinstance(w, tuple):
|
||||
self.weight = w[0].to(device)
|
||||
self.bias = w[1].to(device)
|
||||
self.has_bias = True
|
||||
self.weight_scale_inv = w[1].to(device)
|
||||
self.has_bias = False
|
||||
else:
|
||||
raise ValueError("Invalid weight type")
|
||||
self.weight = self.weight.to(device)
|
||||
|
@ -425,7 +427,8 @@ class KLinearCPUInfer(KLinearBase):
|
|||
LINEAR_MAP = {
|
||||
"KLinearMarlin": KLinearMarlin,
|
||||
"KLinearTorch": KLinearTorch,
|
||||
"KLinearCPUInfer": KLinearCPUInfer
|
||||
"KLinearCPUInfer": KLinearCPUInfer,
|
||||
"KLinearFP8": KLinearFP8,
|
||||
}
|
||||
|
||||
class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
||||
|
@ -472,10 +475,11 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
|||
def forward(self, x):
|
||||
if self.mode == InferenceState.PREFILL:
|
||||
assert self.prefill_linear is not None, "cpu linear is not initialized"
|
||||
return self.prefill_linear.forward(x)
|
||||
y = self.prefill_linear.forward(x)
|
||||
else:
|
||||
assert self.generate_linear is not None, "gpu linear is not initialized"
|
||||
return self.generate_linear.forward(x)
|
||||
y = self.generate_linear.forward(x)
|
||||
return y
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):
|
||||
if not mode:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue