[ADD] support multi-gpu qlen>1 q5_k

This commit is contained in:
chenxl 2024-08-12 11:17:29 +00:00
parent f293803156
commit f5f79f5c0e
63 changed files with 3271 additions and 1285 deletions

View file

@ -176,7 +176,7 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
self.act_order = act_order
self.is_k_full = is_k_full
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = "cuda"):
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if device is None: device = self.device
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
if w is None: w = self.load_weight(device=device)
@ -200,7 +200,7 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
weight, self.num_bits, self.group_size, self.act_order
)
self.workspace = MarlinWorkspace(
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device
)
self.marlin_q_w = marlin_q_w
self.marlin_s = marlin_s
@ -247,7 +247,6 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
LINEAR_MAP = {
"QuantizedLinearMarlin": QuantizedLinearMarlin,
"QuantizedLinearTorch": QuantizedLinearTorch,
"QuantizedLinearTorch": QuantizedLinearTorch,
}
class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):
@ -257,15 +256,15 @@ class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
# device: str = "cuda",
generate_device: str = "cuda",
generate_op: str| None = "QuantizedLinearMarlin",
prefill_device: str = "cuda",
prefill_op: str| None = "QuantizedLinearTorch",
**kwargs,
):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
QuantizedLinearBase.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
QuantizedLinearBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
# build all the linear operators
if prefill_op is not None:
assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported"
@ -289,7 +288,6 @@ class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):
self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)
else:
self.generate_linear = None
self.device = device
self.mode = InferenceState.UNLOAD
def forward(self, x):