mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
commit
3986e2d2cf
31 changed files with 1713 additions and 114 deletions
|
@ -35,6 +35,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext
|
|||
import cpuinfer_ext
|
||||
from ktransformers.operators.cpuinfer import CPUInfer
|
||||
from ktransformers.server.config.config import Config
|
||||
from typing import Dict, Tuple, Optional, Union
|
||||
import numpy as np
|
||||
|
||||
#class KLinearBase(BaseInjectedModule, ABC):
|
||||
class KLinearBase(ABC):
|
||||
|
@ -176,16 +178,182 @@ class KLinearTorch(KLinearBase):
|
|||
if self.has_bias:
|
||||
self.bias = None
|
||||
|
||||
|
||||
class KLinearQ8(KLinearBase):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module = None,
|
||||
device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
self.has_bias = False
|
||||
self.compute_dtype = torch.float32
|
||||
self.weight = None
|
||||
self.weight_scale = None
|
||||
self.weight_zero_point = None
|
||||
self.bias = None
|
||||
self.loaded = False
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
out_device = x.device
|
||||
|
||||
x = x.to(device=self.device, dtype=self.compute_dtype)
|
||||
|
||||
# 使用原始权重做矩阵乘法,模拟原始行为
|
||||
|
||||
# 反量化权重进行矩阵乘法
|
||||
weight_dequant = self._dequantize_weight(self.weight, self.weight_scale, bits=8)
|
||||
out = x @ weight_dequant.T
|
||||
|
||||
if self.has_bias:
|
||||
out = out + self.bias
|
||||
|
||||
return out.to(dtype=orig_dtype, device=out_device)
|
||||
|
||||
def _dequantize_weight(self, q_matrix, scales, bits=8):
|
||||
"""
|
||||
Dequantize a low-precision matrix back to floating-point
|
||||
|
||||
Args:
|
||||
q_matrix (torch.Tensor): Quantized int matrix
|
||||
scales (torch.Tensor): Scale factors for each column
|
||||
bits (int): Quantization bits used (8 or 4)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Dequantized floating-point matrix
|
||||
"""
|
||||
# Ensure inputs are torch tensors
|
||||
if not isinstance(q_matrix, torch.Tensor):
|
||||
q_matrix = torch.tensor(q_matrix, dtype=torch.int8)
|
||||
if not isinstance(scales, torch.Tensor):
|
||||
scales = torch.tensor(scales, dtype=torch.float32)
|
||||
|
||||
# Convert to correct dtype if needed
|
||||
if q_matrix.dtype != torch.int8:
|
||||
q_matrix = q_matrix.to(torch.int8)
|
||||
if scales.dtype != torch.float32:
|
||||
scales = scales.to(torch.float32)
|
||||
|
||||
# For Q4, ensure the values stay within 4-bit range
|
||||
if bits == 4:
|
||||
q_matrix = torch.clamp(q_matrix, -7, 7)
|
||||
rows, cols = q_matrix.shape
|
||||
dequant_matrix = q_matrix.to(torch.float32)
|
||||
scales_broadcast = scales.view(1, cols)
|
||||
# Apply dequantization to all columns at once using matrix multiplication
|
||||
dequant_matrix = dequant_matrix * scales_broadcast
|
||||
|
||||
return dequant_matrix
|
||||
|
||||
|
||||
def _quantize_weight(self, matrix, bits=8):
|
||||
"""
|
||||
Quantize a floating-point matrix to lower precision (Q8 or Q4)
|
||||
|
||||
Args:
|
||||
matrix (torch.Tensor): Input matrix in floating-point format
|
||||
bits (int): Quantization bits, either 8 or 4
|
||||
|
||||
Returns:
|
||||
tuple: (quantized int matrix, scale factors for each column)
|
||||
"""
|
||||
if not isinstance(matrix, torch.Tensor):
|
||||
matrix = torch.tensor(matrix, dtype=torch.float32)
|
||||
|
||||
# Convert to float32 if needed
|
||||
if matrix.dtype != torch.float32:
|
||||
matrix = matrix.to(torch.float32)
|
||||
|
||||
# Get matrix shape
|
||||
rows, cols = matrix.shape
|
||||
|
||||
# Determine quantization parameters based on bits
|
||||
if bits == 8:
|
||||
max_int = 127
|
||||
qtype = torch.int8
|
||||
elif bits == 4:
|
||||
max_int = 7
|
||||
qtype = torch.int8 # We'll still use int8 storage but limit to 4-bit range, wait for native support
|
||||
else:
|
||||
raise ValueError("Quantization bits must be either 8 or 4")
|
||||
|
||||
scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device)
|
||||
|
||||
# Calculate max absolute value for each column
|
||||
max_abs_vals, _ = torch.max(torch.abs(matrix), dim=0)
|
||||
|
||||
# Handle zero columns (avoid division by zero)
|
||||
zero_cols = max_abs_vals == 0
|
||||
max_abs_vals[zero_cols] = 1.0
|
||||
|
||||
# Calculate scale factors for all columns at once
|
||||
scales = max_abs_vals / max_int
|
||||
|
||||
# Prepare the scales for broadcasting [1, cols]
|
||||
scales_broadcast = scales.view(1, cols)
|
||||
|
||||
# Apply quantization to the entire matrix at once
|
||||
q_matrix = torch.round(matrix / scales_broadcast).to(qtype)
|
||||
|
||||
# For Q4, clamp values to ensure they stay within 4-bit range
|
||||
if bits == 4:
|
||||
q_matrix = torch.clamp(q_matrix, -max_int, max_int)
|
||||
|
||||
return q_matrix, scales
|
||||
|
||||
def load(self, w: Union[Dict, nn.Parameter, Tuple, None] = None, device: Optional[str] = None):
|
||||
if self.loaded: return
|
||||
if device is None: device = self.device
|
||||
if w is None: w = self.load_weight(device=device)
|
||||
|
||||
if isinstance(w, nn.Parameter):
|
||||
try:
|
||||
weight = w.to(dtype=self.compute_dtype).view(self.out_features, self.in_features)
|
||||
except:
|
||||
weight = w.to(dtype=self.compute_dtype)
|
||||
self.has_bias = False
|
||||
elif isinstance(w, tuple):
|
||||
try:
|
||||
weight = w[0].to(dtype=self.compute_dtype).view(self.out_features, self.in_features)
|
||||
except:
|
||||
weight = w[0].to(dtype=self.compute_dtype)
|
||||
self.bias = w[1].to(dtype=self.compute_dtype).to(device)
|
||||
self.has_bias = True
|
||||
else:
|
||||
raise ValueError("Invalid weight type")
|
||||
|
||||
self.weight, self.weight_scale = self._quantize_weight(weight, bits=8)
|
||||
|
||||
self.weight = self.weight.to(device)
|
||||
self.weight_scale = self.weight_scale.to(device)
|
||||
|
||||
if self.has_bias:
|
||||
self.bias = self.bias.to(device)
|
||||
|
||||
self.loaded = True
|
||||
|
||||
def unload(self):
|
||||
self.weight = None
|
||||
self.weight_scale = None
|
||||
self.weight_zero_point = None
|
||||
self._orig_weight = None
|
||||
|
||||
if self.has_bias:
|
||||
self.bias = None
|
||||
|
||||
self.loaded = False
|
||||
|
||||
|
||||
class KLinearFP8(KLinearBase):
|
||||
# this kernel requires special handling for weight
|
||||
# Please load the weight file downloaded from KVCache.AI
|
||||
marlin_q_w: torch.Tensor
|
||||
marlin_s: torch.Tensor
|
||||
g_idx: torch.Tensor
|
||||
sort_indices: torch.Tensor
|
||||
has_bias: bool
|
||||
weight: torch.Tensor
|
||||
scale_w: torch.Tensor
|
||||
bias: torch.Tensor
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -468,6 +636,7 @@ LINEAR_MAP = {
|
|||
"KLinearTorch": KLinearTorch,
|
||||
"KLinearCPUInfer": KLinearCPUInfer,
|
||||
"KLinearFP8": KLinearFP8,
|
||||
"KLinearQ8": KLinearQ8,
|
||||
}
|
||||
|
||||
class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue