Merge pull request #333 from kvcache-ai/feat_experts_gpu

toy support for experts on GPU, no CUDA Graph
This commit is contained in:
Atream 2025-02-15 23:30:24 +08:00 committed by GitHub
commit c5f036e8a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 202 additions and 66 deletions

View file

@ -25,7 +25,7 @@ from-https://github.com/kvcache-ai/ktransformers/issues/129#issue-2842799552
1. local_chat.py: You can increase the context window size by setting `--max_new_tokens` to a larger value. 1. local_chat.py: You can increase the context window size by setting `--max_new_tokens` to a larger value.
2. server: Increase the `--cache_lens' to a larger value. 2. server: Increase the `--cache_lens' to a larger value.
2. Move more weights to the GPU. 2. Move more weights to the GPU.
Refer to the ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml Refer to the ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml
```yaml ```yaml
- match: - match:
name: "^model\\.layers\\.([4-10])\\.mlp\\.experts$" # inject experts in layer 4~10 as marlin expert name: "^model\\.layers\\.([4-10])\\.mlp\\.experts$" # inject experts in layer 4~10 as marlin expert
@ -39,6 +39,8 @@ from-https://github.com/kvcache-ai/ktransformers/issues/129#issue-2842799552
You can modify layer as you want, eg. `name: "^model\\.layers\\.([4-10])\\.mlp\\.experts$"` to `name: "^model\\.layers\\.([4-12])\\.mlp\\.experts$"` to move more weights to the GPU. You can modify layer as you want, eg. `name: "^model\\.layers\\.([4-10])\\.mlp\\.experts$"` to `name: "^model\\.layers\\.([4-12])\\.mlp\\.experts$"` to move more weights to the GPU.
> Note: The first matched rule in yaml will be applied. For example, if you have two rules that match the same layer, only the first rule's replacement will be valid. > Note: The first matched rule in yaml will be applied. For example, if you have two rules that match the same layer, only the first rule's replacement will be valid.
> NoteCurrently, executing experts on the GPU will conflict with CUDA Graph. Without CUDA Graph, there will be a significant slowdown. Therefore, unless you have a substantial amount of VRAM (placing a single layer of experts for DeepSeek-V3/R1 on the GPU requires at least 5.6GB of VRAM), we do not recommend enabling this feature. We are actively working on optimization.
> Note KExpertsTorch is untested.
### Q: If I don't have enough VRAM, but I have multiple GPUs, how can I utilize them? ### Q: If I don't have enough VRAM, but I have multiple GPUs, how can I utilize them?

View file

@ -17,8 +17,8 @@
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
__global__ void dequantize_q8_0_kernel(float* output, const float* scales, const int8_t* qs, int num_blocks, int blk_size) { __global__ void dequantize_q8_0_kernel(float* output, const float* scales, const int8_t* qs, int num_blocks, int blk_size) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
for(int i=0;i<blk_size;i++){ for(int i=0;i<blk_size;i++){
float scale = scales[block_id]; float scale = scales[block_id];
output[block_id * blk_size + i] = scale * qs[block_id * blk_size + i]; output[block_id * blk_size + i] = scale * qs[block_id * blk_size + i];
@ -37,8 +37,8 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
} }
__global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { __global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256); float* __restrict__ output_blk = (float*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 80))); const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 80)));
@ -72,10 +72,10 @@ __global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size
__global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const uint32_t kmask1 = 0x03030303; const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f; const uint32_t kmask2 = 0x0f0f0f0f;
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256); float* __restrict__ output_blk = (float*)(output + block_id * 256);
uint32_t aux[4]; uint32_t aux[4];
@ -128,8 +128,8 @@ __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size
__global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256); float* __restrict__ output_blk = (float*)(output + block_id * 256);
// const uint8_t * q = data[i].qs; // const uint8_t * q = data[i].qs;
const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16); const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);
@ -152,8 +152,8 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
} }
__global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { __global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){ for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256); float* __restrict__ output_blk = (float*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 0))); const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 0)));
@ -181,8 +181,8 @@ __global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size
} }
__global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256); float* __restrict__ output_blk = (float*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 208))); const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 208)));
@ -215,8 +215,8 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
__global__ void dequantize_iq4_xs_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { __global__ void dequantize_iq4_xs_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) { for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {
float* __restrict__ output_blk = (float*)(output + block_id * 256); float* __restrict__ output_blk = (float*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size))); const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size)));
const uint16_t scales_h = *(reinterpret_cast<uint16_t*>(data + block_id * blk_size + 2)); const uint16_t scales_h = *(reinterpret_cast<uint16_t*>(data + block_id * blk_size + 2));

View file

@ -18,6 +18,7 @@ import torch.nn.functional as F
import torch import torch
import sys, os import sys, os
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from tqdm import tqdm
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build")) sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release")) sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
@ -225,6 +226,7 @@ class KExpertsCPU(KExpertsBase):
return return
def load_weights(self, override_key: str | None = None, device: str = "cpu"): def load_weights(self, override_key: str | None = None, device: str = "cpu"):
# TODO: support Bias
res = {} res = {}
if override_key is not None: if override_key is not None:
keys = override_key keys = override_key
@ -288,6 +290,8 @@ class KExpertsMarlin(KExpertsBase):
self.act_fn = ACT2FN[config.hidden_act] self.act_fn = ACT2FN[config.hidden_act]
assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU" assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU"
self.device = device self.device = device
self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size
# create empty marlin experts according to the number of experts per token # create empty marlin experts according to the number of experts per token
# up # up
self.up_projs = [KLinearMarlin(key+ "." + "ffn_up_exps", gguf_loader, config, device=device) for i in range(self.expert_num)] self.up_projs = [KLinearMarlin(key+ "." + "ffn_up_exps", gguf_loader, config, device=device) for i in range(self.expert_num)]
@ -299,17 +303,34 @@ class KExpertsMarlin(KExpertsBase):
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):
if device is None: device = self.device if device is None: device = self.device
assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU" assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU"
if w is None: w = self.load_weights()[self.key] if w is None:
w = self.load_weights()
load_by_experts = True
if isinstance(w, dict): if load_by_experts:
self.gate = w["gate"] if isinstance(w, dict):
self.up = (w["up"]) self.gate = w["gate"]
self.down = (w["down"]) self.up = (w["up"])
for i in range(self.expert_num): self.down = (w["down"])
self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device) for i in tqdm(range(self.expert_num), desc=f"Dequanting and quanting for KExpertsMarlin {self.key}"):
self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device) up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", self.up, i, self.elements_per_tensor, device=self.device)
self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device) gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", self.gate, i, self.elements_per_tensor, device=self.device)
self.loaded_experts_idx.append(i) down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", self.down, i, self.elements_per_tensor, device=self.device)
self.up_projs[i].load(nn.Parameter(up_weights), device=device)
self.gate_projs[i].load(nn.Parameter(gate_weights), device=device)
self.down_projs[i].load(nn.Parameter(down_weights), device=device)
self.loaded_experts_idx.append(i)
else:
if isinstance(w, dict):
self.gate = w["gate"]
self.up = (w["up"])
self.down = (w["down"])
for i in range(self.expert_num):
self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device)
self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device)
self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device)
self.loaded_experts_idx.append(i)
return return
def unload(self): def unload(self):
@ -329,20 +350,13 @@ class KExpertsMarlin(KExpertsBase):
gate = None gate = None
up = None up = None
down = None down = None
gate_type = None
up_type = None
down_type = None
for key in keys: for key in keys:
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info: if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
gate = self.gguf_loader.load_gguf_tensor(key + ".ffn_gate_exps.weight") gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
up = self.gguf_loader.load_gguf_tensor(key + ".ffn_up_exps.weight") up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
down = self.gguf_loader.load_gguf_tensor(key + ".ffn_down_exps.weight") down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"] res = {"gate": gate, "up": up, "down": down}
up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
# tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"])
res = {key:{"gate": nn.Parameter(gate), "up": nn.Parameter(up), "down": nn.Parameter(down), "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
return res return res
def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
@ -381,6 +395,7 @@ class KExpertsMarlin(KExpertsBase):
return final_hidden_states.to(dtype=org_dtype, device=org_device) return final_hidden_states.to(dtype=org_dtype, device=org_device)
# untested, CUDA OOM
class KExpertsTorch(KExpertsBase): class KExpertsTorch(KExpertsBase):
expert_num: int expert_num: int
loaded_experts_idx: list[int] loaded_experts_idx: list[int]
@ -402,19 +417,39 @@ class KExpertsTorch(KExpertsBase):
# self.loaded_experts_idx = [] # self.loaded_experts_idx = []
self.act_fn = ACT2FN[config.hidden_act] self.act_fn = ACT2FN[config.hidden_act]
self.device = device self.device = device
self.gate = None self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size
self.up = None self.gate = [None for _ in range(self.expert_num)]
self.donw = None self.up = [None for _ in range(self.expert_num)]
self.down = [None for _ in range(self.expert_num)]
self.dtype = torch.get_default_dtype() self.dtype = torch.get_default_dtype()
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):
if device is None: device = self.device if device is None: device = self.device
if w is None: w = self.load_weights(device=device)[self.key] if w is None:
w = self.load_weights()
load_by_experts = True
if isinstance(w, dict): if load_by_experts:
self.gate = w["gate"].to(device=device, dtype=self.dtype) if isinstance(w, dict):
self.up = w["up"].to(device=device, dtype=self.dtype) for i in tqdm(range(self.expert_num), desc=f"Dequanting for KExpertsTorch {self.key}"):
self.down = w["down"].to(device=device, dtype=self.dtype) up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", w["up"], i, self.elements_per_tensor, device=self.device)
gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", w["gate"], i, self.elements_per_tensor, device=self.device)
down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", w["down"], i, self.elements_per_tensor, device=self.device)
self.up[i] = up_weights
self.gate[i] = gate_weights
self.down[i] = down_weights
else:
if isinstance(w, dict):
for i in range(self.expert_num):
self.gate[i] = w["gate"][i, ...].to(device=device, dtype=self.dtype)
self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype)
self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype)
self.up = torch.cat(self.gate, dim=0)
self.gate = torch.cat(self.gate, dim=0)
self.down = torch.cat(self.gate, dim=0)
return
def unload(self): def unload(self):
if self.gate is not None: if self.gate is not None:
@ -422,6 +457,25 @@ class KExpertsTorch(KExpertsBase):
self.up = None self.up = None
self.down = None self.down = None
def load_weights(self, override_key: str | None = None):
res = {}
if override_key is not None:
keys = override_key
else:
keys = [self.key]
gate = None
up = None
down = None
for key in keys:
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
res = {"gate": gate, "up": up, "down": down}
return res
def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
org_device = hidden_states_cpu.device org_device = hidden_states_cpu.device
@ -582,7 +636,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
if isinstance(self.experts, KExpertsBase): if isinstance(self.experts, KExpertsBase):
y = ( y = (
self.moe_on_cpuinfer( self.moe_kexperts(
hidden_states_expert, selected_experts_expert, routing_weights_expert hidden_states_expert, selected_experts_expert, routing_weights_expert
) )
.view(*orig_shape) .view(*orig_shape)
@ -601,8 +655,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
return y, router_logits return y, router_logits
@torch.no_grad() @torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
outs = self.experts(x, topk_ids, topk_weight) outs = self.experts(x, topk_ids, topk_weight)
return outs return outs
@ -672,7 +725,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
y_ = self.shared_experts(identity).squeeze(0) y_ = self.shared_experts(identity).squeeze(0)
if isinstance(self.experts, KExpertsBase): if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device) y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10: elif hidden_states.size(0) > 10:
# TODO may bugs here # TODO may bugs here
y = ( y = (
@ -692,8 +745,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
return y return y
@torch.no_grad() @torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
outs = self.experts(x, topk_ids, topk_weight) outs = self.experts(x, topk_ids, topk_weight)
return outs return outs
@ -773,7 +825,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
y_ = self.shared_experts(identity).squeeze(0) y_ = self.shared_experts(identity).squeeze(0)
if isinstance(self.experts, KExpertsBase): if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device) y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10: elif hidden_states.size(0) > 10:
# TODO may bugs here # TODO may bugs here
y = ( y = (
@ -793,8 +845,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
return y return y
@torch.no_grad() @torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
outs = self.experts(x, topk_ids, topk_weight) outs = self.experts(x, topk_ids, topk_weight)
return outs return outs
@ -881,7 +932,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
if isinstance(self.experts, KExpertsBase): if isinstance(self.experts, KExpertsBase):
y = ( y = (
self.moe_on_cpuinfer( self.moe_kexperts(
hidden_states_expert, selected_experts_expert, routing_weights_expert hidden_states_expert, selected_experts_expert, routing_weights_expert
) )
.view(*orig_shape) .view(*orig_shape)
@ -900,8 +951,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
return y, router_logits return y, router_logits
@torch.no_grad() @torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
outs = self.experts(x, topk_ids, topk_weight) outs = self.experts(x, topk_ids, topk_weight)
return outs return outs

View file

@ -119,7 +119,7 @@ class KLinearTorch(KLinearBase):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.has_bias = False self.has_bias = False
self.dtype = torch.get_default_dtype() self.dtype = torch.get_default_dtype()
self.w = None self.weight = None
self.has_bias = False self.has_bias = False
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -127,7 +127,7 @@ class KLinearTorch(KLinearBase):
out_device = x.device out_device = x.device
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended. # TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
x = x.to(device=self.device, dtype=self.dtype) x = x.to(device=self.device, dtype=self.dtype)
x = x @ self.w x = x @ self.weight
if self.has_bias: if self.has_bias:
x = x + self.bias x = x + self.bias
x = x.to(dtype=dtype, device=out_device) x = x.to(dtype=dtype, device=out_device)
@ -140,27 +140,27 @@ class KLinearTorch(KLinearBase):
if isinstance(w, nn.Parameter): if isinstance(w, nn.Parameter):
try: try:
self.w = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T self.weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
except: except:
self.w = w.to(dtype=self.dtype).T self.weight = w.to(dtype=self.dtype).T
self.has_bias = False self.has_bias = False
elif isinstance(w, tuple): elif isinstance(w, tuple):
try: try:
self.w = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T self.weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
except: except:
self.w = w[0].to(dtype=self.dtype).T self.weight = w[0].to(dtype=self.dtype).T
self.bias = w[1].to(dtype=self.dtype) self.bias = w[1].to(dtype=self.dtype)
self.has_bias = True self.has_bias = True
else: else:
raise ValueError("Invalid weight type") raise ValueError("Invalid weight type")
# self.linear = self.linear.to(device) # self.linear = self.linear.to(device)
self.w = self.w.to(device) self.weight = self.weight.to(device)
if self.has_bias: if self.has_bias:
self.bias = self.bias.to(device) self.bias = self.bias.to(device)
def unload(self): def unload(self):
if self.w is not None: if self.weight is not None:
self.w = None self.weight = None
if self.has_bias: if self.has_bias:
self.bias = None self.bias = None
@ -218,6 +218,7 @@ class KLinearMarlin(KLinearBase):
self.workspace = MarlinWorkspace( self.workspace = MarlinWorkspace(
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device
) )
self.weight = marlin_q_w # modeling_xxx.py may use linear.weight
self.marlin_q_w = marlin_q_w self.marlin_q_w = marlin_q_w
self.marlin_s = marlin_s self.marlin_s = marlin_s
self.g_idx = g_idx self.g_idx = g_idx
@ -424,11 +425,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
if mode == InferenceState.PREFILL: if mode == InferenceState.PREFILL:
self.generate_linear.unload() self.generate_linear.unload()
self.prefill_linear.load(w=w) self.prefill_linear.load(w=w)
self.device = self.prefill_linear.device self.device = self.prefill_linear.device
self.weight = self.prefill_linear.weight # modeling_xxx.py may use linear.weight
elif mode == InferenceState.GENERATE: elif mode == InferenceState.GENERATE:
self.prefill_linear.unload() self.prefill_linear.unload()
self.generate_linear.load(w=w) self.generate_linear.load(w=w)
self.device = self.generate_linear.device self.device = self.generate_linear.device
self.weight = self.generate_linear.weight # modeling_xxx.py may use linear.weight
elif mode == InferenceState.UNLOAD: elif mode == InferenceState.UNLOAD:
self.prefill_linear.unload() self.prefill_linear.unload()
self.generate_linear.unload() self.generate_linear.unload()

View file

@ -182,6 +182,53 @@
generate_device: "cuda:3" generate_device: "cuda:3"
prefill_device: "cuda:3" prefill_device: "cuda:3"
# === MLP Experts Replacement ===
# replace with marlin expert. Open and modify layer-num as needed.
# Each layer of malin experts takes about 6GB of GPU memory.
# !!!Do remember 'close' cuda graph if you are using marlin expert.!!!
# !!!KExpertsTorch is untested, we don't have enough VRAM.!!!
# # GPU 0: layers 34
# - match:
# name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:0"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 1: layers 1517
# - match:
# name: "^model\\.layers\\.(1[5-7])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:1"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 2: layers 3032
# - match:
# name: "^model\\.layers\\.(3[0-2])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:2"
# generate_op: "KExpertsMarlin"
# recursive: False
# # GPU 3: layers 4546
# - match:
# name: "^model\\.layers\\.(4[5-6])\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts
# kwargs:
# generate_device: "cuda:3"
# generate_op: "KExpertsMarlin"
# recursive: False
# === MLP Experts Replacement === # === MLP Experts Replacement ===
# GPU 0: layers 014 # GPU 0: layers 014
@ -316,6 +363,8 @@
generate_device: "cuda:2" generate_device: "cuda:2"
prefill_device: "cuda:2" prefill_device: "cuda:2"
# don't inject lm_head if already inject marlin experts
# For final modules (model.norm and lm_head), ensure they are on GPU 3 (as in your original config) # For final modules (model.norm and lm_head), ensure they are on GPU 3 (as in your original config)
- match: - match:
name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)" name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)"

View file

@ -713,6 +713,8 @@
generate_device: "cuda:7" generate_device: "cuda:7"
prefill_device: "cuda:7" prefill_device: "cuda:7"
# don't inject lm_head if already inject marlin experts
# For final modules (model.norm and lm_head), ensure they are on GPU 7 (as in your original config) # For final modules (model.norm and lm_head), ensure they are on GPU 7 (as in your original config)
- match: - match:
name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)" name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)"

View file

@ -282,8 +282,38 @@ class GGUFLoader:
itemsize = int(np.empty([], dtype = item_type).itemsize) itemsize = int(np.empty([], dtype = item_type).itemsize)
return mmap_data[offset : offset + itemsize * item_count] return mmap_data[offset : offset + itemsize * item_count]
def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "gpu")->torch.Tensor:
t = self.tensor_info[name]
if device.lower() == "cpu":
print(f"loading expert {expert_id} of {name} with CPU")
shape = t["shape"]
ggml_type = t["ggml_type"]
if ggml_type not in GGML_NAMES:
raise NotImplementedError(f"ggml_type {ggml_type} not implemented")
ggml_name = GGML_NAMES[ggml_type]
# TODO: experts may fused in quant block, split it
assert elements_per_expert % GGML_ELEMENTS_PER_BLOCK[ggml_name] == 0, "experts may fused in quant block, please use CPU dequant"
blocks_per_experts = elements_per_expert // GGML_ELEMENTS_PER_BLOCK[ggml_name]
block_size = GGML_BLOCK_SIZES[ggml_name]
offset = expert_id * block_size * blocks_per_experts
data = data[offset: offset + block_size * blocks_per_experts]
if "cuda" in device.lower():
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
else:
values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values)
values = values.view(shape[-2::-1])
return values
def load_gguf_tensor(self, name: str, device:str = "cpu")->torch.Tensor: def load_gguf_tensor(self, name: str, device:str = "cpu")->torch.Tensor:
t = self.tensor_info[name] t = self.tensor_info[name]
if device.lower() == "cpu":
print(f"loading {name} with CPU")
shape = t["shape"] shape = t["shape"]
ggml_type = t["ggml_type"] ggml_type = t["ggml_type"]