mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
[feature] experts can be injected using CPUInfer
[fix] fix ktransformers interface when use new CUDAGraphRunner [fix] fix YAML and optimize logic, the top rule has the highest priority
This commit is contained in:
parent
80815dbc50
commit
412055d450
13 changed files with 318 additions and 158 deletions
|
@ -11,8 +11,9 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
|||
'''
|
||||
|
||||
|
||||
import ctypes
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import Tensor, nn
|
||||
import KTransformersOps
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from ktransformers.util.utils import InferenceState
|
||||
|
@ -25,7 +26,13 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl
|
|||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import sys, os
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
|
||||
import cpuinfer_ext
|
||||
from ktransformers.operators.cpuinfer import CPUInfer
|
||||
from ktransformers.server.config.config import Config
|
||||
|
||||
#class QuantizedLinearBase(BaseInjectedModule, ABC):
|
||||
class QuantizedLinearBase(ABC):
|
||||
|
@ -118,6 +125,7 @@ class QuantizedLinearTorch(QuantizedLinearBase):
|
|||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
dtype = x.dtype
|
||||
out_device = x.device
|
||||
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
|
||||
x = x.to(device=self.device, dtype=self.dtype)
|
||||
x = x @ self.w
|
||||
if self.has_bias:
|
||||
|
@ -128,7 +136,7 @@ class QuantizedLinearTorch(QuantizedLinearBase):
|
|||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||
if device is None: device = self.device
|
||||
if w is None: w = self.load_weight(device=device)
|
||||
|
||||
|
||||
if isinstance(w, nn.Parameter):
|
||||
self.w = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
||||
self.has_bias = False
|
||||
|
@ -243,10 +251,113 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
|
|||
self.g_idx = None
|
||||
self.sort_indices = None
|
||||
self.workspace = None
|
||||
|
||||
|
||||
class QuantizedLinearCPUInfer(QuantizedLinearBase):
|
||||
CPU_INFER = CPUInfer(Config().cpu_infer)
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module = None,
|
||||
device: str = "cpu",
|
||||
out_device: str = "cuda", # this device mean which device the output should on. TODO: support cpu.
|
||||
stride = 16,
|
||||
group_max_len = 1024,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
self.has_bias = False
|
||||
self.dtype = torch.get_default_dtype()
|
||||
self.w = None
|
||||
self.has_bias = False
|
||||
self.stride = stride
|
||||
self.group_max_len = group_max_len
|
||||
self.out_device = out_device
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
origin_shape = x.shape # [batch_size, q_len, hidden_size]
|
||||
if origin_shape[1] == 1:
|
||||
out_device = x.device
|
||||
self.input_tensor_cpu.copy_(x, non_blocking=True)
|
||||
qlen = origin_shape[1]
|
||||
QuantizedLinearCPUInfer.CPU_INFER.submit_with_cuda_stream(
|
||||
torch.cuda.current_stream().cuda_stream,
|
||||
self.linear.forward(
|
||||
qlen,
|
||||
self.input_tensor_cpu.data_ptr(),
|
||||
self.output_cpu.data_ptr()
|
||||
)
|
||||
)
|
||||
QuantizedLinearCPUInfer.CPU_INFER.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
|
||||
self.output_gpu.copy_(self.output_cpu, non_blocking=True)
|
||||
if self.has_bias:
|
||||
self.output_gpu += self.bias
|
||||
return self.output_gpu
|
||||
else:
|
||||
dtype = x.dtype
|
||||
out_device = x.device
|
||||
x = x.to(device=self.device)
|
||||
qlen = origin_shape[1]
|
||||
output_shape = (*origin_shape[:-1], self.out_features)
|
||||
output = torch.empty(output_shape, device=x.device, dtype=x.dtype)
|
||||
QuantizedLinearCPUInfer.CPU_INFER.submit(
|
||||
self.linear.forward(
|
||||
qlen,
|
||||
x.data_ptr(),
|
||||
output.data_ptr()
|
||||
)
|
||||
)
|
||||
QuantizedLinearCPUInfer.CPU_INFER.sync()
|
||||
if self.has_bias:
|
||||
output = output + self.bias
|
||||
output = output.to(dtype=dtype, device=out_device)
|
||||
return output
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None, warmup:bool = True):
|
||||
print(f"loading {self.key} to {self.device} using CPUInfer")
|
||||
if device is None: device = self.device
|
||||
self.load_weights(w=w, device=device)
|
||||
if self.bias is not None:
|
||||
self.has_bias = True
|
||||
self.bias = self.bias.to(device)
|
||||
|
||||
weight_ptr = ctypes.addressof(
|
||||
ctypes.cast(self.weight.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||
)
|
||||
config = cpuinfer_ext.linear.LinearConfig(self.in_features, self.out_features, self.stride, self.group_max_len, weight_ptr, self.weight_type, 30)
|
||||
self.linear = cpuinfer_ext.linear.Linear(config)
|
||||
|
||||
if warmup:
|
||||
QuantizedLinearCPUInfer.CPU_INFER.submit(self.linear.warm_up())
|
||||
QuantizedLinearCPUInfer.CPU_INFER.sync()
|
||||
self.input_tensor_cpu = torch.zeros((1, 1, self.in_features), device="cpu", pin_memory=True)
|
||||
self.output_cpu = torch.zeros((1, 1, self.out_features), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||
self.output_gpu = torch.zeros((1, 1, self.out_features), device=self.out_device)
|
||||
|
||||
def load_weights(self, w: dict | nn.Parameter | tuple | None = None, device: str = "cpu"):
|
||||
if self.key + ".weight" in self.gguf_loader.tensor_info:
|
||||
if self.key + ".bias" in self.gguf_loader.tensor_file_map:
|
||||
self.weight = self.gguf_loader.get_mmap_tensor(self.key + ".weight")
|
||||
self.weight_type = self.gguf_loader.tensor_info[self.key + ".weight"]["ggml_type"]
|
||||
self.bias = self.gguf_loader.load_gguf_tensor(self.key + ".bias", device=device)
|
||||
else:
|
||||
self.weight = self.gguf_loader.get_mmap_tensor(self.key + ".weight")
|
||||
self.weight_type = self.gguf_loader.tensor_info[self.key + ".weight"]["ggml_type"]
|
||||
self.bias = None
|
||||
else:
|
||||
raise ValueError(f"Linear {self.key} not found in gguf_loader")
|
||||
|
||||
def unload(self):
|
||||
if self.w is not None:
|
||||
self.w = None
|
||||
if self.has_bias:
|
||||
self.bias = None
|
||||
|
||||
LINEAR_MAP = {
|
||||
"QuantizedLinearMarlin": QuantizedLinearMarlin,
|
||||
"QuantizedLinearTorch": QuantizedLinearTorch,
|
||||
"QuantizedLinearCPUInfer": QuantizedLinearCPUInfer
|
||||
}
|
||||
|
||||
class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue