mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 23:34:35 +00:00
add balance-serve, support concurrence
This commit is contained in:
parent
8d0292aa44
commit
25cee5810e
196 changed files with 22077 additions and 565 deletions
|
@ -15,14 +15,16 @@ import ctypes
|
|||
import torch
|
||||
from torch import Tensor, nn
|
||||
import KTransformersOps
|
||||
import vLLMMarlin
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from ktransformers.util.utils import InferenceState
|
||||
from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (
|
||||
MarlinWorkspace,
|
||||
marlin_quantize,
|
||||
marlin_quantize,
|
||||
GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MIN_THREAD_K,
|
||||
GPTQ_MARLIN_MAX_PARALLEL,
|
||||
vllm_marlin_quantize
|
||||
)
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
@ -84,8 +86,10 @@ class KLinearBase(ABC):
|
|||
if self.gguf_loader.safetensor_loader is not None:
|
||||
# using safetensor_loader
|
||||
tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight')
|
||||
weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
|
||||
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
|
||||
if key+'.weight_scale_inv' in self.gguf_loader.safetensor_loader.tensor_file_map:
|
||||
weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
|
||||
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
|
||||
return nn.Parameter(tensor)
|
||||
|
||||
elif key + ".weight" in self.gguf_loader.tensor_file_map:
|
||||
if key + ".bias" in self.gguf_loader.tensor_file_map:
|
||||
|
@ -134,7 +138,7 @@ class KLinearTorch(KLinearBase):
|
|||
self.weight = None
|
||||
self.has_bias = False
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
dtype = x.dtype
|
||||
out_device = x.device
|
||||
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
|
||||
|
@ -178,7 +182,6 @@ class KLinearTorch(KLinearBase):
|
|||
if self.has_bias:
|
||||
self.bias = None
|
||||
|
||||
|
||||
class KLinearQ8(KLinearBase):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -370,7 +373,7 @@ class KLinearFP8(KLinearBase):
|
|||
self.dtype = torch.get_default_dtype()
|
||||
self.block_size = block_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor) -> torch.Tensor:
|
||||
x = x.to(self.device)
|
||||
orig_dtype = x.dtype
|
||||
x_quantized, scale_x = act_quant(x, self.block_size)
|
||||
|
@ -397,8 +400,152 @@ class KLinearFP8(KLinearBase):
|
|||
self.weight = None
|
||||
if self.has_bias:
|
||||
self.bias = None
|
||||
|
||||
# TODO: merge two marlin class
|
||||
|
||||
class VLinearMarlin(KLinearBase):
|
||||
marlin_q_w: torch.Tensor
|
||||
marlin_s: torch.Tensor
|
||||
g_idx: torch.Tensor
|
||||
sort_indices: torch.Tensor
|
||||
has_bias: bool
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module = None,
|
||||
device: str = "cuda",
|
||||
num_bits: int = 4, # 4-bit/8-bit is supported
|
||||
group_size: int = 64, # -1, 32, 64, 128
|
||||
act_order: bool = False,
|
||||
is_k_full=True,
|
||||
**kwargs,
|
||||
):
|
||||
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
|
||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
self.num_bits = num_bits
|
||||
self.group_size = group_size
|
||||
self.act_order = act_order
|
||||
self.is_k_full = is_k_full
|
||||
self.padding = False
|
||||
self.orin_in_features = self.in_features
|
||||
self.orin_out_features = self.out_features
|
||||
if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0:
|
||||
#print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding")
|
||||
self.padding = True
|
||||
self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K
|
||||
self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N
|
||||
#print(f"After padding: in_features={in_features}, out_features={out_features}")
|
||||
|
||||
self.k = self.in_features
|
||||
self.n = self.out_features
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||
if self.loaded: return
|
||||
if device is None: device = self.device
|
||||
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
|
||||
|
||||
#if self.in_features * self.out_features:
|
||||
if w is None:
|
||||
w = self.load_weight(device=device)
|
||||
|
||||
if isinstance(w, nn.Parameter):
|
||||
# pad weight
|
||||
weight = w.view(self.orin_out_features, self.orin_in_features).T
|
||||
self.has_bias = False
|
||||
elif isinstance(w, tuple):
|
||||
w = list(w)
|
||||
weight = w[0].view(self.orin_out_features, self.orin_in_features).T
|
||||
self.bias = w[1].view(self.orin_out_features)
|
||||
self.bias = w[1]
|
||||
self.has_bias = True
|
||||
else:
|
||||
raise ValueError("Invalid weight type")
|
||||
weight = weight.to(device)
|
||||
if self.has_bias:
|
||||
self.bias = self.bias.to(device)
|
||||
|
||||
if self.padding:
|
||||
padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device)
|
||||
padded_weight[:self.orin_in_features, :self.orin_out_features] = weight
|
||||
weight = padded_weight
|
||||
|
||||
# Pack Marlin linear
|
||||
marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
weight, self.num_bits, self.group_size, self.act_order
|
||||
)
|
||||
self.workspace = MarlinWorkspace(
|
||||
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device
|
||||
)
|
||||
self.weight = marlin_q_w
|
||||
self.marlin_q_w = marlin_q_w
|
||||
self.marlin_s = marlin_s
|
||||
self.g_idx = g_idx
|
||||
self.sort_indices = sort_indices
|
||||
self.k = weight.shape[0]
|
||||
self.n = weight.shape[1]
|
||||
# self.shape_buffer = torch.tensor([60], dtype=torch.int32, device=self.device)
|
||||
self.loaded = True
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:
|
||||
if bsz_tensor is None:
|
||||
bsz_tensor = torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device)
|
||||
|
||||
|
||||
# Only support input x as BF16 and FP16
|
||||
x = x.to(self.device)
|
||||
orig_shape = list(x.shape)
|
||||
orig_dtype = x.dtype
|
||||
x = x.reshape(-1, orig_shape[-1])
|
||||
marlin_s = self.marlin_s.to(x.dtype)
|
||||
sms = -1
|
||||
|
||||
x = vLLMMarlin.gptq_marlin_gemm(
|
||||
x,
|
||||
self.marlin_q_w,
|
||||
marlin_s,
|
||||
self.g_idx,
|
||||
self.sort_indices,
|
||||
self.workspace.scratch,
|
||||
self.num_bits,
|
||||
bsz_tensor,
|
||||
# torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device),
|
||||
x.shape[0],
|
||||
self.n,
|
||||
x.shape[-1],
|
||||
sms,
|
||||
self.is_k_full,
|
||||
)
|
||||
# x = KTransformersOps.gptq_marlin_gemm(
|
||||
# x,
|
||||
# self.marlin_q_w,
|
||||
# marlin_s,
|
||||
# self.g_idx,
|
||||
# self.sort_indices,
|
||||
# self.workspace.scratch,
|
||||
# self.num_bits,
|
||||
# x.shape[0],
|
||||
# self.n,
|
||||
# x.shape[-1],
|
||||
# self.is_k_full,
|
||||
# )
|
||||
if self.has_bias:
|
||||
x = x + self.bias
|
||||
orig_shape[-1] = self.n
|
||||
return x.reshape(orig_shape).to(orig_dtype)
|
||||
|
||||
def unload(self):
|
||||
|
||||
if self.has_bias:
|
||||
self.bias = None
|
||||
self.marlin_q_w = None
|
||||
self.marlin_s = None
|
||||
self.g_idx = None
|
||||
self.sort_indices = None
|
||||
self.workspace = None
|
||||
|
||||
class KLinearMarlin(KLinearBase):
|
||||
marlin_q_w: torch.Tensor
|
||||
marlin_s: torch.Tensor
|
||||
|
@ -483,7 +630,7 @@ class KLinearMarlin(KLinearBase):
|
|||
self.n = weight.shape[1]
|
||||
self.loaded = True
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None, **kwargs) -> torch.Tensor:
|
||||
# Only support input x as BF16 and FP16
|
||||
x = x.to(self.device)
|
||||
orig_shape = list(x.shape)
|
||||
|
@ -629,12 +776,13 @@ class KLinearCPUInfer(KLinearBase):
|
|||
if self.w is not None:
|
||||
self.w = None
|
||||
if self.has_bias:
|
||||
self.bias = None
|
||||
self.bias = None
|
||||
|
||||
LINEAR_MAP = {
|
||||
"KLinearMarlin": KLinearMarlin,
|
||||
"KLinearTorch": KLinearTorch,
|
||||
"KLinearCPUInfer": KLinearCPUInfer,
|
||||
"VLinearMarlin": VLinearMarlin,
|
||||
"KLinearFP8": KLinearFP8,
|
||||
"KLinearQ8": KLinearQ8,
|
||||
}
|
||||
|
@ -668,13 +816,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
|||
self.generate_linear = None
|
||||
self.mode = InferenceState.UNLOAD
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, bsz_tensor=None):
|
||||
if self.mode == InferenceState.PREFILL:
|
||||
assert self.prefill_linear is not None, "cpu linear is not initialized"
|
||||
y = self.prefill_linear.forward(x)
|
||||
y = self.prefill_linear.forward(x, bsz_tensor)
|
||||
else:
|
||||
assert self.generate_linear is not None, "gpu linear is not initialized"
|
||||
y = self.generate_linear.forward(x)
|
||||
y = self.generate_linear.forward(x, bsz_tensor)
|
||||
return y
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):
|
||||
|
@ -717,3 +865,5 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
|||
self.unload()
|
||||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue