mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
Merge remote-tracking branch 'upstream/develop-0.2.2' into support-fp8
This commit is contained in:
commit
ca7366d2db
41 changed files with 1223 additions and 314 deletions
|
@ -21,6 +21,7 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl
|
|||
MarlinWorkspace,
|
||||
marlin_quantize,
|
||||
GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MIN_THREAD_K,
|
||||
GPTQ_MARLIN_MAX_PARALLEL,
|
||||
)
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
|
@ -65,6 +66,8 @@ class KLinearBase(ABC):
|
|||
self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0]
|
||||
self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1]
|
||||
|
||||
self.loaded = False # for lm_head pre-load, TODO: use new way to do lm_head pre-load when layer wise prefill.
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
@ -141,6 +144,7 @@ class KLinearTorch(KLinearBase):
|
|||
return x
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||
if self.loaded: return
|
||||
if device is None: device = self.device
|
||||
if w is None: w = self.load_weight(device=device)
|
||||
# else: self.out_features = w.shape[0], self.in_features = w.shape[1]
|
||||
|
@ -164,6 +168,7 @@ class KLinearTorch(KLinearBase):
|
|||
self.weight = self.weight.to(device)
|
||||
if self.has_bias:
|
||||
self.bias = self.bias.to(device)
|
||||
self.loaded = True
|
||||
|
||||
def unload(self):
|
||||
if self.weight is not None:
|
||||
|
@ -251,20 +256,36 @@ class KLinearMarlin(KLinearBase):
|
|||
self.group_size = group_size
|
||||
self.act_order = act_order
|
||||
self.is_k_full = is_k_full
|
||||
self.padding = False
|
||||
self.orin_in_features = self.in_features
|
||||
self.orin_out_features = self.out_features
|
||||
if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0:
|
||||
#print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding")
|
||||
self.padding = True
|
||||
self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K
|
||||
self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N
|
||||
#print(f"After padding: in_features={in_features}, out_features={out_features}")
|
||||
|
||||
self.k = self.in_features
|
||||
self.n = self.out_features
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||
if self.loaded: return
|
||||
if device is None: device = self.device
|
||||
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
|
||||
|
||||
#if self.in_features * self.out_features:
|
||||
if w is None:
|
||||
w = self.load_weight(device=device)
|
||||
|
||||
if isinstance(w, nn.Parameter):
|
||||
# pad weight
|
||||
weight = w.view(self.out_features, self.in_features).T
|
||||
weight = w.view(self.orin_out_features, self.orin_in_features).T
|
||||
self.has_bias = False
|
||||
elif isinstance(w, tuple):
|
||||
w = list(w)
|
||||
weight = w[0].view(self.out_features, self.in_features).T
|
||||
weight = w[0].view(self.orin_out_features, self.orin_in_features).T
|
||||
self.bias = w[1].view(self.orin_out_features)
|
||||
self.bias = w[1]
|
||||
self.has_bias = True
|
||||
else:
|
||||
|
@ -272,8 +293,14 @@ class KLinearMarlin(KLinearBase):
|
|||
weight = weight.to(device)
|
||||
if self.has_bias:
|
||||
self.bias = self.bias.to(device)
|
||||
|
||||
if self.padding:
|
||||
padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device)
|
||||
padded_weight[:self.orin_in_features, :self.orin_out_features] = weight
|
||||
weight = padded_weight
|
||||
|
||||
# Pack Marlin linear
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
weight, self.num_bits, self.group_size, self.act_order
|
||||
)
|
||||
self.workspace = MarlinWorkspace(
|
||||
|
@ -286,6 +313,7 @@ class KLinearMarlin(KLinearBase):
|
|||
self.sort_indices = sort_indices
|
||||
self.k = weight.shape[0]
|
||||
self.n = weight.shape[1]
|
||||
self.loaded = True
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Only support input x as BF16 and FP16
|
||||
|
@ -293,6 +321,11 @@ class KLinearMarlin(KLinearBase):
|
|||
orig_shape = list(x.shape)
|
||||
orig_dtype = x.dtype
|
||||
x = x.reshape(-1, orig_shape[-1])
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if self.padding:
|
||||
padding_input=torch.empty(x.shape[0], self.in_features, device=x.device, dtype=x.dtype)
|
||||
padding_input[:,:self.orin_in_features] = x
|
||||
x = padding_input
|
||||
marlin_s = self.marlin_s.to(x.dtype)
|
||||
x = KTransformersOps.gptq_marlin_gemm(
|
||||
x,
|
||||
|
@ -307,6 +340,11 @@ class KLinearMarlin(KLinearBase):
|
|||
x.shape[-1],
|
||||
self.is_k_full,
|
||||
)
|
||||
if self.padding:
|
||||
x = x[:,:self.orin_out_features]
|
||||
orig_shape[-1] = self.orin_out_features
|
||||
else:
|
||||
orig_shape[-1] = self.out_features
|
||||
if self.has_bias:
|
||||
x = x + self.bias
|
||||
orig_shape[-1] = self.n
|
||||
|
@ -450,24 +488,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
|||
# build all the linear operators
|
||||
if prefill_op is not None:
|
||||
assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported"
|
||||
if prefill_op == "KLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0):
|
||||
print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using KLinearTorch instead.")
|
||||
print(f"module info: key:{key} orig_module:{orig_module}")
|
||||
self.prefill_linear = KLinearTorch(key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
else:
|
||||
self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
else:
|
||||
self.prefill_linear = None
|
||||
|
||||
if generate_op is not None:
|
||||
assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported"
|
||||
if generate_op == "KLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0):
|
||||
print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using KLinearTorch instead.")
|
||||
print(f"module info: key:{key} orig_module:{orig_module}")
|
||||
self.generate_op = "KLinearTorch"
|
||||
self.generate_linear = KLinearTorch(key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
else:
|
||||
self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
else:
|
||||
self.generate_linear = None
|
||||
self.mode = InferenceState.UNLOAD
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue