[feature] experts can be injected using CPUInfer

[fix] fix ktransformers interface when use new CUDAGraphRunner
[fix] fix YAML and optimize logic, the top rule has the highest priority
This commit is contained in:
Atream 2024-08-14 16:10:54 +08:00
parent 80815dbc50
commit 412055d450
13 changed files with 318 additions and 158 deletions

View file

@ -0,0 +1,18 @@
import sys, os
from typing import Any
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
import cpuinfer_ext
from ktransformers.server.config.config import Config
class CPUInfer:
cpu_infer = None
def __init__(self, cpu_infer:int = Config().cpu_infer):
if CPUInfer.cpu_infer is None:
CPUInfer.cpu_infer = cpuinfer_ext.CPUInfer(cpu_infer)
def __getattribute__(self, __name: str) -> Any:
return CPUInfer.cpu_infer.__getattribute__(__name)
def __setattr__(self, __name: str, __value: Any) -> None:
return CPUInfer.cpu_infer.__setattr__(__name, __value)

View file

@ -33,6 +33,7 @@ from transformers.configuration_utils import PretrainedConfig
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from ktransformers.operators.linear import QuantizedLinearMarlin, QuantizedLinearTorch, KTransformerLinear from ktransformers.operators.linear import QuantizedLinearMarlin, QuantizedLinearTorch, KTransformerLinear
import time import time
from ktransformers.operators.cpuinfer import CPUInfer
# class Base(BaseInjectedModule, ABC): # class Base(BaseInjectedModule, ABC):
@ -117,7 +118,7 @@ class MLPCPUExperts(MLPExpertsBase):
output_cpu:Tensor = None output_cpu:Tensor = None
output_gpu_map:dict = {} # Manage output tensor buffer on different gpu output_gpu_map:dict = {} # Manage output tensor buffer on different gpu
#stream_map:dict = {} # Manage cuda stream on different gpu #stream_map:dict = {} # Manage cuda stream on different gpu
CPU_INFER = cpuinfer_ext.CPUInfer(Config().cpu_infer) CPU_INFER = CPUInfer(Config().cpu_infer)
def __init__( def __init__(
self, self,
key: str, key: str,
@ -126,7 +127,7 @@ class MLPCPUExperts(MLPExpertsBase):
n_routed_experts: int, n_routed_experts: int,
orig_module: nn.Module = None, orig_module: nn.Module = None,
device: str = "cpu", device: str = "cpu",
out_device: str = "cuda", # this device mean which device the output should on out_device: str = "cuda", # this device mean which device the output should on. TODO: support cpu.
**kwargs **kwargs
): ):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
@ -135,51 +136,50 @@ class MLPCPUExperts(MLPExpertsBase):
self.out_device = out_device self.out_device = out_device
def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False): def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False):
with torch.device(self.out_device): if device:
if device: assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU, Parameter \"device\" can be cpu or None."
assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU, Parameter \"device\" can be cpu or None." if w is None: w = self.load_weights()[self.key]
if w is None: w = self.load_weights()[self.key] self.gate = w["gate"]
self.gate = w["gate"] self.up = w["up"]
self.up = w["up"] self.down = w["down"]
self.down = w["down"] self.gate_type = w["gate_type"]
self.gate_type = w["gate_type"] self.up_type = w["up_type"]
self.up_type = w["up_type"] self.down_type = w["down_type"]
self.down_type = w["down_type"] gate_ptr = ctypes.addressof(
gate_ptr = ctypes.addressof( ctypes.cast(self.gate.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
ctypes.cast(self.gate.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents )
) up_ptr = ctypes.addressof(
up_ptr = ctypes.addressof( ctypes.cast(self.up.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
ctypes.cast(self.up.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents )
) down_ptr = ctypes.addressof(
down_ptr = ctypes.addressof( ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents )
) # print(self.gate_qtype, self.up_qtype, self.down_qtype)
# print(self.gate_qtype, self.up_qtype, self.down_qtype) n_routed_experts = self.n_routed_experts
n_routed_experts = self.n_routed_experts # n_routed_experts = len(self.orig_module)
# n_routed_experts = len(self.orig_module) moe_config = MOEConfig(
moe_config = MOEConfig( n_routed_experts,
n_routed_experts, self.config.num_experts_per_tok,
self.config.num_experts_per_tok, self.config.hidden_size,
self.config.hidden_size, self.config.moe_intermediate_size,
self.config.moe_intermediate_size, 64,
64, 10,
10, 1024,
1024, gate_ptr,
gate_ptr, up_ptr,
up_ptr, down_ptr,
down_ptr, self.gate_type,
self.gate_type, self.up_type,
self.up_type, self.down_type,
self.down_type, 30, # TODO: get from model.dtype
30, # TODO: get from model.dtype )
) # print(n_routed_experts, hidden_size, moe_intermediate_size)
# print(n_routed_experts, hidden_size, moe_intermediate_size) num_experts_per_tok = self.config.num_experts_per_tok
num_experts_per_tok = self.config.num_experts_per_tok self.moe = MOE(moe_config)
self.moe = MOE(moe_config) self.cpu_infer = MLPCPUExperts.CPU_INFER
self.cpu_infer = MLPCPUExperts.CPU_INFER if warmup:
if warmup: self.cpu_infer.submit(self.moe.warm_up())
self.cpu_infer.submit(self.moe.warm_up()) self.cpu_infer.sync()
self.cpu_infer.sync()
if self.out_device not in MLPCPUExperts.output_gpu_map: if self.out_device not in MLPCPUExperts.output_gpu_map:
MLPCPUExperts.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device) MLPCPUExperts.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device)
if MLPCPUExperts.input_tensor_cpu == None: if MLPCPUExperts.input_tensor_cpu == None:

View file

@ -11,8 +11,9 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
''' '''
import ctypes
import torch import torch
from torch import nn from torch import Tensor, nn
import KTransformersOps import KTransformersOps
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import InferenceState from ktransformers.util.utils import InferenceState
@ -25,7 +26,13 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
import cpuinfer_ext
from ktransformers.operators.cpuinfer import CPUInfer
from ktransformers.server.config.config import Config
#class QuantizedLinearBase(BaseInjectedModule, ABC): #class QuantizedLinearBase(BaseInjectedModule, ABC):
class QuantizedLinearBase(ABC): class QuantizedLinearBase(ABC):
@ -118,6 +125,7 @@ class QuantizedLinearTorch(QuantizedLinearBase):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype dtype = x.dtype
out_device = x.device out_device = x.device
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
x = x.to(device=self.device, dtype=self.dtype) x = x.to(device=self.device, dtype=self.dtype)
x = x @ self.w x = x @ self.w
if self.has_bias: if self.has_bias:
@ -244,9 +252,112 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
self.sort_indices = None self.sort_indices = None
self.workspace = None self.workspace = None
class QuantizedLinearCPUInfer(QuantizedLinearBase):
CPU_INFER = CPUInfer(Config().cpu_infer)
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
device: str = "cpu",
out_device: str = "cuda", # this device mean which device the output should on. TODO: support cpu.
stride = 16,
group_max_len = 1024,
**kwargs,
):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.has_bias = False
self.dtype = torch.get_default_dtype()
self.w = None
self.has_bias = False
self.stride = stride
self.group_max_len = group_max_len
self.out_device = out_device
def forward(self, x: torch.Tensor) -> torch.Tensor:
origin_shape = x.shape # [batch_size, q_len, hidden_size]
if origin_shape[1] == 1:
out_device = x.device
self.input_tensor_cpu.copy_(x, non_blocking=True)
qlen = origin_shape[1]
QuantizedLinearCPUInfer.CPU_INFER.submit_with_cuda_stream(
torch.cuda.current_stream().cuda_stream,
self.linear.forward(
qlen,
self.input_tensor_cpu.data_ptr(),
self.output_cpu.data_ptr()
)
)
QuantizedLinearCPUInfer.CPU_INFER.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
self.output_gpu.copy_(self.output_cpu, non_blocking=True)
if self.has_bias:
self.output_gpu += self.bias
return self.output_gpu
else:
dtype = x.dtype
out_device = x.device
x = x.to(device=self.device)
qlen = origin_shape[1]
output_shape = (*origin_shape[:-1], self.out_features)
output = torch.empty(output_shape, device=x.device, dtype=x.dtype)
QuantizedLinearCPUInfer.CPU_INFER.submit(
self.linear.forward(
qlen,
x.data_ptr(),
output.data_ptr()
)
)
QuantizedLinearCPUInfer.CPU_INFER.sync()
if self.has_bias:
output = output + self.bias
output = output.to(dtype=dtype, device=out_device)
return output
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None, warmup:bool = True):
print(f"loading {self.key} to {self.device} using CPUInfer")
if device is None: device = self.device
self.load_weights(w=w, device=device)
if self.bias is not None:
self.has_bias = True
self.bias = self.bias.to(device)
weight_ptr = ctypes.addressof(
ctypes.cast(self.weight.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
)
config = cpuinfer_ext.linear.LinearConfig(self.in_features, self.out_features, self.stride, self.group_max_len, weight_ptr, self.weight_type, 30)
self.linear = cpuinfer_ext.linear.Linear(config)
if warmup:
QuantizedLinearCPUInfer.CPU_INFER.submit(self.linear.warm_up())
QuantizedLinearCPUInfer.CPU_INFER.sync()
self.input_tensor_cpu = torch.zeros((1, 1, self.in_features), device="cpu", pin_memory=True)
self.output_cpu = torch.zeros((1, 1, self.out_features), device="cpu", pin_memory=True, dtype=torch.bfloat16)
self.output_gpu = torch.zeros((1, 1, self.out_features), device=self.out_device)
def load_weights(self, w: dict | nn.Parameter | tuple | None = None, device: str = "cpu"):
if self.key + ".weight" in self.gguf_loader.tensor_info:
if self.key + ".bias" in self.gguf_loader.tensor_file_map:
self.weight = self.gguf_loader.get_mmap_tensor(self.key + ".weight")
self.weight_type = self.gguf_loader.tensor_info[self.key + ".weight"]["ggml_type"]
self.bias = self.gguf_loader.load_gguf_tensor(self.key + ".bias", device=device)
else:
self.weight = self.gguf_loader.get_mmap_tensor(self.key + ".weight")
self.weight_type = self.gguf_loader.tensor_info[self.key + ".weight"]["ggml_type"]
self.bias = None
else:
raise ValueError(f"Linear {self.key} not found in gguf_loader")
def unload(self):
if self.w is not None:
self.w = None
if self.has_bias:
self.bias = None
LINEAR_MAP = { LINEAR_MAP = {
"QuantizedLinearMarlin": QuantizedLinearMarlin, "QuantizedLinearMarlin": QuantizedLinearMarlin,
"QuantizedLinearTorch": QuantizedLinearTorch, "QuantizedLinearTorch": QuantizedLinearTorch,
"QuantizedLinearCPUInfer": QuantizedLinearCPUInfer
} }
class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase): class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):

View file

@ -58,7 +58,6 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
#print("gen_optimize_config", prefix, module_name, translated_name) #print("gen_optimize_config", prefix, module_name, translated_name)
recursive = True recursive = True
for rule in rule_list: for rule in rule_list:
#print(rule)
match_meta = rule["match"] match_meta = rule["match"]
if "class" not in match_meta and "name" not in match_meta: if "class" not in match_meta and "name" not in match_meta:
raise Exception("match must have at least one of \"class\" and \"name\"") raise Exception("match must have at least one of \"class\" and \"name\"")
@ -87,6 +86,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
out_data[module_name]["kwargs"].update(copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict()) out_data[module_name]["kwargs"].update(copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict())
if "recursive" in rule: if "recursive" in rule:
recursive = bool(rule["recursive"]) recursive = bool(rule["recursive"])
break
if module_name not in out_data: if module_name not in out_data:
out_data[module_name]= { out_data[module_name]= {
@ -127,5 +127,6 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
with torch.device("meta"): with torch.device("meta"):
inject(module, optimize_config, model_config, gguf_loader) inject(module, optimize_config, model_config, gguf_loader)
load_weights(module, gguf_loader) load_weights(module, gguf_loader)
model_config.gguf_loader = gguf_loader module.gguf_loader = gguf_loader
del_meta(module) del_meta(module)
torch.cuda.empty_cache()

View file

@ -1,32 +1,3 @@
- match:
name: "^model\\.layers\\.([0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([1][0-9])\\.)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "(^model\\.layers\\.([2][0-9])\\.)"
replace:
class: "default"
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
- match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(^model.norm)|(^lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
- match: - match:
name: "^model.embed_tokens" name: "^model.embed_tokens"
replace: replace:
@ -69,7 +40,7 @@
prefill_device: "cuda:3" prefill_device: "cuda:3"
- match: - match:
name: "^model\\.layers\\.([1][0-9])\\.(?!self_attn).*$" # regular expression name: "^model\\.layers\\.([0-9])\\.(?!self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
replace: replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types
@ -226,3 +197,32 @@
10: "cuda:1" 10: "cuda:1"
20: "cuda:2" 20: "cuda:2"
30: "cuda:3" 30: "cuda:3"
- match:
name: "^model\\.layers\\.([0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([1][0-9])\\.)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "(^model\\.layers\\.([2][0-9])\\.)"
replace:
class: "default"
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
- match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(^model.norm)|(^lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"

View file

@ -1,19 +1,3 @@
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)|(lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match: - match:
name: "^model.embed_tokens" name: "^model.embed_tokens"
replace: replace:
@ -124,3 +108,19 @@
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map: transfer_map:
30: "cuda:1" 30: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)|(lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"

View file

@ -1,14 +1,21 @@
- match:
name: "^model\\.layers\\..*\\.|^lm_head"
replace:
class: "default"
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match: - match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace: replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
#- match:
# name: "^model\\.layers\\.([1-5][0-9])\\.mlp\\.shared_experts.*$" # regular expression
# class: torch.nn.Linear # only match modules matching name and class simultaneously
# replace:
# class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types
# kwargs:
# generate_device: "cpu"
# prefill_device: "cuda"
# generate_op: "QuantizedLinearCPUInfer"
# prefill_op: "QuantizedLinearTorch"
# out_device: "cuda"
- match: - match:
name: "^model\\.layers\\.(?!.*self_attn).*$" # regular expression name: "^model\\.layers\\.(?!.*self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
@ -24,6 +31,9 @@
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace: replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match: - match:
name: "^model\\.layers\\..*\\.mlp\\.experts$" name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace: replace:
@ -39,16 +49,21 @@
name: "^model\\.layers\\..*\\.self_attn$" name: "^model\\.layers\\..*\\.self_attn$"
replace: replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match: - match:
name: "^model$" name: "^model$"
replace: replace:
class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelKTransformers" class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelKTransformers"
kwargs: kwargs:
generate_device: "cuda"
prefill_device: "cuda"
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match: - match:
name: "^model.embed_tokens" name: "^model.embed_tokens"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
generate_device: "cpu" generate_device: "cpu"
prefill_device: "cpu" prefill_device: "cpu"

View file

@ -1,19 +1,3 @@
- match:
name: "^model\\.layers\\.(0|[1-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)|(lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match: - match:
name: "^model.embed_tokens" name: "^model.embed_tokens"
replace: replace:
@ -124,3 +108,19 @@
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map: transfer_map:
10: "cuda:1" 10: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)|(lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"

View file

@ -1,14 +1,10 @@
- match:
name: "^model\\.layers\\..*\\."
replace:
class: "default"
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match: - match:
class: ktransformers.models.modeling_mixtral.MixtralRotaryEmbedding class: ktransformers.models.modeling_mixtral.MixtralRotaryEmbedding
replace: replace:
class: ktransformers.operators.RoPE.RotaryEmbedding class: ktransformers.operators.RoPE.RotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match: - match:
name: "^model\\.layers\\..*$" name: "^model\\.layers\\..*$"
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
@ -43,3 +39,11 @@
kwargs: kwargs:
generate_device: "cpu" generate_device: "cpu"
prefill_device: "cpu" prefill_device: "cpu"
- match:
name: "^model\\.layers\\..*\\."
replace:
class: "default"
kwargs:
generate_device: "cuda"
prefill_device: "cuda"

View file

@ -1,10 +1,3 @@
- match:
name: "^model\\.layers\\.([012])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match: - match:
name: "^model\\.layers\\.([012])\\." name: "^model\\.layers\\.([012])\\."
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
@ -41,13 +34,6 @@
out_device: "cuda:0" out_device: "cuda:0"
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([12][0-9]|[3-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match: - match:
name: "^model\\.layers\\.([12][0-9]|[3-9])\\." name: "^model\\.layers\\.([12][0-9]|[3-9])\\."
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
@ -109,3 +95,18 @@
transfer_map: transfer_map:
3: "cuda:1" 3: "cuda:1"
- match:
name: "^model\\.layers\\.([012])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([12][0-9]|[3-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"

View file

@ -1,14 +1,10 @@
- match:
name: "^model\\.layers\\..*\\."
replace:
class: "default"
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match: - match:
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
replace: replace:
class: ktransformers.operators.RoPE.RotaryEmbedding class: ktransformers.operators.RoPE.RotaryEmbedding
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match: - match:
name: "^model\\.layers\\..*$" # regular expression name: "^model\\.layers\\..*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
@ -24,6 +20,9 @@
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
replace: replace:
class: ktransformers.operators.experts.Qwen2MoeSparseMoeBlockInjected # mlp module with custom forward function class: ktransformers.operators.experts.Qwen2MoeSparseMoeBlockInjected # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match: - match:
name: "^model\\.layers\\..*\\.mlp\\.experts$" name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace: replace:
@ -49,3 +48,10 @@
kwargs: kwargs:
generate_device: "cpu" generate_device: "cpu"
prefill_device: "cpu" prefill_device: "cpu"
- match:
name: "^model\\.layers\\..*\\."
replace:
class: "default"
kwargs:
generate_device: "cuda"
prefill_device: "cuda"

View file

@ -6,6 +6,7 @@ from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.custom_cache import StaticCache from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.local_chat import custom_models, default_optimize_rules from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device
class KTransformersThreadContext(TransformersThreadContext): class KTransformersThreadContext(TransformersThreadContext):
@ -48,8 +49,11 @@ class KTransformersInterface(TransformersInterface):
def decode_one_tokens(self): def decode_one_tokens(self):
if not hasattr(self, "cuda_graph_runner"): if not hasattr(self, "cuda_graph_runner"):
device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device('blk.0.self_attn', device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
self.cuda_graph_runner = CUDAGraphRunner() self.cuda_graph_runner = CUDAGraphRunner()
self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, return_dict=False, use_cache=True) self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, main_device=torch_device, return_dict=False, use_cache=True)
if hasattr(self, "cuda_graph_runner"): if hasattr(self, "cuda_graph_runner"):
logits = self.cuda_graph_runner(self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position) logits = self.cuda_graph_runner(self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position)

View file

@ -89,7 +89,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._dynamo.config.suppress_errors = True torch._dynamo.config.suppress_errors = True
batch_size, seq_length = inputs.shape batch_size, seq_length = inputs.shape
device_map = model.config.gguf_loader.tensor_device_map device_map = model.gguf_loader.tensor_device_map
torch_device = get_device('blk.0.self_attn', device_map) torch_device = get_device('blk.0.self_attn', device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device torch_device = "cuda:0" if torch_device == "cuda" else torch_device
inputs = inputs.to(torch_device) inputs = inputs.to(torch_device)