add balance-serve, support concurrence

This commit is contained in:
Atream 2025-03-31 22:55:32 +08:00
parent 8d0292aa44
commit 25cee5810e
196 changed files with 22077 additions and 565 deletions

View file

@ -15,14 +15,16 @@ import ctypes
import torch
from torch import Tensor, nn
import KTransformersOps
import vLLMMarlin
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import InferenceState
from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (
MarlinWorkspace,
marlin_quantize,
marlin_quantize,
GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MIN_THREAD_K,
GPTQ_MARLIN_MAX_PARALLEL,
vllm_marlin_quantize
)
from ktransformers.operators.base_operator import BaseInjectedModule
from transformers.configuration_utils import PretrainedConfig
@ -84,8 +86,10 @@ class KLinearBase(ABC):
if self.gguf_loader.safetensor_loader is not None:
# using safetensor_loader
tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight')
weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
if key+'.weight_scale_inv' in self.gguf_loader.safetensor_loader.tensor_file_map:
weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
return nn.Parameter(tensor)
elif key + ".weight" in self.gguf_loader.tensor_file_map:
if key + ".bias" in self.gguf_loader.tensor_file_map:
@ -134,7 +138,7 @@ class KLinearTorch(KLinearBase):
self.weight = None
self.has_bias = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
dtype = x.dtype
out_device = x.device
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
@ -178,7 +182,6 @@ class KLinearTorch(KLinearBase):
if self.has_bias:
self.bias = None
class KLinearQ8(KLinearBase):
def __init__(
self,
@ -370,7 +373,7 @@ class KLinearFP8(KLinearBase):
self.dtype = torch.get_default_dtype()
self.block_size = block_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor) -> torch.Tensor:
x = x.to(self.device)
orig_dtype = x.dtype
x_quantized, scale_x = act_quant(x, self.block_size)
@ -397,8 +400,152 @@ class KLinearFP8(KLinearBase):
self.weight = None
if self.has_bias:
self.bias = None
# TODO: merge two marlin class
class VLinearMarlin(KLinearBase):
marlin_q_w: torch.Tensor
marlin_s: torch.Tensor
g_idx: torch.Tensor
sort_indices: torch.Tensor
has_bias: bool
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
device: str = "cuda",
num_bits: int = 4, # 4-bit/8-bit is supported
group_size: int = 64, # -1, 32, 64, 128
act_order: bool = False,
is_k_full=True,
**kwargs,
):
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.num_bits = num_bits
self.group_size = group_size
self.act_order = act_order
self.is_k_full = is_k_full
self.padding = False
self.orin_in_features = self.in_features
self.orin_out_features = self.out_features
if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0:
#print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding")
self.padding = True
self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K
self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N
#print(f"After padding: in_features={in_features}, out_features={out_features}")
self.k = self.in_features
self.n = self.out_features
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if self.loaded: return
if device is None: device = self.device
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
#if self.in_features * self.out_features:
if w is None:
w = self.load_weight(device=device)
if isinstance(w, nn.Parameter):
# pad weight
weight = w.view(self.orin_out_features, self.orin_in_features).T
self.has_bias = False
elif isinstance(w, tuple):
w = list(w)
weight = w[0].view(self.orin_out_features, self.orin_in_features).T
self.bias = w[1].view(self.orin_out_features)
self.bias = w[1]
self.has_bias = True
else:
raise ValueError("Invalid weight type")
weight = weight.to(device)
if self.has_bias:
self.bias = self.bias.to(device)
if self.padding:
padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device)
padded_weight[:self.orin_in_features, :self.orin_out_features] = weight
weight = padded_weight
# Pack Marlin linear
marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
weight, self.num_bits, self.group_size, self.act_order
)
self.workspace = MarlinWorkspace(
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device
)
self.weight = marlin_q_w
self.marlin_q_w = marlin_q_w
self.marlin_s = marlin_s
self.g_idx = g_idx
self.sort_indices = sort_indices
self.k = weight.shape[0]
self.n = weight.shape[1]
# self.shape_buffer = torch.tensor([60], dtype=torch.int32, device=self.device)
self.loaded = True
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:
if bsz_tensor is None:
bsz_tensor = torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device)
# Only support input x as BF16 and FP16
x = x.to(self.device)
orig_shape = list(x.shape)
orig_dtype = x.dtype
x = x.reshape(-1, orig_shape[-1])
marlin_s = self.marlin_s.to(x.dtype)
sms = -1
x = vLLMMarlin.gptq_marlin_gemm(
x,
self.marlin_q_w,
marlin_s,
self.g_idx,
self.sort_indices,
self.workspace.scratch,
self.num_bits,
bsz_tensor,
# torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device),
x.shape[0],
self.n,
x.shape[-1],
sms,
self.is_k_full,
)
# x = KTransformersOps.gptq_marlin_gemm(
# x,
# self.marlin_q_w,
# marlin_s,
# self.g_idx,
# self.sort_indices,
# self.workspace.scratch,
# self.num_bits,
# x.shape[0],
# self.n,
# x.shape[-1],
# self.is_k_full,
# )
if self.has_bias:
x = x + self.bias
orig_shape[-1] = self.n
return x.reshape(orig_shape).to(orig_dtype)
def unload(self):
if self.has_bias:
self.bias = None
self.marlin_q_w = None
self.marlin_s = None
self.g_idx = None
self.sort_indices = None
self.workspace = None
class KLinearMarlin(KLinearBase):
marlin_q_w: torch.Tensor
marlin_s: torch.Tensor
@ -483,7 +630,7 @@ class KLinearMarlin(KLinearBase):
self.n = weight.shape[1]
self.loaded = True
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None, **kwargs) -> torch.Tensor:
# Only support input x as BF16 and FP16
x = x.to(self.device)
orig_shape = list(x.shape)
@ -629,12 +776,13 @@ class KLinearCPUInfer(KLinearBase):
if self.w is not None:
self.w = None
if self.has_bias:
self.bias = None
self.bias = None
LINEAR_MAP = {
"KLinearMarlin": KLinearMarlin,
"KLinearTorch": KLinearTorch,
"KLinearCPUInfer": KLinearCPUInfer,
"VLinearMarlin": VLinearMarlin,
"KLinearFP8": KLinearFP8,
"KLinearQ8": KLinearQ8,
}
@ -668,13 +816,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
self.generate_linear = None
self.mode = InferenceState.UNLOAD
def forward(self, x):
def forward(self, x, bsz_tensor=None):
if self.mode == InferenceState.PREFILL:
assert self.prefill_linear is not None, "cpu linear is not initialized"
y = self.prefill_linear.forward(x)
y = self.prefill_linear.forward(x, bsz_tensor)
else:
assert self.generate_linear is not None, "gpu linear is not initialized"
y = self.generate_linear.forward(x)
y = self.generate_linear.forward(x, bsz_tensor)
return y
def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):
@ -717,3 +865,5 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
self.unload()
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")