mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
Add fp8 linear kernel;\n Add empty cache to fit in 16G VRAM; By 'wkGCaSS - 知乎 https://zhuanlan.zhihu.com/p/25491611225'
This commit is contained in:
parent
b4fb633991
commit
7b7c6a657d
5 changed files with 331 additions and 2 deletions
|
@ -25,6 +25,7 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl
|
|||
)
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
|
||||
from abc import ABC, abstractmethod
|
||||
import sys, os
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
|
||||
|
@ -164,7 +165,65 @@ class KLinearTorch(KLinearBase):
|
|||
if self.has_bias:
|
||||
self.bias = None
|
||||
|
||||
|
||||
class KLinearFP8(KLinearBase):
|
||||
marlin_q_w: torch.Tensor
|
||||
marlin_s: torch.Tensor
|
||||
g_idx: torch.Tensor
|
||||
sort_indices: torch.Tensor
|
||||
has_bias: bool
|
||||
weight: torch.Tensor
|
||||
scale_w: torch.Tensor
|
||||
bias: torch.Tensor
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module = None,
|
||||
device: str = "cuda",
|
||||
block_size: int = 128,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
self.has_bias = False
|
||||
self.dtype = torch.get_default_dtype()
|
||||
self.block_size = block_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.to(self.device)
|
||||
orig_shape = list(x.shape)
|
||||
orig_dtype = x.dtype
|
||||
x = x.reshape(-1, orig_shape[-1])
|
||||
x_quantized, scale_x = act_quant(x, self.block_size)
|
||||
y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight.scale)
|
||||
if self.bias is not None:
|
||||
y += self.bias
|
||||
return y.to(orig_dtype).reshape(orig_shape)
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||
if device is None: device = self.device
|
||||
if w is None:
|
||||
w = self.load_weight(device=device)
|
||||
if isinstance(w, nn.Parameter):
|
||||
self.weight = w.to(device)
|
||||
self.has_bias = False
|
||||
elif isinstance(w, tuple):
|
||||
self.weight = w[0].to(device)
|
||||
self.bias = w[1].to(device)
|
||||
self.has_bias = True
|
||||
else:
|
||||
raise ValueError("Invalid weight type")
|
||||
self.weight = self.weight.to(device)
|
||||
if self.has_bias:
|
||||
self.bias = self.bias.to(device)
|
||||
|
||||
def unload(self):
|
||||
if self.weight is not None:
|
||||
self.weight = None
|
||||
if self.has_bias:
|
||||
self.bias = None
|
||||
|
||||
|
||||
class KLinearMarlin(KLinearBase):
|
||||
marlin_q_w: torch.Tensor
|
||||
marlin_s: torch.Tensor
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue