[fix] format classes and files name

This commit is contained in:
TangJingqi 2024-08-15 10:44:59 +08:00
parent 1db4a67dca
commit 67043b4b5c
15 changed files with 212 additions and 212 deletions

View file

@ -276,11 +276,11 @@ Below is an example of a YAML template for replacing all original Linear modules
name: "^model\\.layers\\..*$" # regular expression name: "^model\\.layers\\..*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
replace: replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
device: "cpu" # which devices to load this module when initializing device: "cpu" # which devices to load this module when initializing
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
generate_linear_type: "QuantizedLinearMarlin" generate_linear_type: "KLinearMarlin"
``` ```
Each rule in the YAML file has two parts: `match` and `replace`. The `match` part specifies which module should be replaced, and the `replace` part specifies the module to be injected into the model along with the initialization keywords. Each rule in the YAML file has two parts: `match` and `replace`. The `match` part specifies which module should be replaced, and the `replace` part specifies the module to be injected into the model along with the initialization keywords.

View file

@ -90,7 +90,7 @@ The YAML rule is listed below.
- match: - match:
name: "^model\\.layers\\..*\\.self_attn$" # regular expression name: "^model\\.layers\\..*\\.self_attn$" # regular expression
replace: replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
``` ```
As we can see, each rule in the YAML file has two parts: `match` and `replace`. As we can see, each rule in the YAML file has two parts: `match` and `replace`.
@ -98,9 +98,9 @@ The match part specifies which module should be replaced, and the replace part s
<h3 id="experts">Routed Experts </h3> <h3 id="experts">Routed Experts </h3>
For routed experts, the module we inject is a wrapper of CPUInfer, KTransformersMLPExpert. There are several implementations within a wrapper, and we need to specify keywords to tell the wrapper which implementation we want to use and how we intend to use it. For routed experts, the module we inject is a wrapper of CPUInfer, KTransformersExperts. There are several implementations within a wrapper, and we need to specify keywords to tell the wrapper which implementation we want to use and how we intend to use it.
In KTransformers, some models exhibit different behaviors during prefilling and generation for better performance. KTransformersMLPExpert is one of them. All these special modules have a `device` keyword describing which device the module should be initialized on. Other keywords specify the behaviors during prefilling and generation and may be differ when using different injection modules. Here, we specify which implementation on which device we want to use during prefilling and generation, and which device the output should be on. In KTransformers, some models exhibit different behaviors during prefilling and generation for better performance. KTransformersExperts is one of them. All these special modules have a `device` keyword describing which device the module should be initialized on. Other keywords specify the behaviors during prefilling and generation and may be differ when using different injection modules. Here, we specify which implementation on which device we want to use during prefilling and generation, and which device the output should be on.
Note that we only use these parameters when layer-wise prefilling is enabled; otherwise, prefilling is conducted with the same configuration as generation. Note that we only use these parameters when layer-wise prefilling is enabled; otherwise, prefilling is conducted with the same configuration as generation.
In the original implementation of Transformers, MoE is implemented using `nn.ModuleList`. We don't want KTransformers to iterate through all the sub-modules in the list, so we set `recursive: False` in this rule to prevent recursive injection into submodules of the current module. Here is the YAML rule: In the original implementation of Transformers, MoE is implemented using `nn.ModuleList`. We don't want KTransformers to iterate through all the sub-modules in the list, so we set `recursive: False` in this rule to prevent recursive injection into submodules of the current module. Here is the YAML rule:
@ -109,13 +109,13 @@ In the original implementation of Transformers, MoE is implemented using `nn.Mod
- match: - match:
name: "^model\\.layers\\..*\\.mlp\\.experts$" name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace: replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert parallelism class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert parallelism
device: "cpu" # device to load this module on initialization device: "cpu" # device to load this module on initialization
kwargs: kwargs:
prefill_device: "cuda" prefill_device: "cuda"
prefill_mlp_type: "MLPExpertsTorch" prefill_op: "KExpertsTorch"
generate_device: "cpu" generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts" generate_op: "KExpertsCPU"
out_device: "cuda" out_device: "cuda"
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
``` ```
@ -126,7 +126,7 @@ If we inject the expert list as a custom module, we can't use the interface in `
- match: - match:
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace: replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # MLP module with custom forward function class: ktransformers.operators.experts.KDeepseekV2MoE # MLP module with custom forward function
``` ```
<h3 id="linear">Other Linear Modules</h3> <h3 id="linear">Other Linear Modules</h3>
@ -140,12 +140,12 @@ We also need to transfer some keywords similar to the injection of experts. Here
name: "^model\\.layers\\.(?!.*self_attn).*$" # regular expression name: "^model\\.layers\\.(?!.*self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
replace: replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "QuantizedLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "QuantizedLinearTorch" prefill_op: "KLinearTorch"
``` ```
<h3 id="Pre-compute Buffers">Pre-compute Buffers </h3> <h3 id="Pre-compute Buffers">Pre-compute Buffers </h3>

View file

@ -15,7 +15,7 @@ from ktransformers.util.custom_gguf import GGUFLoader
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
class DeepseekV2AttentionInjected(BaseInjectedModule, DeepseekV2Attention): class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, def __init__(self,

View file

@ -5,8 +5,8 @@ Description :
Author : Azure-Tang, Boxin Zhang, chenht2022 Author : Azure-Tang, Boxin Zhang, chenht2022
Date : 2024-07-25 11:25:24 Date : 2024-07-25 11:25:24
Version : 0.1.0 Version : 0.1.0
LastEditors : kkk1nak0 LastEditors : Azure
LastEditTime : 2024-08-11 12:14:39 LastEditTime : 2024-08-15 02:36:29
Copyright (c) 2024 by KVCache.AI, All Rights Reserved. Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
''' '''
@ -31,13 +31,13 @@ from ktransformers.server.config.config import Config
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from ktransformers.operators.linear import QuantizedLinearMarlin, QuantizedLinearTorch, KTransformerLinear from ktransformers.operators.linear import KLinearMarlin, KLinearTorch, KTransformersLinear
import time import time
from ktransformers.operators.cpuinfer import CPUInfer from ktransformers.operators.cpuinfer import CPUInfer
# class Base(BaseInjectedModule, ABC): # class Base(BaseInjectedModule, ABC):
class MLPExpertsBase(ABC): class KExpertsBase(ABC):
def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs): def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs):
# super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) # super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.key = key self.key = key
@ -111,7 +111,7 @@ class MLPExpertsBase(ABC):
tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device) tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device)
return tensors return tensors
class MLPCPUExperts(MLPExpertsBase): class KExpertsCPU(KExpertsBase):
input_tensor_cpu:Tensor = None input_tensor_cpu:Tensor = None
expert_ids_cpu:Tensor = None expert_ids_cpu:Tensor = None
weights_cpu:Tensor = None weights_cpu:Tensor = None
@ -131,13 +131,13 @@ class MLPCPUExperts(MLPExpertsBase):
**kwargs **kwargs
): ):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU" assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU"
self.n_routed_experts = n_routed_experts self.n_routed_experts = n_routed_experts
self.out_device = out_device self.out_device = out_device
def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False): def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False):
if device: if device:
assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU, Parameter \"device\" can be cpu or None." assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU, Parameter \"device\" can be cpu or None."
if w is None: w = self.load_weights()[self.key] if w is None: w = self.load_weights()[self.key]
self.gate = w["gate"] self.gate = w["gate"]
self.up = w["up"] self.up = w["up"]
@ -176,28 +176,28 @@ class MLPCPUExperts(MLPExpertsBase):
# print(n_routed_experts, hidden_size, moe_intermediate_size) # print(n_routed_experts, hidden_size, moe_intermediate_size)
num_experts_per_tok = self.config.num_experts_per_tok num_experts_per_tok = self.config.num_experts_per_tok
self.moe = MOE(moe_config) self.moe = MOE(moe_config)
self.cpu_infer = MLPCPUExperts.CPU_INFER self.cpu_infer = KExpertsCPU.CPU_INFER
if warmup: if warmup:
self.cpu_infer.submit(self.moe.warm_up()) self.cpu_infer.submit(self.moe.warm_up())
self.cpu_infer.sync() self.cpu_infer.sync()
if self.out_device not in MLPCPUExperts.output_gpu_map: if self.out_device not in KExpertsCPU.output_gpu_map:
MLPCPUExperts.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device) KExpertsCPU.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device)
if MLPCPUExperts.input_tensor_cpu == None: if KExpertsCPU.input_tensor_cpu == None:
MLPCPUExperts.input_tensor_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True) KExpertsCPU.input_tensor_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True)
MLPCPUExperts.expert_ids_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) KExpertsCPU.expert_ids_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
MLPCPUExperts.weights_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) KExpertsCPU.weights_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
MLPCPUExperts.output_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) KExpertsCPU.output_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
def submit_for_one_decode(self, input_tensor, expert_ids, weights): def submit_for_one_decode(self, input_tensor, expert_ids, weights):
MLPCPUExperts.input_tensor_cpu.copy_(input_tensor, non_blocking=True) KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
MLPCPUExperts.expert_ids_cpu.copy_(expert_ids, non_blocking=True) KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
MLPCPUExperts.weights_cpu.copy_(weights, non_blocking=True) KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(0), MLPCPUExperts.expert_ids_cpu.data_ptr(), MLPCPUExperts.weights_cpu.data_ptr(), MLPCPUExperts.input_tensor_cpu.data_ptr(), MLPCPUExperts.output_cpu.data_ptr())) self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(0), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr()))
def sync_for_one_decode(self): def sync_for_one_decode(self):
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream) self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)
MLPCPUExperts.output_gpu_map[self.out_device].copy_(MLPCPUExperts.output_cpu, non_blocking=True) KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
return MLPCPUExperts.output_gpu_map[self.out_device] return KExpertsCPU.output_gpu_map[self.out_device]
def forward(self, input_tensor, expert_ids, weights): def forward(self, input_tensor, expert_ids, weights):
# generate, capture and run cuda graph # generate, capture and run cuda graph
@ -205,13 +205,13 @@ class MLPCPUExperts(MLPExpertsBase):
if input_tensor.size(0)==1: if input_tensor.size(0)==1:
# TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible # TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible
#print("capturing experts") #print("capturing experts")
MLPCPUExperts.input_tensor_cpu.copy_(input_tensor, non_blocking=True) KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
MLPCPUExperts.expert_ids_cpu.copy_(expert_ids, non_blocking=True) KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
MLPCPUExperts.weights_cpu.copy_(weights, non_blocking=True) KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(1, expert_ids.size(1), MLPCPUExperts.expert_ids_cpu.data_ptr(), MLPCPUExperts.weights_cpu.data_ptr(), MLPCPUExperts.input_tensor_cpu.data_ptr(), MLPCPUExperts.output_cpu.data_ptr())) self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(1, expert_ids.size(1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr()))
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream) self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
MLPCPUExperts.output_gpu_map[self.out_device].copy_(MLPCPUExperts.output_cpu, non_blocking=True) KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
return MLPCPUExperts.output_gpu_map[self.out_device] return KExpertsCPU.output_gpu_map[self.out_device]
else: else:
input_tensor = input_tensor.contiguous().cpu() input_tensor = input_tensor.contiguous().cpu()
expert_ids = expert_ids.contiguous().cpu() expert_ids = expert_ids.contiguous().cpu()
@ -269,7 +269,7 @@ class MLPCPUExperts(MLPExpertsBase):
res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}} res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
return res return res
class MLPExpertsMarlin(MLPExpertsBase): class KExpertsMarlin(KExpertsBase):
expert_num: int expert_num: int
loaded_experts_idx: list[int] loaded_experts_idx: list[int]
def __init__( def __init__(
@ -290,11 +290,11 @@ class MLPExpertsMarlin(MLPExpertsBase):
self.device = device self.device = device
# create empty marlin experts according to the number of experts per token # create empty marlin experts according to the number of experts per token
# up # up
self.up_projs = [QuantizedLinearMarlin(key+ "." + "ffn_up_exps", gguf_loader, config, device=device) for i in range(self.expert_num)] self.up_projs = [KLinearMarlin(key+ "." + "ffn_up_exps", gguf_loader, config, device=device) for i in range(self.expert_num)]
# gate # gate
self.gate_projs = [QuantizedLinearMarlin(key+ "." + "ffn_gate_exps", gguf_loader, config, device=device) for i in range(self.expert_num)] self.gate_projs = [KLinearMarlin(key+ "." + "ffn_gate_exps", gguf_loader, config, device=device) for i in range(self.expert_num)]
# down # down
self.down_projs = [QuantizedLinearMarlin(key+ "." + "ffn_down_exps", gguf_loader, config, device=device) for i in range(self.expert_num)] self.down_projs = [KLinearMarlin(key+ "." + "ffn_down_exps", gguf_loader, config, device=device) for i in range(self.expert_num)]
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):
if device is None: device = self.device if device is None: device = self.device
@ -359,7 +359,7 @@ class MLPExpertsMarlin(MLPExpertsBase):
outs = outs.to(device) outs = outs.to(device)
return outs return outs
class MLPExpertsTorch(MLPExpertsBase): class KExpertsTorch(KExpertsBase):
expert_num: int expert_num: int
loaded_experts_idx: list[int] loaded_experts_idx: list[int]
gate: torch.Tensor gate: torch.Tensor
@ -439,12 +439,12 @@ class MLPExpertsTorch(MLPExpertsBase):
return final_hidden_states.to(org_dtype, device=org_device) return final_hidden_states.to(org_dtype, device=org_device)
EXPERTS_MAP = { EXPERTS_MAP = {
"MLPCPUExperts": MLPCPUExperts, "KExpertsCPU": KExpertsCPU,
"MLPExpertsTorch": MLPExpertsTorch, "KExpertsTorch": KExpertsTorch,
"MLPExpertsMarlin": MLPExpertsMarlin, "KExpertsMarlin": KExpertsMarlin,
} }
class KTransformersMLPExpert(BaseInjectedModule, MLPExpertsBase): class KTransformersExperts(BaseInjectedModule, KExpertsBase):
def __init__(self, def __init__(self,
key: str, key: str,
gguf_loader: GGUFLoader, gguf_loader: GGUFLoader,
@ -452,22 +452,22 @@ class KTransformersMLPExpert(BaseInjectedModule, MLPExpertsBase):
orig_module: nn.Module, orig_module: nn.Module,
# device: str = "cuda", # device: str = "cuda",
prefill_device:str = "cuda", prefill_device:str = "cuda",
prefill_mlp_type: str | None = "MLPExpertsTorch", prefill_op: str | None = "KExpertsTorch",
generate_device: str = "cpu", generate_device: str = "cpu",
generate_mlp_type: str | None = "MLPCPUExperts", generate_op: str | None = "KExpertsCPU",
**kwargs): **kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
MLPExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
if generate_mlp_type is not None: if generate_op is not None:
self.generate_experts = EXPERTS_MAP[generate_mlp_type](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs) self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
else: else:
self.generate_experts = None self.generate_experts = None
if prefill_mlp_type is not None: if prefill_op is not None:
self.prefill_experts = EXPERTS_MAP[prefill_mlp_type](key, gguf_loader, config, len(orig_module), device=prefill_device, **kwargs) self.prefill_experts = EXPERTS_MAP[prefill_op](key, gguf_loader, config, len(orig_module), device=prefill_device, **kwargs)
else: else:
self.prefill_experts = None self.prefill_experts = None
self.gpu_mlp_type = prefill_mlp_type self.gpu_mlp_type = prefill_op
self.cpu_mlp_type = generate_mlp_type self.cpu_mlp_type = generate_op
self.mode = InferenceState.UNLOAD self.mode = InferenceState.UNLOAD
def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True): def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True):
@ -523,7 +523,7 @@ from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
class Qwen2MoeSparseMoeBlockInjected(BaseInjectedModule, Qwen2MoeSparseMoeBlock): class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """ """ """
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
@ -548,16 +548,16 @@ class Qwen2MoeSparseMoeBlockInjected(BaseInjectedModule, Qwen2MoeSparseMoeBlock)
y.resize_(*orig_shape) y.resize_(*orig_shape)
return y, router_logits return y, router_logits
hidden_states_expert = hidden_states.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else hidden_states_expert.cpu() hidden_states_expert = hidden_states.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else hidden_states_expert.cpu()
selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else selected_experts_expert.cpu() selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else selected_experts_expert.cpu()
routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else routing_weights_expert.cpu() routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else routing_weights_expert.cpu()
shared_expert_output = self.shared_expert(hidden_states) shared_expert_output = self.shared_expert(hidden_states)
shared_expert_output = ( shared_expert_output = (
F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output
) )
if isinstance(self.experts, MLPExpertsBase): if isinstance(self.experts, KExpertsBase):
y = ( y = (
self.moe_on_cpuinfer( self.moe_on_cpuinfer(
hidden_states_expert, selected_experts_expert, routing_weights_expert hidden_states_expert, selected_experts_expert, routing_weights_expert
@ -628,7 +628,7 @@ class Qwen2MoeSparseMoeBlockInjected(BaseInjectedModule, Qwen2MoeSparseMoeBlock)
return final_hidden_states return final_hidden_states
class DeepseekV2MoEInjected(BaseInjectedModule, DeepseekV2MoE): class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
def forward(self, hidden_states): def forward(self, hidden_states):
identity = hidden_states identity = hidden_states
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
@ -648,7 +648,7 @@ class DeepseekV2MoEInjected(BaseInjectedModule, DeepseekV2MoE):
if self.config.n_shared_experts is not None: if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0) y_ = self.shared_experts(identity).squeeze(0)
if isinstance(self.experts, MLPExpertsBase): if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device) y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10: elif hidden_states.size(0) > 10:
# TODO may bugs here # TODO may bugs here
@ -727,7 +727,7 @@ class DeepseekV2MoEInjected(BaseInjectedModule, DeepseekV2MoE):
) )
return final_out return final_out
class MisrtalSparseMoEBlockInjected(BaseInjectedModule, MixtralSparseMoeBlock): class KMisrtalSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """ """ """
@ -751,11 +751,11 @@ class MisrtalSparseMoEBlockInjected(BaseInjectedModule, MixtralSparseMoeBlock):
y.resize_(*orig_shape) y.resize_(*orig_shape)
return y, router_logits return y, router_logits
hidden_states_expert = hidden_states.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else hidden_states_expert.cpu() hidden_states_expert = hidden_states.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else hidden_states_expert.cpu()
selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else selected_experts_expert.cpu() selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else selected_experts_expert.cpu()
routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, MLPExpertsBase) else routing_weights_expert.cpu() routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else routing_weights_expert.cpu()
if isinstance(self.experts, MLPExpertsBase): if isinstance(self.experts, KExpertsBase):
y = ( y = (
self.moe_on_cpuinfer( self.moe_on_cpuinfer(
hidden_states_expert, selected_experts_expert, routing_weights_expert hidden_states_expert, selected_experts_expert, routing_weights_expert

View file

@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang
Date : 2024-07-25 11:25:24 Date : 2024-07-25 11:25:24
Version : 0.1.0 Version : 0.1.0
LastEditors : Azure LastEditors : Azure
LastEditTime : 2024-07-26 09:27:53 LastEditTime : 2024-08-14 14:57:04
Copyright (c) 2024 by KVCache.AI, All Rights Reserved. Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
''' '''
@ -34,8 +34,8 @@ import cpuinfer_ext
from ktransformers.operators.cpuinfer import CPUInfer from ktransformers.operators.cpuinfer import CPUInfer
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
#class QuantizedLinearBase(BaseInjectedModule, ABC): #class KLinearBase(BaseInjectedModule, ABC):
class QuantizedLinearBase(ABC): class KLinearBase(ABC):
def __init__( def __init__(
self, self,
key: str, key: str,
@ -106,7 +106,7 @@ class QuantizedLinearBase(ABC):
pass pass
class QuantizedLinearTorch(QuantizedLinearBase): class KLinearTorch(KLinearBase):
def __init__( def __init__(
self, self,
key: str, key: str,
@ -158,7 +158,7 @@ class QuantizedLinearTorch(QuantizedLinearBase):
self.bias = None self.bias = None
class QuantizedLinearMarlin(QuantizedLinearBase): class KLinearMarlin(KLinearBase):
marlin_q_w: torch.Tensor marlin_q_w: torch.Tensor
marlin_s: torch.Tensor marlin_s: torch.Tensor
g_idx: torch.Tensor g_idx: torch.Tensor
@ -252,7 +252,7 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
self.sort_indices = None self.sort_indices = None
self.workspace = None self.workspace = None
class QuantizedLinearCPUInfer(QuantizedLinearBase): class KLinearCPUInfer(KLinearBase):
CPU_INFER = CPUInfer(Config().cpu_infer) CPU_INFER = CPUInfer(Config().cpu_infer)
def __init__( def __init__(
self, self,
@ -281,7 +281,7 @@ class QuantizedLinearCPUInfer(QuantizedLinearBase):
out_device = x.device out_device = x.device
self.input_tensor_cpu.copy_(x, non_blocking=True) self.input_tensor_cpu.copy_(x, non_blocking=True)
qlen = origin_shape[1] qlen = origin_shape[1]
QuantizedLinearCPUInfer.CPU_INFER.submit_with_cuda_stream( KLinearCPUInfer.CPU_INFER.submit_with_cuda_stream(
torch.cuda.current_stream().cuda_stream, torch.cuda.current_stream().cuda_stream,
self.linear.forward( self.linear.forward(
qlen, qlen,
@ -289,7 +289,7 @@ class QuantizedLinearCPUInfer(QuantizedLinearBase):
self.output_cpu.data_ptr() self.output_cpu.data_ptr()
) )
) )
QuantizedLinearCPUInfer.CPU_INFER.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream) KLinearCPUInfer.CPU_INFER.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
self.output_gpu.copy_(self.output_cpu, non_blocking=True) self.output_gpu.copy_(self.output_cpu, non_blocking=True)
if self.has_bias: if self.has_bias:
self.output_gpu += self.bias self.output_gpu += self.bias
@ -301,14 +301,14 @@ class QuantizedLinearCPUInfer(QuantizedLinearBase):
qlen = origin_shape[1] qlen = origin_shape[1]
output_shape = (*origin_shape[:-1], self.out_features) output_shape = (*origin_shape[:-1], self.out_features)
output = torch.empty(output_shape, device=x.device, dtype=x.dtype) output = torch.empty(output_shape, device=x.device, dtype=x.dtype)
QuantizedLinearCPUInfer.CPU_INFER.submit( KLinearCPUInfer.CPU_INFER.submit(
self.linear.forward( self.linear.forward(
qlen, qlen,
x.data_ptr(), x.data_ptr(),
output.data_ptr() output.data_ptr()
) )
) )
QuantizedLinearCPUInfer.CPU_INFER.sync() KLinearCPUInfer.CPU_INFER.sync()
if self.has_bias: if self.has_bias:
output = output + self.bias output = output + self.bias
output = output.to(dtype=dtype, device=out_device) output = output.to(dtype=dtype, device=out_device)
@ -329,8 +329,8 @@ class QuantizedLinearCPUInfer(QuantizedLinearBase):
self.linear = cpuinfer_ext.linear.Linear(config) self.linear = cpuinfer_ext.linear.Linear(config)
if warmup: if warmup:
QuantizedLinearCPUInfer.CPU_INFER.submit(self.linear.warm_up()) KLinearCPUInfer.CPU_INFER.submit(self.linear.warm_up())
QuantizedLinearCPUInfer.CPU_INFER.sync() KLinearCPUInfer.CPU_INFER.sync()
self.input_tensor_cpu = torch.zeros((1, 1, self.in_features), device="cpu", pin_memory=True) self.input_tensor_cpu = torch.zeros((1, 1, self.in_features), device="cpu", pin_memory=True)
self.output_cpu = torch.zeros((1, 1, self.out_features), device="cpu", pin_memory=True, dtype=torch.bfloat16) self.output_cpu = torch.zeros((1, 1, self.out_features), device="cpu", pin_memory=True, dtype=torch.bfloat16)
self.output_gpu = torch.zeros((1, 1, self.out_features), device=self.out_device) self.output_gpu = torch.zeros((1, 1, self.out_features), device=self.out_device)
@ -355,12 +355,12 @@ class QuantizedLinearCPUInfer(QuantizedLinearBase):
self.bias = None self.bias = None
LINEAR_MAP = { LINEAR_MAP = {
"QuantizedLinearMarlin": QuantizedLinearMarlin, "KLinearMarlin": KLinearMarlin,
"QuantizedLinearTorch": QuantizedLinearTorch, "KLinearTorch": KLinearTorch,
"QuantizedLinearCPUInfer": QuantizedLinearCPUInfer "KLinearCPUInfer": KLinearCPUInfer
} }
class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase): class KTransformersLinear(BaseInjectedModule, KLinearBase):
def __init__( def __init__(
self, self,
key: str, key: str,
@ -369,20 +369,20 @@ class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):
orig_module: nn.Module, orig_module: nn.Module,
# device: str = "cuda", # device: str = "cuda",
generate_device: str = "cuda", generate_device: str = "cuda",
generate_op: str| None = "QuantizedLinearMarlin", generate_op: str| None = "KLinearMarlin",
prefill_device: str = "cuda", prefill_device: str = "cuda",
prefill_op: str| None = "QuantizedLinearTorch", prefill_op: str| None = "KLinearTorch",
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
QuantizedLinearBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) KLinearBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
# build all the linear operators # build all the linear operators
if prefill_op is not None: if prefill_op is not None:
assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported" assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported"
if prefill_op == "QuantizedLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0): if prefill_op == "KLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0):
print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using QuantizedLinearTorch instead.") print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using KLinearTorch instead.")
print(f"module info: key:{key} orig_module:{orig_module}") print(f"module info: key:{key} orig_module:{orig_module}")
self.prefill_linear = QuantizedLinearTorch(key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.prefill_linear = KLinearTorch(key, gguf_loader, config, orig_module, prefill_device, **kwargs)
else: else:
self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs)
else: else:
@ -390,11 +390,11 @@ class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):
if generate_op is not None: if generate_op is not None:
assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported" assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported"
if generate_op == "QuantizedLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0): if generate_op == "KLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0):
print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using QuantizedLinearTorch instead.") print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using KLinearTorch instead.")
print(f"module info: key:{key} orig_module:{orig_module}") print(f"module info: key:{key} orig_module:{orig_module}")
self.generate_op = "QuantizedLinearTorch" self.generate_op = "KLinearTorch"
self.generate_linear = QuantizedLinearTorch(key, gguf_loader, config, orig_module, generate_device, **kwargs) self.generate_linear = KLinearTorch(key, gguf_loader, config, orig_module, generate_device, **kwargs)
else: else:
self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs) self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)
else: else:

View file

@ -6,7 +6,7 @@ Author : Azure-Tang
Date : 2024-07-25 11:25:24 Date : 2024-07-25 11:25:24
Version : 1.0.0 Version : 1.0.0
LastEditors : Azure LastEditors : Azure
LastEditTime : 2024-08-08 10:09:14 LastEditTime : 2024-08-14 14:53:05
Copyright (c) 2024 by KVCache.AI, All Rights Reserved. Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
''' '''
@ -155,7 +155,7 @@ QWEN2MOE_INPUTS_DOCSTRING = r"""
"The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.", "The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.",
QWEN2MOE_START_DOCSTRING, QWEN2MOE_START_DOCSTRING,
) )
class Qwen2MoeModelKTransformers(BaseInjectedModule): class KQwen2MoeModel(BaseInjectedModule):
""" """
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`] Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`]
@ -451,7 +451,7 @@ DeepseekV2_INPUTS_DOCSTRING = r"""
""" """
class DeepseekV2ModelKTransformers(BaseInjectedModule): class KDeepseekV2Model(BaseInjectedModule):
""" """
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`]

View file

@ -43,48 +43,48 @@
name: "^model\\.layers\\.([0-9])\\.(?!self_attn).*$" # regular expression name: "^model\\.layers\\.([0-9])\\.(?!self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
replace: replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
generate_op: "QuantizedLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "QuantizedLinearTorch" prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\.([1][0-9])\\.(?!self_attn).*$" # regular expression name: "^model\\.layers\\.([1][0-9])\\.(?!self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
replace: replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
generate_op: "QuantizedLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "QuantizedLinearTorch" prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\.([2][0-9])\\.(?!self_attn).*$" # regular expression name: "^model\\.layers\\.([2][0-9])\\.(?!self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
replace: replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs: kwargs:
generate_device: "cuda:2" generate_device: "cuda:2"
prefill_device: "cuda:2" prefill_device: "cuda:2"
generate_op: "QuantizedLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "QuantizedLinearTorch" prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\.([345][0-9])\\.(?!self_attn).*$" # regular expression name: "^model\\.layers\\.([345][0-9])\\.(?!self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
replace: replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs: kwargs:
generate_device: "cuda:3" generate_device: "cuda:3"
prefill_device: "cuda:3" prefill_device: "cuda:3"
generate_op: "QuantizedLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "QuantizedLinearTorch" prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\.([0-9])\\.mlp$" name: "^model\\.layers\\.([0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace: replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
@ -92,7 +92,7 @@
name: "^model\\.layers\\.([1][0-9])\\.mlp$" name: "^model\\.layers\\.([1][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace: replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
@ -100,7 +100,7 @@
name: "^model\\.layers\\.([2][0-9])\\.mlp$" name: "^model\\.layers\\.([2][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace: replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs: kwargs:
generate_device: "cuda:2" generate_device: "cuda:2"
prefill_device: "cuda:2" prefill_device: "cuda:2"
@ -108,7 +108,7 @@
name: "^model\\.layers\\.([345][0-9])\\.mlp$" name: "^model\\.layers\\.([345][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace: replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs: kwargs:
generate_device: "cuda:3" generate_device: "cuda:3"
prefill_device: "cuda:3" prefill_device: "cuda:3"
@ -116,73 +116,73 @@
- match: - match:
name: "^model\\.layers\\.([0-9])\\.mlp\\.experts$" name: "^model\\.layers\\.([0-9])\\.mlp\\.experts$"
replace: replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs: kwargs:
prefill_device: "cuda:0" prefill_device: "cuda:0"
prefill_mlp_type: "MLPExpertsTorch" prefill_op: "KExpertsTorch"
generate_device: "cpu" generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts" generate_op: "KExpertsCPU"
out_device: "cuda:0" out_device: "cuda:0"
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
- match: - match:
name: "^model\\.layers\\.([1][0-9])\\.mlp\\.experts$" name: "^model\\.layers\\.([1][0-9])\\.mlp\\.experts$"
replace: replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs: kwargs:
prefill_device: "cuda:1" prefill_device: "cuda:1"
prefill_mlp_type: "MLPExpertsTorch" prefill_op: "KExpertsTorch"
generate_device: "cpu" generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts" generate_op: "KExpertsCPU"
out_device: "cuda:1" out_device: "cuda:1"
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
- match: - match:
name: "^model\\.layers\\.([2][0-9])\\.mlp\\.experts$" name: "^model\\.layers\\.([2][0-9])\\.mlp\\.experts$"
replace: replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs: kwargs:
prefill_device: "cuda:2" prefill_device: "cuda:2"
prefill_mlp_type: "MLPExpertsTorch" prefill_op: "KExpertsTorch"
generate_device: "cpu" generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts" generate_op: "KExpertsCPU"
out_device: "cuda:2" out_device: "cuda:2"
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
- match: - match:
name: "^model\\.layers\\.([345][0-9])\\.mlp\\.experts$" name: "^model\\.layers\\.([345][0-9])\\.mlp\\.experts$"
replace: replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs: kwargs:
prefill_device: "cuda:3" prefill_device: "cuda:3"
prefill_mlp_type: "MLPExpertsTorch" prefill_op: "KExpertsTorch"
generate_device: "cpu" generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts" generate_op: "KExpertsCPU"
out_device: "cuda:3" out_device: "cuda:3"
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
- match: - match:
name: "^model\\.layers\\.([0-9])\\.self_attn$" name: "^model\\.layers\\.([0-9])\\.self_attn$"
replace: replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "^model\\.layers\\.([1][0-9])\\.self_attn$" name: "^model\\.layers\\.([1][0-9])\\.self_attn$"
replace: replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
- match: - match:
name: "^model\\.layers\\.([2][0-9])\\.self_attn$" name: "^model\\.layers\\.([2][0-9])\\.self_attn$"
replace: replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs: kwargs:
generate_device: "cuda:2" generate_device: "cuda:2"
prefill_device: "cuda:2" prefill_device: "cuda:2"
- match: - match:
name: "^model\\.layers\\.([345][0-9])\\.self_attn$" name: "^model\\.layers\\.([345][0-9])\\.self_attn$"
replace: replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs: kwargs:
generate_device: "cuda:3" generate_device: "cuda:3"
prefill_device: "cuda:3" prefill_device: "cuda:3"
@ -190,7 +190,7 @@
- match: - match:
name: "^model$" name: "^model$"
replace: replace:
class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelKTransformers" class: "ktransformers.operators.layer_wise_prefill.KDeepseekV2Model"
kwargs: kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map: transfer_map:

View file

@ -27,29 +27,29 @@
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn).*$" # regular expression name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
replace: replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
generate_op: "QuantizedLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "QuantizedLinearTorch" prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\.([345][0-9])\\.(?!self_attn).*$" # regular expression name: "^model\\.layers\\.([345][0-9])\\.(?!self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
replace: replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
generate_op: "QuantizedLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "QuantizedLinearTorch" prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$" name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace: replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
@ -57,7 +57,7 @@
name: "^model\\.layers\\.([345][0-9])\\.mlp$" name: "^model\\.layers\\.([345][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace: replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
@ -65,45 +65,45 @@
- match: - match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$" name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$"
replace: replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs: kwargs:
prefill_device: "cuda:0" prefill_device: "cuda:0"
prefill_mlp_type: "MLPExpertsTorch" prefill_op: "KExpertsTorch"
generate_device: "cpu" generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts" generate_op: "KExpertsCPU"
out_device: "cuda:0" out_device: "cuda:0"
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
- match: - match:
name: "^model\\.layers\\.([345][0-9])\\.mlp\\.experts$" name: "^model\\.layers\\.([345][0-9])\\.mlp\\.experts$"
replace: replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs: kwargs:
prefill_device: "cuda:1" prefill_device: "cuda:1"
prefill_mlp_type: "MLPExpertsTorch" prefill_op: "KExpertsTorch"
generate_device: "cpu" generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts" generate_op: "KExpertsCPU"
out_device: "cuda:1" out_device: "cuda:1"
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
- match: - match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$" name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$"
replace: replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "^model\\.layers\\.([345][0-9])\\.self_attn$" name: "^model\\.layers\\.([345][0-9])\\.self_attn$"
replace: replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
- match: - match:
name: "^model$" name: "^model$"
replace: replace:
class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelKTransformers" class: "ktransformers.operators.layer_wise_prefill.KDeepseekV2Model"
kwargs: kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map: transfer_map:

View file

@ -9,53 +9,53 @@
# name: "^model\\.layers\\.([1-5][0-9])\\.mlp\\.shared_experts.*$" # regular expression # name: "^model\\.layers\\.([1-5][0-9])\\.mlp\\.shared_experts.*$" # regular expression
# class: torch.nn.Linear # only match modules matching name and class simultaneously # class: torch.nn.Linear # only match modules matching name and class simultaneously
# replace: # replace:
# class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types # class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
# kwargs: # kwargs:
# generate_device: "cpu" # generate_device: "cpu"
# prefill_device: "cuda" # prefill_device: "cuda"
# generate_op: "QuantizedLinearCPUInfer" # generate_op: "KLinearCPUInfer"
# prefill_op: "QuantizedLinearTorch" # prefill_op: "KLinearTorch"
# out_device: "cuda" # out_device: "cuda"
- match: - match:
name: "^model\\.layers\\.(?!.*self_attn).*$" # regular expression name: "^model\\.layers\\.(?!.*self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
replace: replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "QuantizedLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "QuantizedLinearTorch" prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.mlp$" name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace: replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
- match: - match:
name: "^model\\.layers\\..*\\.mlp\\.experts$" name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace: replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs: kwargs:
prefill_device: "cuda" prefill_device: "cuda"
prefill_mlp_type: "MLPExpertsTorch" prefill_op: "KExpertsTorch"
generate_device: "cpu" generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts" generate_op: "KExpertsCPU"
out_device: "cuda" out_device: "cuda"
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
- match: - match:
name: "^model\\.layers\\..*\\.self_attn$" name: "^model\\.layers\\..*\\.self_attn$"
replace: replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
- match: - match:
name: "^model$" name: "^model$"
replace: replace:
class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelKTransformers" class: "ktransformers.operators.layer_wise_prefill.KDeepseekV2Model"
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"

View file

@ -27,29 +27,29 @@
name: "^model\\.layers\\.(0|[1-9])\\.(?!self_attn).*$" # regular expression name: "^model\\.layers\\.(0|[1-9])\\.(?!self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
replace: replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
generate_op: "QuantizedLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "QuantizedLinearTorch" prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\.([12][0-9])\\.(?!self_attn).*$" # regular expression name: "^model\\.layers\\.([12][0-9])\\.(?!self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
replace: replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
generate_op: "QuantizedLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "QuantizedLinearTorch" prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\.(0|[1-9])\\.mlp$" name: "^model\\.layers\\.(0|[1-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace: replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
@ -57,7 +57,7 @@
name: "^model\\.layers\\.([12][0-9])\\.mlp$" name: "^model\\.layers\\.([12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace: replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
@ -65,45 +65,45 @@
- match: - match:
name: "^model\\.layers\\.(0|[1-9])\\.mlp\\.experts$" name: "^model\\.layers\\.(0|[1-9])\\.mlp\\.experts$"
replace: replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs: kwargs:
prefill_device: "cuda:0" prefill_device: "cuda:0"
prefill_mlp_type: "MLPExpertsTorch" prefill_op: "KExpertsTorch"
generate_device: "cpu" generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts" generate_op: "KExpertsCPU"
out_device: "cuda:0" out_device: "cuda:0"
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
- match: - match:
name: "^model\\.layers\\.([12][0-9])\\.mlp\\.experts$" name: "^model\\.layers\\.([12][0-9])\\.mlp\\.experts$"
replace: replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs: kwargs:
prefill_device: "cuda:1" prefill_device: "cuda:1"
prefill_mlp_type: "MLPExpertsTorch" prefill_op: "KExpertsTorch"
generate_device: "cpu" generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts" generate_op: "KExpertsCPU"
out_device: "cuda:1" out_device: "cuda:1"
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
- match: - match:
name: "^model\\.layers\\.(0|[1-9])\\.self_attn$" name: "^model\\.layers\\.(0|[1-9])\\.self_attn$"
replace: replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "^model\\.layers\\.([12][0-9])\\.self_attn$" name: "^model\\.layers\\.([12][0-9])\\.self_attn$"
replace: replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
- match: - match:
name: "^model$" name: "^model$"
replace: replace:
class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelKTransformers" class: "ktransformers.operators.layer_wise_prefill.KDeepseekV2Model"
kwargs: kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map: transfer_map:

View file

@ -9,26 +9,26 @@
name: "^model\\.layers\\..*$" name: "^model\\.layers\\..*$"
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
replace: replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "QuantizedLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "QuantizedLinearTorch" prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.block_sparse_moe$" name: "^model\\.layers\\..*\\.block_sparse_moe$"
class: ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock class: ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock
replace: replace:
class: ktransformers.operators.experts.MisrtalSparseMoEBlockInjected class: ktransformers.operators.experts.KMisrtalSparseMoEBlock
- match: - match:
name: "^model\\.layers\\..*\\.block_sparse_moe\\.experts$" name: "^model\\.layers\\..*\\.block_sparse_moe\\.experts$"
replace: replace:
class: ktransformers.operators.experts.KTransformersMLPExpert class: ktransformers.operators.experts.KTransformersExperts
kwargs: kwargs:
prefill_device: "cuda" prefill_device: "cuda"
prefill_mlp_type: "MLPExpertsTorch" prefill_op: "KExpertsTorch"
generate_device: "cpu" generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts" generate_op: "KExpertsCPU"
out_device: "cuda" out_device: "cuda"
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module

View file

@ -10,27 +10,27 @@
name: "^model\\.layers\\.([012])$" # regular expression name: "^model\\.layers\\.([012])$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
replace: replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
generate_op: "QuantizedLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "QuantizedLinearTorch" prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\.([012])\\.mlp$" name: "^model\\.layers\\.([012])\\.mlp$"
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
replace: replace:
class: ktransformers.operators.experts.Qwen2MoeSparseMoeBlockInjected # mlp module with custom forward function class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlock # mlp module with custom forward function
- match: - match:
name: "^model\\.layers\\.([012])\\.mlp\\.experts$" name: "^model\\.layers\\.([012])\\.mlp\\.experts$"
replace: replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
# device: "cpu" # which devices to load this module when initializing # device: "cpu" # which devices to load this module when initializing
kwargs: kwargs:
prefill_device: "cuda:0" prefill_device: "cuda:0"
prefill_mlp_type: "MLPExpertsTorch" prefill_op: "KExpertsTorch"
generate_device: "cpu" generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts" generate_op: "KExpertsCPU"
out_device: "cuda:0" out_device: "cuda:0"
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
@ -46,27 +46,27 @@
name: "^model\\.layers\\.([12][0-9]|[3-9])$" # regular expression name: "^model\\.layers\\.([12][0-9]|[3-9])$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
replace: replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
generate_op: "QuantizedLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "QuantizedLinearTorch" prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\.([12][0-9]|[3-9])\\.mlp$" name: "^model\\.layers\\.([12][0-9]|[3-9])\\.mlp$"
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
replace: replace:
class: ktransformers.operators.experts.Qwen2MoeSparseMoeBlockInjected # mlp module with custom forward function class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlock # mlp module with custom forward function
- match: - match:
name: "^model\\.layers\\.([12][0-9]|[3-9])\\.mlp\\.experts$" name: "^model\\.layers\\.([12][0-9]|[3-9])\\.mlp\\.experts$"
replace: replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
# device: "cpu" # which devices to load this module when initializing # device: "cpu" # which devices to load this module when initializing
kwargs: kwargs:
prefill_device: "cuda:1" prefill_device: "cuda:1"
prefill_mlp_type: "MLPExpertsTorch" prefill_op: "KExpertsTorch"
generate_device: "cpu" generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts" generate_op: "KExpertsCPU"
out_device: "cuda:1" out_device: "cuda:1"
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
@ -89,7 +89,7 @@
- match: - match:
name: "^model$" name: "^model$"
replace: replace:
class: "ktransformers.operators.layer_wise_prefill.Qwen2MoeModelKTransformers" class: "ktransformers.operators.layer_wise_prefill.KQwen2MoeModel"
kwargs: kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map: transfer_map:

View file

@ -9,36 +9,36 @@
name: "^model\\.layers\\..*$" # regular expression name: "^model\\.layers\\..*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
replace: replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "QuantizedLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "QuantizedLinearTorch" prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.mlp$" name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
replace: replace:
class: ktransformers.operators.experts.Qwen2MoeSparseMoeBlockInjected # mlp module with custom forward function class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlock # mlp module with custom forward function
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
- match: - match:
name: "^model\\.layers\\..*\\.mlp\\.experts$" name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace: replace:
class: ktransformers.operators.experts.KTransformersMLPExpert # custom MoE Kernel with expert paralleism class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
# device: "cpu" # which devices to load this module when initializing # device: "cpu" # which devices to load this module when initializing
kwargs: kwargs:
prefill_device: "cuda" prefill_device: "cuda"
prefill_mlp_type: "MLPExpertsTorch" prefill_op: "KExpertsTorch"
generate_device: "cpu" generate_device: "cpu"
generate_mlp_type: "MLPCPUExperts" generate_op: "KExpertsCPU"
out_device: "cuda" out_device: "cuda"
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
- match: - match:
name: "^model$" name: "^model$"
replace: replace:
class: "ktransformers.operators.layer_wise_prefill.Qwen2MoeModelKTransformers" class: "ktransformers.operators.layer_wise_prefill.KQwen2MoeModel"
kwargs: kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match: - match:

View file

@ -5,8 +5,8 @@ import sys
current_path = os.path.abspath(os.path.dirname(__file__)) current_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(current_path+"/../..") sys.path.append(current_path+"/../..")
import numpy as np import numpy as np
# from ktransformers.operators.linear import KTransformerLinear, QuantizedLinearMarlin # from ktransformers.operators.linear import KTransformersLinear, KLinearMarlin
# from ktransformers.operators.experts import KTransformersMLPExpert, MLPExpertsTorch # from ktransformers.operators.experts import KTransformersExperts, KExpertsTorch
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
import torch import torch
import KTransformersOps import KTransformersOps

View file

@ -7,8 +7,8 @@ import pycuda.autoinit
import pycuda.driver as cuda import pycuda.driver as cuda
from pycuda.compiler import SourceModule from pycuda.compiler import SourceModule
import numpy as np import numpy as np
from ktransformers.operators.linear import KTransformerLinear, QuantizedLinearMarlin from ktransformers.operators.linear import KTransformersLinear, KLinearMarlin
from ktransformers.operators.experts import KTransformersMLPExpert, MLPExpertsTorch from ktransformers.operators.experts import KTransformersExperts, KExpertsTorch
from ktransformers.util.custom_gguf import GGUFLoader, dequantize_q4_k_gpu, dequantize_q4_k from ktransformers.util.custom_gguf import GGUFLoader, dequantize_q4_k_gpu, dequantize_q4_k
import torch import torch
import KTransformersOps import KTransformersOps