mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 20:19:51 +00:00
[feature] experts can be injected using CPUInfer
[fix] fix ktransformers interface when use new CUDAGraphRunner [fix] fix YAML and optimize logic, the top rule has the highest priority
This commit is contained in:
parent
80815dbc50
commit
412055d450
13 changed files with 318 additions and 158 deletions
18
ktransformers/operators/cpuinfer.py
Normal file
18
ktransformers/operators/cpuinfer.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
import sys, os
|
||||||
|
from typing import Any
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
|
||||||
|
import cpuinfer_ext
|
||||||
|
from ktransformers.server.config.config import Config
|
||||||
|
class CPUInfer:
|
||||||
|
cpu_infer = None
|
||||||
|
def __init__(self, cpu_infer:int = Config().cpu_infer):
|
||||||
|
if CPUInfer.cpu_infer is None:
|
||||||
|
CPUInfer.cpu_infer = cpuinfer_ext.CPUInfer(cpu_infer)
|
||||||
|
|
||||||
|
def __getattribute__(self, __name: str) -> Any:
|
||||||
|
return CPUInfer.cpu_infer.__getattribute__(__name)
|
||||||
|
|
||||||
|
def __setattr__(self, __name: str, __value: Any) -> None:
|
||||||
|
return CPUInfer.cpu_infer.__setattr__(__name, __value)
|
|
@ -33,6 +33,7 @@ from transformers.configuration_utils import PretrainedConfig
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from ktransformers.operators.linear import QuantizedLinearMarlin, QuantizedLinearTorch, KTransformerLinear
|
from ktransformers.operators.linear import QuantizedLinearMarlin, QuantizedLinearTorch, KTransformerLinear
|
||||||
import time
|
import time
|
||||||
|
from ktransformers.operators.cpuinfer import CPUInfer
|
||||||
|
|
||||||
|
|
||||||
# class Base(BaseInjectedModule, ABC):
|
# class Base(BaseInjectedModule, ABC):
|
||||||
|
@ -117,7 +118,7 @@ class MLPCPUExperts(MLPExpertsBase):
|
||||||
output_cpu:Tensor = None
|
output_cpu:Tensor = None
|
||||||
output_gpu_map:dict = {} # Manage output tensor buffer on different gpu
|
output_gpu_map:dict = {} # Manage output tensor buffer on different gpu
|
||||||
#stream_map:dict = {} # Manage cuda stream on different gpu
|
#stream_map:dict = {} # Manage cuda stream on different gpu
|
||||||
CPU_INFER = cpuinfer_ext.CPUInfer(Config().cpu_infer)
|
CPU_INFER = CPUInfer(Config().cpu_infer)
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
|
@ -126,7 +127,7 @@ class MLPCPUExperts(MLPExpertsBase):
|
||||||
n_routed_experts: int,
|
n_routed_experts: int,
|
||||||
orig_module: nn.Module = None,
|
orig_module: nn.Module = None,
|
||||||
device: str = "cpu",
|
device: str = "cpu",
|
||||||
out_device: str = "cuda", # this device mean which device the output should on
|
out_device: str = "cuda", # this device mean which device the output should on. TODO: support cpu.
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||||
|
@ -135,51 +136,50 @@ class MLPCPUExperts(MLPExpertsBase):
|
||||||
self.out_device = out_device
|
self.out_device = out_device
|
||||||
|
|
||||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False):
|
def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False):
|
||||||
with torch.device(self.out_device):
|
if device:
|
||||||
if device:
|
assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU, Parameter \"device\" can be cpu or None."
|
||||||
assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU, Parameter \"device\" can be cpu or None."
|
if w is None: w = self.load_weights()[self.key]
|
||||||
if w is None: w = self.load_weights()[self.key]
|
self.gate = w["gate"]
|
||||||
self.gate = w["gate"]
|
self.up = w["up"]
|
||||||
self.up = w["up"]
|
self.down = w["down"]
|
||||||
self.down = w["down"]
|
self.gate_type = w["gate_type"]
|
||||||
self.gate_type = w["gate_type"]
|
self.up_type = w["up_type"]
|
||||||
self.up_type = w["up_type"]
|
self.down_type = w["down_type"]
|
||||||
self.down_type = w["down_type"]
|
gate_ptr = ctypes.addressof(
|
||||||
gate_ptr = ctypes.addressof(
|
ctypes.cast(self.gate.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||||
ctypes.cast(self.gate.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
)
|
||||||
)
|
up_ptr = ctypes.addressof(
|
||||||
up_ptr = ctypes.addressof(
|
ctypes.cast(self.up.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||||
ctypes.cast(self.up.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
)
|
||||||
)
|
down_ptr = ctypes.addressof(
|
||||||
down_ptr = ctypes.addressof(
|
ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||||
ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
)
|
||||||
)
|
# print(self.gate_qtype, self.up_qtype, self.down_qtype)
|
||||||
# print(self.gate_qtype, self.up_qtype, self.down_qtype)
|
n_routed_experts = self.n_routed_experts
|
||||||
n_routed_experts = self.n_routed_experts
|
# n_routed_experts = len(self.orig_module)
|
||||||
# n_routed_experts = len(self.orig_module)
|
moe_config = MOEConfig(
|
||||||
moe_config = MOEConfig(
|
n_routed_experts,
|
||||||
n_routed_experts,
|
self.config.num_experts_per_tok,
|
||||||
self.config.num_experts_per_tok,
|
self.config.hidden_size,
|
||||||
self.config.hidden_size,
|
self.config.moe_intermediate_size,
|
||||||
self.config.moe_intermediate_size,
|
64,
|
||||||
64,
|
10,
|
||||||
10,
|
1024,
|
||||||
1024,
|
gate_ptr,
|
||||||
gate_ptr,
|
up_ptr,
|
||||||
up_ptr,
|
down_ptr,
|
||||||
down_ptr,
|
self.gate_type,
|
||||||
self.gate_type,
|
self.up_type,
|
||||||
self.up_type,
|
self.down_type,
|
||||||
self.down_type,
|
30, # TODO: get from model.dtype
|
||||||
30, # TODO: get from model.dtype
|
)
|
||||||
)
|
# print(n_routed_experts, hidden_size, moe_intermediate_size)
|
||||||
# print(n_routed_experts, hidden_size, moe_intermediate_size)
|
num_experts_per_tok = self.config.num_experts_per_tok
|
||||||
num_experts_per_tok = self.config.num_experts_per_tok
|
self.moe = MOE(moe_config)
|
||||||
self.moe = MOE(moe_config)
|
self.cpu_infer = MLPCPUExperts.CPU_INFER
|
||||||
self.cpu_infer = MLPCPUExperts.CPU_INFER
|
if warmup:
|
||||||
if warmup:
|
self.cpu_infer.submit(self.moe.warm_up())
|
||||||
self.cpu_infer.submit(self.moe.warm_up())
|
self.cpu_infer.sync()
|
||||||
self.cpu_infer.sync()
|
|
||||||
if self.out_device not in MLPCPUExperts.output_gpu_map:
|
if self.out_device not in MLPCPUExperts.output_gpu_map:
|
||||||
MLPCPUExperts.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device)
|
MLPCPUExperts.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device)
|
||||||
if MLPCPUExperts.input_tensor_cpu == None:
|
if MLPCPUExperts.input_tensor_cpu == None:
|
||||||
|
|
|
@ -11,8 +11,9 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
import ctypes
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import Tensor, nn
|
||||||
import KTransformersOps
|
import KTransformersOps
|
||||||
from ktransformers.util.custom_gguf import GGUFLoader
|
from ktransformers.util.custom_gguf import GGUFLoader
|
||||||
from ktransformers.util.utils import InferenceState
|
from ktransformers.util.utils import InferenceState
|
||||||
|
@ -25,7 +26,13 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl
|
||||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
import sys, os
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
|
||||||
|
import cpuinfer_ext
|
||||||
|
from ktransformers.operators.cpuinfer import CPUInfer
|
||||||
|
from ktransformers.server.config.config import Config
|
||||||
|
|
||||||
#class QuantizedLinearBase(BaseInjectedModule, ABC):
|
#class QuantizedLinearBase(BaseInjectedModule, ABC):
|
||||||
class QuantizedLinearBase(ABC):
|
class QuantizedLinearBase(ABC):
|
||||||
|
@ -118,6 +125,7 @@ class QuantizedLinearTorch(QuantizedLinearBase):
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
out_device = x.device
|
out_device = x.device
|
||||||
|
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
|
||||||
x = x.to(device=self.device, dtype=self.dtype)
|
x = x.to(device=self.device, dtype=self.dtype)
|
||||||
x = x @ self.w
|
x = x @ self.w
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
|
@ -128,7 +136,7 @@ class QuantizedLinearTorch(QuantizedLinearBase):
|
||||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||||
if device is None: device = self.device
|
if device is None: device = self.device
|
||||||
if w is None: w = self.load_weight(device=device)
|
if w is None: w = self.load_weight(device=device)
|
||||||
|
|
||||||
if isinstance(w, nn.Parameter):
|
if isinstance(w, nn.Parameter):
|
||||||
self.w = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
self.w = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
||||||
self.has_bias = False
|
self.has_bias = False
|
||||||
|
@ -243,10 +251,113 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
|
||||||
self.g_idx = None
|
self.g_idx = None
|
||||||
self.sort_indices = None
|
self.sort_indices = None
|
||||||
self.workspace = None
|
self.workspace = None
|
||||||
|
|
||||||
|
class QuantizedLinearCPUInfer(QuantizedLinearBase):
|
||||||
|
CPU_INFER = CPUInfer(Config().cpu_infer)
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
gguf_loader: GGUFLoader,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
orig_module: nn.Module = None,
|
||||||
|
device: str = "cpu",
|
||||||
|
out_device: str = "cuda", # this device mean which device the output should on. TODO: support cpu.
|
||||||
|
stride = 16,
|
||||||
|
group_max_len = 1024,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||||
|
self.has_bias = False
|
||||||
|
self.dtype = torch.get_default_dtype()
|
||||||
|
self.w = None
|
||||||
|
self.has_bias = False
|
||||||
|
self.stride = stride
|
||||||
|
self.group_max_len = group_max_len
|
||||||
|
self.out_device = out_device
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
origin_shape = x.shape # [batch_size, q_len, hidden_size]
|
||||||
|
if origin_shape[1] == 1:
|
||||||
|
out_device = x.device
|
||||||
|
self.input_tensor_cpu.copy_(x, non_blocking=True)
|
||||||
|
qlen = origin_shape[1]
|
||||||
|
QuantizedLinearCPUInfer.CPU_INFER.submit_with_cuda_stream(
|
||||||
|
torch.cuda.current_stream().cuda_stream,
|
||||||
|
self.linear.forward(
|
||||||
|
qlen,
|
||||||
|
self.input_tensor_cpu.data_ptr(),
|
||||||
|
self.output_cpu.data_ptr()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
QuantizedLinearCPUInfer.CPU_INFER.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
|
||||||
|
self.output_gpu.copy_(self.output_cpu, non_blocking=True)
|
||||||
|
if self.has_bias:
|
||||||
|
self.output_gpu += self.bias
|
||||||
|
return self.output_gpu
|
||||||
|
else:
|
||||||
|
dtype = x.dtype
|
||||||
|
out_device = x.device
|
||||||
|
x = x.to(device=self.device)
|
||||||
|
qlen = origin_shape[1]
|
||||||
|
output_shape = (*origin_shape[:-1], self.out_features)
|
||||||
|
output = torch.empty(output_shape, device=x.device, dtype=x.dtype)
|
||||||
|
QuantizedLinearCPUInfer.CPU_INFER.submit(
|
||||||
|
self.linear.forward(
|
||||||
|
qlen,
|
||||||
|
x.data_ptr(),
|
||||||
|
output.data_ptr()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
QuantizedLinearCPUInfer.CPU_INFER.sync()
|
||||||
|
if self.has_bias:
|
||||||
|
output = output + self.bias
|
||||||
|
output = output.to(dtype=dtype, device=out_device)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None, warmup:bool = True):
|
||||||
|
print(f"loading {self.key} to {self.device} using CPUInfer")
|
||||||
|
if device is None: device = self.device
|
||||||
|
self.load_weights(w=w, device=device)
|
||||||
|
if self.bias is not None:
|
||||||
|
self.has_bias = True
|
||||||
|
self.bias = self.bias.to(device)
|
||||||
|
|
||||||
|
weight_ptr = ctypes.addressof(
|
||||||
|
ctypes.cast(self.weight.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
|
||||||
|
)
|
||||||
|
config = cpuinfer_ext.linear.LinearConfig(self.in_features, self.out_features, self.stride, self.group_max_len, weight_ptr, self.weight_type, 30)
|
||||||
|
self.linear = cpuinfer_ext.linear.Linear(config)
|
||||||
|
|
||||||
|
if warmup:
|
||||||
|
QuantizedLinearCPUInfer.CPU_INFER.submit(self.linear.warm_up())
|
||||||
|
QuantizedLinearCPUInfer.CPU_INFER.sync()
|
||||||
|
self.input_tensor_cpu = torch.zeros((1, 1, self.in_features), device="cpu", pin_memory=True)
|
||||||
|
self.output_cpu = torch.zeros((1, 1, self.out_features), device="cpu", pin_memory=True, dtype=torch.bfloat16)
|
||||||
|
self.output_gpu = torch.zeros((1, 1, self.out_features), device=self.out_device)
|
||||||
|
|
||||||
|
def load_weights(self, w: dict | nn.Parameter | tuple | None = None, device: str = "cpu"):
|
||||||
|
if self.key + ".weight" in self.gguf_loader.tensor_info:
|
||||||
|
if self.key + ".bias" in self.gguf_loader.tensor_file_map:
|
||||||
|
self.weight = self.gguf_loader.get_mmap_tensor(self.key + ".weight")
|
||||||
|
self.weight_type = self.gguf_loader.tensor_info[self.key + ".weight"]["ggml_type"]
|
||||||
|
self.bias = self.gguf_loader.load_gguf_tensor(self.key + ".bias", device=device)
|
||||||
|
else:
|
||||||
|
self.weight = self.gguf_loader.get_mmap_tensor(self.key + ".weight")
|
||||||
|
self.weight_type = self.gguf_loader.tensor_info[self.key + ".weight"]["ggml_type"]
|
||||||
|
self.bias = None
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Linear {self.key} not found in gguf_loader")
|
||||||
|
|
||||||
|
def unload(self):
|
||||||
|
if self.w is not None:
|
||||||
|
self.w = None
|
||||||
|
if self.has_bias:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
LINEAR_MAP = {
|
LINEAR_MAP = {
|
||||||
"QuantizedLinearMarlin": QuantizedLinearMarlin,
|
"QuantizedLinearMarlin": QuantizedLinearMarlin,
|
||||||
"QuantizedLinearTorch": QuantizedLinearTorch,
|
"QuantizedLinearTorch": QuantizedLinearTorch,
|
||||||
|
"QuantizedLinearCPUInfer": QuantizedLinearCPUInfer
|
||||||
}
|
}
|
||||||
|
|
||||||
class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):
|
class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):
|
||||||
|
|
|
@ -58,7 +58,6 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
|
||||||
#print("gen_optimize_config", prefix, module_name, translated_name)
|
#print("gen_optimize_config", prefix, module_name, translated_name)
|
||||||
recursive = True
|
recursive = True
|
||||||
for rule in rule_list:
|
for rule in rule_list:
|
||||||
#print(rule)
|
|
||||||
match_meta = rule["match"]
|
match_meta = rule["match"]
|
||||||
if "class" not in match_meta and "name" not in match_meta:
|
if "class" not in match_meta and "name" not in match_meta:
|
||||||
raise Exception("match must have at least one of \"class\" and \"name\"")
|
raise Exception("match must have at least one of \"class\" and \"name\"")
|
||||||
|
@ -87,6 +86,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
|
||||||
out_data[module_name]["kwargs"].update(copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict())
|
out_data[module_name]["kwargs"].update(copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict())
|
||||||
if "recursive" in rule:
|
if "recursive" in rule:
|
||||||
recursive = bool(rule["recursive"])
|
recursive = bool(rule["recursive"])
|
||||||
|
break
|
||||||
|
|
||||||
if module_name not in out_data:
|
if module_name not in out_data:
|
||||||
out_data[module_name]= {
|
out_data[module_name]= {
|
||||||
|
@ -127,5 +127,6 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
inject(module, optimize_config, model_config, gguf_loader)
|
inject(module, optimize_config, model_config, gguf_loader)
|
||||||
load_weights(module, gguf_loader)
|
load_weights(module, gguf_loader)
|
||||||
model_config.gguf_loader = gguf_loader
|
module.gguf_loader = gguf_loader
|
||||||
del_meta(module)
|
del_meta(module)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
|
@ -1,32 +1,3 @@
|
||||||
- match:
|
|
||||||
name: "^model\\.layers\\.([0-9])\\."
|
|
||||||
replace:
|
|
||||||
class: "default"
|
|
||||||
kwargs:
|
|
||||||
generate_device: "cuda:0"
|
|
||||||
prefill_device: "cuda:0"
|
|
||||||
- match:
|
|
||||||
name: "(^model\\.layers\\.([1][0-9])\\.)"
|
|
||||||
replace:
|
|
||||||
class: "default"
|
|
||||||
kwargs:
|
|
||||||
generate_device: "cuda:1"
|
|
||||||
prefill_device: "cuda:1"
|
|
||||||
- match:
|
|
||||||
name: "(^model\\.layers\\.([2][0-9])\\.)"
|
|
||||||
replace:
|
|
||||||
class: "default"
|
|
||||||
kwargs:
|
|
||||||
generate_device: "cuda:2"
|
|
||||||
prefill_device: "cuda:2"
|
|
||||||
- match:
|
|
||||||
name: "(^model\\.layers\\.([345][0-9])\\.)|(^model.norm)|(^lm_head)"
|
|
||||||
replace:
|
|
||||||
class: "default"
|
|
||||||
kwargs:
|
|
||||||
generate_device: "cuda:3"
|
|
||||||
prefill_device: "cuda:3"
|
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
name: "^model.embed_tokens"
|
name: "^model.embed_tokens"
|
||||||
replace:
|
replace:
|
||||||
|
@ -69,7 +40,7 @@
|
||||||
prefill_device: "cuda:3"
|
prefill_device: "cuda:3"
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.([1][0-9])\\.(?!self_attn).*$" # regular expression
|
name: "^model\\.layers\\.([0-9])\\.(?!self_attn).*$" # regular expression
|
||||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types
|
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types
|
||||||
|
@ -225,4 +196,33 @@
|
||||||
transfer_map:
|
transfer_map:
|
||||||
10: "cuda:1"
|
10: "cuda:1"
|
||||||
20: "cuda:2"
|
20: "cuda:2"
|
||||||
30: "cuda:3"
|
30: "cuda:3"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([0-9])\\."
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
- match:
|
||||||
|
name: "(^model\\.layers\\.([1][0-9])\\.)"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
- match:
|
||||||
|
name: "(^model\\.layers\\.([2][0-9])\\.)"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:2"
|
||||||
|
prefill_device: "cuda:2"
|
||||||
|
- match:
|
||||||
|
name: "(^model\\.layers\\.([345][0-9])\\.)|(^model.norm)|(^lm_head)"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:3"
|
||||||
|
prefill_device: "cuda:3"
|
|
@ -1,19 +1,3 @@
|
||||||
- match:
|
|
||||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
|
|
||||||
replace:
|
|
||||||
class: "default"
|
|
||||||
kwargs:
|
|
||||||
generate_device: "cuda:0"
|
|
||||||
prefill_device: "cuda:0"
|
|
||||||
|
|
||||||
- match:
|
|
||||||
name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)|(lm_head)"
|
|
||||||
replace:
|
|
||||||
class: "default"
|
|
||||||
kwargs:
|
|
||||||
generate_device: "cuda:1"
|
|
||||||
prefill_device: "cuda:1"
|
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
name: "^model.embed_tokens"
|
name: "^model.embed_tokens"
|
||||||
replace:
|
replace:
|
||||||
|
@ -123,4 +107,20 @@
|
||||||
kwargs:
|
kwargs:
|
||||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
transfer_map:
|
transfer_map:
|
||||||
30: "cuda:1"
|
30: "cuda:1"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)|(lm_head)"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
|
@ -1,14 +1,21 @@
|
||||||
- match:
|
|
||||||
name: "^model\\.layers\\..*\\.|^lm_head"
|
|
||||||
replace:
|
|
||||||
class: "default"
|
|
||||||
kwargs:
|
|
||||||
generate_device: "cuda"
|
|
||||||
prefill_device: "cuda"
|
|
||||||
- match:
|
- match:
|
||||||
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
|
#- match:
|
||||||
|
# name: "^model\\.layers\\.([1-5][0-9])\\.mlp\\.shared_experts.*$" # regular expression
|
||||||
|
# class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
# replace:
|
||||||
|
# class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types
|
||||||
|
# kwargs:
|
||||||
|
# generate_device: "cpu"
|
||||||
|
# prefill_device: "cuda"
|
||||||
|
# generate_op: "QuantizedLinearCPUInfer"
|
||||||
|
# prefill_op: "QuantizedLinearTorch"
|
||||||
|
# out_device: "cuda"
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.(?!.*self_attn).*$" # regular expression
|
name: "^model\\.layers\\.(?!.*self_attn).*$" # regular expression
|
||||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
@ -24,6 +31,9 @@
|
||||||
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function
|
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||||
replace:
|
replace:
|
||||||
|
@ -39,16 +49,21 @@
|
||||||
name: "^model\\.layers\\..*\\.self_attn$"
|
name: "^model\\.layers\\..*\\.self_attn$"
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation
|
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
- match:
|
- match:
|
||||||
name: "^model$"
|
name: "^model$"
|
||||||
replace:
|
replace:
|
||||||
class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelKTransformers"
|
class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelKTransformers"
|
||||||
kwargs:
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
- match:
|
- match:
|
||||||
name: "^model.embed_tokens"
|
name: "^model.embed_tokens"
|
||||||
replace:
|
replace:
|
||||||
class: "default"
|
class: "default"
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cpu"
|
generate_device: "cpu"
|
||||||
prefill_device: "cpu"
|
prefill_device: "cpu"
|
|
@ -1,19 +1,3 @@
|
||||||
- match:
|
|
||||||
name: "^model\\.layers\\.(0|[1-9])\\."
|
|
||||||
replace:
|
|
||||||
class: "default"
|
|
||||||
kwargs:
|
|
||||||
generate_device: "cuda:0"
|
|
||||||
prefill_device: "cuda:0"
|
|
||||||
|
|
||||||
- match:
|
|
||||||
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)|(lm_head)"
|
|
||||||
replace:
|
|
||||||
class: "default"
|
|
||||||
kwargs:
|
|
||||||
generate_device: "cuda:1"
|
|
||||||
prefill_device: "cuda:1"
|
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
name: "^model.embed_tokens"
|
name: "^model.embed_tokens"
|
||||||
replace:
|
replace:
|
||||||
|
@ -123,4 +107,20 @@
|
||||||
kwargs:
|
kwargs:
|
||||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
transfer_map:
|
transfer_map:
|
||||||
10: "cuda:1"
|
10: "cuda:1"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9])\\."
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)|(lm_head)"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
|
@ -1,14 +1,10 @@
|
||||||
- match:
|
|
||||||
name: "^model\\.layers\\..*\\."
|
|
||||||
replace:
|
|
||||||
class: "default"
|
|
||||||
kwargs:
|
|
||||||
generate_device: "cuda"
|
|
||||||
prefill_device: "cuda"
|
|
||||||
- match:
|
- match:
|
||||||
class: ktransformers.models.modeling_mixtral.MixtralRotaryEmbedding
|
class: ktransformers.models.modeling_mixtral.MixtralRotaryEmbedding
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.RoPE.RotaryEmbedding
|
class: ktransformers.operators.RoPE.RotaryEmbedding
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\..*$"
|
name: "^model\\.layers\\..*$"
|
||||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
@ -43,3 +39,11 @@
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cpu"
|
generate_device: "cpu"
|
||||||
prefill_device: "cpu"
|
prefill_device: "cpu"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\."
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
|
@ -1,10 +1,3 @@
|
||||||
- match:
|
|
||||||
name: "^model\\.layers\\.([012])\\."
|
|
||||||
replace:
|
|
||||||
class: "default"
|
|
||||||
kwargs:
|
|
||||||
generate_device: "cuda:0"
|
|
||||||
prefill_device: "cuda:0"
|
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.([012])\\."
|
name: "^model\\.layers\\.([012])\\."
|
||||||
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
|
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
|
||||||
|
@ -41,13 +34,6 @@
|
||||||
out_device: "cuda:0"
|
out_device: "cuda:0"
|
||||||
recursive: False # don't recursively inject submodules of this module
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
|
||||||
- match:
|
|
||||||
name: "^model\\.layers\\.([12][0-9]|[3-9])\\."
|
|
||||||
replace:
|
|
||||||
class: "default"
|
|
||||||
kwargs:
|
|
||||||
generate_device: "cuda:1"
|
|
||||||
prefill_device: "cuda:1"
|
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.([12][0-9]|[3-9])\\."
|
name: "^model\\.layers\\.([12][0-9]|[3-9])\\."
|
||||||
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
|
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
|
||||||
|
@ -109,3 +95,18 @@
|
||||||
transfer_map:
|
transfer_map:
|
||||||
3: "cuda:1"
|
3: "cuda:1"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([012])\\."
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([12][0-9]|[3-9])\\."
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
|
@ -1,14 +1,10 @@
|
||||||
- match:
|
|
||||||
name: "^model\\.layers\\..*\\."
|
|
||||||
replace:
|
|
||||||
class: "default"
|
|
||||||
kwargs:
|
|
||||||
generate_device: "cuda"
|
|
||||||
prefill_device: "cuda"
|
|
||||||
- match:
|
- match:
|
||||||
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
|
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.RoPE.RotaryEmbedding
|
class: ktransformers.operators.RoPE.RotaryEmbedding
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\..*$" # regular expression
|
name: "^model\\.layers\\..*$" # regular expression
|
||||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
@ -24,6 +20,9 @@
|
||||||
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
|
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.experts.Qwen2MoeSparseMoeBlockInjected # mlp module with custom forward function
|
class: ktransformers.operators.experts.Qwen2MoeSparseMoeBlockInjected # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||||
replace:
|
replace:
|
||||||
|
@ -48,4 +47,11 @@
|
||||||
class: "default"
|
class: "default"
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cpu"
|
generate_device: "cpu"
|
||||||
prefill_device: "cpu"
|
prefill_device: "cpu"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\..*\\."
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda"
|
||||||
|
prefill_device: "cuda"
|
|
@ -6,6 +6,7 @@ from ktransformers.optimize.optimize import optimize_and_load_gguf
|
||||||
from ktransformers.models.custom_cache import StaticCache
|
from ktransformers.models.custom_cache import StaticCache
|
||||||
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||||
from ktransformers.local_chat import custom_models, default_optimize_rules
|
from ktransformers.local_chat import custom_models, default_optimize_rules
|
||||||
|
from ktransformers.util.utils import get_device
|
||||||
|
|
||||||
|
|
||||||
class KTransformersThreadContext(TransformersThreadContext):
|
class KTransformersThreadContext(TransformersThreadContext):
|
||||||
|
@ -48,8 +49,11 @@ class KTransformersInterface(TransformersInterface):
|
||||||
|
|
||||||
def decode_one_tokens(self):
|
def decode_one_tokens(self):
|
||||||
if not hasattr(self, "cuda_graph_runner"):
|
if not hasattr(self, "cuda_graph_runner"):
|
||||||
|
device_map = self.model.gguf_loader.tensor_device_map
|
||||||
|
torch_device = get_device('blk.0.self_attn', device_map)
|
||||||
|
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||||
self.cuda_graph_runner = CUDAGraphRunner()
|
self.cuda_graph_runner = CUDAGraphRunner()
|
||||||
self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, return_dict=False, use_cache=True)
|
self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, main_device=torch_device, return_dict=False, use_cache=True)
|
||||||
|
|
||||||
if hasattr(self, "cuda_graph_runner"):
|
if hasattr(self, "cuda_graph_runner"):
|
||||||
logits = self.cuda_graph_runner(self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position)
|
logits = self.cuda_graph_runner(self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position)
|
||||||
|
|
|
@ -89,7 +89,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
torch._dynamo.config.suppress_errors = True
|
torch._dynamo.config.suppress_errors = True
|
||||||
batch_size, seq_length = inputs.shape
|
batch_size, seq_length = inputs.shape
|
||||||
device_map = model.config.gguf_loader.tensor_device_map
|
device_map = model.gguf_loader.tensor_device_map
|
||||||
torch_device = get_device('blk.0.self_attn', device_map)
|
torch_device = get_device('blk.0.self_attn', device_map)
|
||||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||||
inputs = inputs.to(torch_device)
|
inputs = inputs.to(torch_device)
|
||||||
|
|
Loading…
Add table
Reference in a new issue