merge main; Add torch q8 linear

This commit is contained in:
Azure-Tang 2025-03-14 05:52:07 -04:00
parent 6c4ed59175
commit ed8437413b
27 changed files with 1561 additions and 114 deletions

View file

@ -35,6 +35,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext
import cpuinfer_ext
from ktransformers.operators.cpuinfer import CPUInfer
from ktransformers.server.config.config import Config
from typing import Dict, Tuple, Optional, Union
import numpy as np
#class KLinearBase(BaseInjectedModule, ABC):
class KLinearBase(ABC):
@ -176,6 +178,195 @@ class KLinearTorch(KLinearBase):
if self.has_bias:
self.bias = None
class KLinearQ8(KLinearBase):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
device: str = "cuda",
group_size: int = 128, # 增大分组大小,减少量化噪声
percentile: float = 99.99, # 新增:对异常值进行截断的百分位数
**kwargs,
):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.has_bias = False
self.compute_dtype = torch.float32
self.weight = None
self.weight_scale = None
self.weight_zero_point = None
self.bias = None
self.loaded = False
self.group_size = group_size
self.percentile = percentile
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
out_device = x.device
x = x.to(device=self.device, dtype=self.compute_dtype)
# 使用原始权重做矩阵乘法,模拟原始行为
# 反量化权重进行矩阵乘法
weight_dequant = self._dequantize_weight(self.weight, self.weight_scale, bits=8)
out = x @ weight_dequant.T
if self.has_bias:
out = out + self.bias
return out.to(dtype=orig_dtype, device=out_device)
def _dequantize_weight(self, q_matrix, scales, bits=8):
"""
Dequantize a low-precision matrix back to floating-point
Args:
q_matrix (torch.Tensor): Quantized int matrix
scales (torch.Tensor): Scale factors for each column
bits (int): Quantization bits used (8 or 4)
Returns:
torch.Tensor: Dequantized floating-point matrix
"""
# Ensure inputs are torch tensors
if not isinstance(q_matrix, torch.Tensor):
q_matrix = torch.tensor(q_matrix, dtype=torch.int8)
if not isinstance(scales, torch.Tensor):
scales = torch.tensor(scales, dtype=torch.float32)
# Convert to correct dtype if needed
if q_matrix.dtype != torch.int8:
q_matrix = q_matrix.to(torch.int8)
if scales.dtype != torch.float32:
scales = scales.to(torch.float32)
# For Q4, ensure the values stay within 4-bit range
if bits == 4:
q_matrix = torch.clamp(q_matrix, -7, 7)
# Get matrix shape
rows, cols = q_matrix.shape
# Convert to float32
dequant_matrix = q_matrix.to(torch.float32)
# Create broadcasted scales: reshape scales to [1, cols] for broadcasting
scales_broadcast = scales.view(1, cols)
# Apply dequantization to all columns at once using matrix multiplication
dequant_matrix = dequant_matrix * scales_broadcast
return dequant_matrix
def _quantize_weight(self, matrix, bits=8):
"""
Quantize a floating-point matrix to lower precision (Q8 or Q4)
Args:
matrix (torch.Tensor): Input matrix in floating-point format
bits (int): Quantization bits, either 8 or 4
Returns:
tuple: (quantized int matrix, scale factors for each column)
"""
if not isinstance(matrix, torch.Tensor):
matrix = torch.tensor(matrix, dtype=torch.float32)
# Convert to float32 if needed
if matrix.dtype != torch.float32:
matrix = matrix.to(torch.float32)
# Get matrix shape
rows, cols = matrix.shape
# Determine quantization parameters based on bits
if bits == 8:
# Q8: range is -127 to 127
max_int = 127
qtype = torch.int8
elif bits == 4:
# Q4: range is -7 to 7 (using 4-bit signed integers)
max_int = 7
qtype = torch.int8 # We'll still use int8 storage but limit to 4-bit range
else:
raise ValueError("Quantization bits must be either 8 or 4")
# Initialize results and scale factors
q_matrix = torch.zeros_like(matrix, dtype=qtype)
scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device)
# Initialize scale factors
scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device)
# Calculate max absolute value for each column
max_abs_vals, _ = torch.max(torch.abs(matrix), dim=0)
# Handle zero columns (avoid division by zero)
zero_cols = max_abs_vals == 0
max_abs_vals[zero_cols] = 1.0
# Calculate scale factors for all columns at once
scales = max_abs_vals / max_int
# Prepare the scales for broadcasting [1, cols]
scales_broadcast = scales.view(1, cols)
# Apply quantization to the entire matrix at once
q_matrix = torch.round(matrix / scales_broadcast).to(qtype)
# For Q4, clamp values to ensure they stay within 4-bit range
if bits == 4:
q_matrix = torch.clamp(q_matrix, -max_int, max_int)
return q_matrix, scales
def load(self, w: Union[Dict, nn.Parameter, Tuple, None] = None, device: Optional[str] = None):
if self.loaded: return
if device is None: device = self.device
if w is None: w = self.load_weight(device=device)
if isinstance(w, nn.Parameter):
try:
weight = w.to(dtype=self.compute_dtype).view(self.out_features, self.in_features)
except:
weight = w.to(dtype=self.compute_dtype)
self.has_bias = False
elif isinstance(w, tuple):
try:
weight = w[0].to(dtype=self.compute_dtype).view(self.out_features, self.in_features)
except:
weight = w[0].to(dtype=self.compute_dtype)
self.bias = w[1].to(dtype=self.compute_dtype).to(device)
self.has_bias = True
else:
raise ValueError("Invalid weight type")
self.weight, self.weight_scale = self._quantize_weight(weight, bits=8)
self.weight = self.weight.to(device)
self.weight_scale = self.weight_scale.to(device)
if self.has_bias:
self.bias = self.bias.to(device)
self.loaded = True
def unload(self):
self.weight = None
self.weight_scale = None
self.weight_zero_point = None
self._orig_weight = None
if self.has_bias:
self.bias = None
self.loaded = False
class KLinearFP8(KLinearBase):
# this kernel requires special handling for weight
# Please load the weight file downloaded from KVCache.AI
@ -468,6 +659,7 @@ LINEAR_MAP = {
"KLinearTorch": KLinearTorch,
"KLinearCPUInfer": KLinearCPUInfer,
"KLinearFP8": KLinearFP8,
"KLinearQ8": KLinearQ8,
}
class KTransformersLinear(BaseInjectedModule, KLinearBase):