mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
Update readme; Format code; Add example yaml.
This commit is contained in:
parent
c38e77de6b
commit
e5b001d76f
8 changed files with 182 additions and 30 deletions
|
@ -187,8 +187,6 @@ class KLinearQ8(KLinearBase):
|
|||
config: PretrainedConfig,
|
||||
orig_module: nn.Module = None,
|
||||
device: str = "cuda",
|
||||
group_size: int = 128, # 增大分组大小,减少量化噪声
|
||||
percentile: float = 99.99, # 新增:对异常值进行截断的百分位数
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
|
@ -199,8 +197,6 @@ class KLinearQ8(KLinearBase):
|
|||
self.weight_zero_point = None
|
||||
self.bias = None
|
||||
self.loaded = False
|
||||
self.group_size = group_size
|
||||
self.percentile = percentile
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
|
@ -246,16 +242,9 @@ class KLinearQ8(KLinearBase):
|
|||
# For Q4, ensure the values stay within 4-bit range
|
||||
if bits == 4:
|
||||
q_matrix = torch.clamp(q_matrix, -7, 7)
|
||||
|
||||
# Get matrix shape
|
||||
rows, cols = q_matrix.shape
|
||||
|
||||
# Convert to float32
|
||||
dequant_matrix = q_matrix.to(torch.float32)
|
||||
|
||||
# Create broadcasted scales: reshape scales to [1, cols] for broadcasting
|
||||
scales_broadcast = scales.view(1, cols)
|
||||
|
||||
# Apply dequantization to all columns at once using matrix multiplication
|
||||
dequant_matrix = dequant_matrix * scales_broadcast
|
||||
|
||||
|
@ -285,21 +274,14 @@ class KLinearQ8(KLinearBase):
|
|||
|
||||
# Determine quantization parameters based on bits
|
||||
if bits == 8:
|
||||
# Q8: range is -127 to 127
|
||||
max_int = 127
|
||||
qtype = torch.int8
|
||||
elif bits == 4:
|
||||
# Q4: range is -7 to 7 (using 4-bit signed integers)
|
||||
max_int = 7
|
||||
qtype = torch.int8 # We'll still use int8 storage but limit to 4-bit range
|
||||
qtype = torch.int8 # We'll still use int8 storage but limit to 4-bit range, wait for native support
|
||||
else:
|
||||
raise ValueError("Quantization bits must be either 8 or 4")
|
||||
|
||||
# Initialize results and scale factors
|
||||
q_matrix = torch.zeros_like(matrix, dtype=qtype)
|
||||
scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device)
|
||||
|
||||
# Initialize scale factors
|
||||
|
||||
scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device)
|
||||
|
||||
# Calculate max absolute value for each column
|
||||
|
@ -370,13 +352,8 @@ class KLinearQ8(KLinearBase):
|
|||
class KLinearFP8(KLinearBase):
|
||||
# this kernel requires special handling for weight
|
||||
# Please load the weight file downloaded from KVCache.AI
|
||||
marlin_q_w: torch.Tensor
|
||||
marlin_s: torch.Tensor
|
||||
g_idx: torch.Tensor
|
||||
sort_indices: torch.Tensor
|
||||
has_bias: bool
|
||||
weight: torch.Tensor
|
||||
scale_w: torch.Tensor
|
||||
bias: torch.Tensor
|
||||
def __init__(
|
||||
self,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue