mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 23:34:35 +00:00
Enable support for Intel XPU devices, add support for DeepSeek V2/V3 first
This commit is contained in:
parent
333351c7c8
commit
142fb7ce6c
22 changed files with 673 additions and 81 deletions
|
@ -14,18 +14,20 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
|||
import ctypes
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
import KTransformersOps
|
||||
import vLLMMarlin
|
||||
if not torch.xpu.is_available():
|
||||
import KTransformersOps
|
||||
import vLLMMarlin
|
||||
from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader
|
||||
from ktransformers.util.utils import InferenceState
|
||||
from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (
|
||||
MarlinWorkspace,
|
||||
marlin_quantize,
|
||||
GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MIN_THREAD_K,
|
||||
GPTQ_MARLIN_MAX_PARALLEL,
|
||||
vllm_marlin_quantize
|
||||
)
|
||||
if not torch.xpu.is_available():
|
||||
from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (
|
||||
MarlinWorkspace,
|
||||
marlin_quantize,
|
||||
GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MIN_THREAD_K,
|
||||
GPTQ_MARLIN_MAX_PARALLEL,
|
||||
vllm_marlin_quantize
|
||||
)
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
|
||||
|
@ -778,6 +780,75 @@ class KLinearCPUInfer(KLinearBase):
|
|||
if self.has_bias:
|
||||
self.bias = None
|
||||
|
||||
class KLinearIPEXLLM(KLinearBase):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module = None,
|
||||
device: str = "xpu",
|
||||
precision: str = "sym_int4",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
self.has_bias = False
|
||||
self.dtype = torch.get_default_dtype()
|
||||
self.weight = None
|
||||
self.has_bias = False
|
||||
self.precision = precision
|
||||
self.qtype = None
|
||||
|
||||
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:
|
||||
dtype = x.dtype
|
||||
out_device = x.device
|
||||
from ipex_llm.transformers.models.common import linear_forward
|
||||
x = linear_forward(x.half(), self.weight, self.qtype, self.out_features)
|
||||
|
||||
if self.has_bias:
|
||||
x = x + self.bias
|
||||
x = x.to(dtype=dtype, device=out_device)
|
||||
return x
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||
if self.loaded: return
|
||||
if device is None: device = self.device
|
||||
assert device.lower()[:3] == "xpu", "IPEX-LLM quantized linear only supports XPU device"
|
||||
if w is None: w = self.load_weight(device=device)
|
||||
|
||||
if isinstance(w, nn.Parameter):
|
||||
try:
|
||||
weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
||||
except:
|
||||
weight = w.to(dtype=self.dtype).T
|
||||
self.has_bias = False
|
||||
elif isinstance(w, tuple):
|
||||
try:
|
||||
weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
||||
except:
|
||||
weight = w[0].to(dtype=self.dtype).T
|
||||
self.bias = w[1].to(dtype=self.dtype)
|
||||
self.has_bias = True
|
||||
else:
|
||||
raise ValueError("Invalid weight type")
|
||||
weight = weight.to("cpu").float().transpose(0, 1).contiguous()
|
||||
|
||||
if self.has_bias:
|
||||
self.bias = self.bias.to(device)
|
||||
|
||||
# quantize linear weight
|
||||
from ipex_llm.transformers.models.common import quantize_linear
|
||||
paramsLowBit, qtype = quantize_linear(weight, self.in_features, self.precision)
|
||||
self.weight = paramsLowBit.to(device)
|
||||
self.qtype = qtype
|
||||
self.loaded = True
|
||||
|
||||
def unload(self):
|
||||
if self.weight is not None:
|
||||
self.weight = None
|
||||
if self.has_bias:
|
||||
self.bias = None
|
||||
|
||||
LINEAR_MAP = {
|
||||
"KLinearMarlin": KLinearMarlin,
|
||||
"KLinearTorch": KLinearTorch,
|
||||
|
@ -785,6 +856,7 @@ LINEAR_MAP = {
|
|||
"VLinearMarlin": VLinearMarlin,
|
||||
"KLinearFP8": KLinearFP8,
|
||||
"KLinearQ8": KLinearQ8,
|
||||
"KLinearIPEXLLM": KLinearIPEXLLM,
|
||||
}
|
||||
|
||||
class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue