mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
[ADD] support multi-gpu qlen>1 q5_k
This commit is contained in:
parent
f293803156
commit
f5f79f5c0e
63 changed files with 3271 additions and 1285 deletions
|
@ -176,7 +176,7 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
|
|||
self.act_order = act_order
|
||||
self.is_k_full = is_k_full
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = "cuda"):
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||
if device is None: device = self.device
|
||||
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
|
||||
if w is None: w = self.load_weight(device=device)
|
||||
|
@ -200,7 +200,7 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
|
|||
weight, self.num_bits, self.group_size, self.act_order
|
||||
)
|
||||
self.workspace = MarlinWorkspace(
|
||||
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
||||
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device
|
||||
)
|
||||
self.marlin_q_w = marlin_q_w
|
||||
self.marlin_s = marlin_s
|
||||
|
@ -247,7 +247,6 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
|
|||
LINEAR_MAP = {
|
||||
"QuantizedLinearMarlin": QuantizedLinearMarlin,
|
||||
"QuantizedLinearTorch": QuantizedLinearTorch,
|
||||
"QuantizedLinearTorch": QuantizedLinearTorch,
|
||||
}
|
||||
|
||||
class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):
|
||||
|
@ -257,15 +256,15 @@ class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):
|
|||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
device: str = "cuda",
|
||||
# device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
generate_op: str| None = "QuantizedLinearMarlin",
|
||||
prefill_device: str = "cuda",
|
||||
prefill_op: str| None = "QuantizedLinearTorch",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
QuantizedLinearBase.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
QuantizedLinearBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
# build all the linear operators
|
||||
if prefill_op is not None:
|
||||
assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported"
|
||||
|
@ -289,7 +288,6 @@ class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):
|
|||
self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
else:
|
||||
self.generate_linear = None
|
||||
self.device = device
|
||||
self.mode = InferenceState.UNLOAD
|
||||
|
||||
def forward(self, x):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue