mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 12:09:48 +00:00
Merge pull request #333 from kvcache-ai/feat_experts_gpu
toy support for experts on GPU, no CUDA Graph
This commit is contained in:
commit
c5f036e8a4
7 changed files with 202 additions and 66 deletions
|
@ -25,7 +25,7 @@ from-https://github.com/kvcache-ai/ktransformers/issues/129#issue-2842799552
|
|||
1. local_chat.py: You can increase the context window size by setting `--max_new_tokens` to a larger value.
|
||||
2. server: Increase the `--cache_lens' to a larger value.
|
||||
2. Move more weights to the GPU.
|
||||
Refer to the ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
|
||||
Refer to the ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml
|
||||
```yaml
|
||||
- match:
|
||||
name: "^model\\.layers\\.([4-10])\\.mlp\\.experts$" # inject experts in layer 4~10 as marlin expert
|
||||
|
@ -39,6 +39,8 @@ from-https://github.com/kvcache-ai/ktransformers/issues/129#issue-2842799552
|
|||
You can modify layer as you want, eg. `name: "^model\\.layers\\.([4-10])\\.mlp\\.experts$"` to `name: "^model\\.layers\\.([4-12])\\.mlp\\.experts$"` to move more weights to the GPU.
|
||||
|
||||
> Note: The first matched rule in yaml will be applied. For example, if you have two rules that match the same layer, only the first rule's replacement will be valid.
|
||||
> Note:Currently, executing experts on the GPU will conflict with CUDA Graph. Without CUDA Graph, there will be a significant slowdown. Therefore, unless you have a substantial amount of VRAM (placing a single layer of experts for DeepSeek-V3/R1 on the GPU requires at least 5.6GB of VRAM), we do not recommend enabling this feature. We are actively working on optimization.
|
||||
> Note KExpertsTorch is untested.
|
||||
|
||||
|
||||
### Q: If I don't have enough VRAM, but I have multiple GPUs, how can I utilize them?
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
__global__ void dequantize_q8_0_kernel(float* output, const float* scales, const int8_t* qs, int num_blocks, int blk_size) {
|
||||
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
|
||||
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
|
||||
for(int i=0;i<blk_size;i++){
|
||||
float scale = scales[block_id];
|
||||
output[block_id * blk_size + i] = scale * qs[block_id * blk_size + i];
|
||||
|
@ -37,8 +37,8 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
|
|||
}
|
||||
|
||||
__global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
|
||||
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
|
||||
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
|
||||
float* __restrict__ output_blk = (float*)(output + block_id * 256);
|
||||
|
||||
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 80)));
|
||||
|
@ -72,10 +72,10 @@ __global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size
|
|||
|
||||
__global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
|
||||
|
||||
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const uint32_t kmask1 = 0x03030303;
|
||||
const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
|
||||
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
|
||||
float* __restrict__ output_blk = (float*)(output + block_id * 256);
|
||||
|
||||
uint32_t aux[4];
|
||||
|
@ -128,8 +128,8 @@ __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size
|
|||
|
||||
|
||||
__global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
|
||||
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
|
||||
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){
|
||||
float* __restrict__ output_blk = (float*)(output + block_id * 256);
|
||||
// const uint8_t * q = data[i].qs;
|
||||
const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);
|
||||
|
@ -152,8 +152,8 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
|
|||
}
|
||||
|
||||
__global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
|
||||
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
|
||||
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){
|
||||
float* __restrict__ output_blk = (float*)(output + block_id * 256);
|
||||
|
||||
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 0)));
|
||||
|
@ -181,8 +181,8 @@ __global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size
|
|||
}
|
||||
|
||||
__global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
|
||||
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
|
||||
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
|
||||
float* __restrict__ output_blk = (float*)(output + block_id * 256);
|
||||
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 208)));
|
||||
|
||||
|
@ -215,8 +215,8 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
|
|||
static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
||||
|
||||
__global__ void dequantize_iq4_xs_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
|
||||
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (auto block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {
|
||||
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {
|
||||
float* __restrict__ output_blk = (float*)(output + block_id * 256);
|
||||
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size)));
|
||||
const uint16_t scales_h = *(reinterpret_cast<uint16_t*>(data + block_id * blk_size + 2));
|
||||
|
|
|
@ -18,6 +18,7 @@ import torch.nn.functional as F
|
|||
import torch
|
||||
import sys, os
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from tqdm import tqdm
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
|
||||
|
@ -225,6 +226,7 @@ class KExpertsCPU(KExpertsBase):
|
|||
return
|
||||
|
||||
def load_weights(self, override_key: str | None = None, device: str = "cpu"):
|
||||
# TODO: support Bias
|
||||
res = {}
|
||||
if override_key is not None:
|
||||
keys = override_key
|
||||
|
@ -288,6 +290,8 @@ class KExpertsMarlin(KExpertsBase):
|
|||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU"
|
||||
self.device = device
|
||||
self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size
|
||||
|
||||
# create empty marlin experts according to the number of experts per token
|
||||
# up
|
||||
self.up_projs = [KLinearMarlin(key+ "." + "ffn_up_exps", gguf_loader, config, device=device) for i in range(self.expert_num)]
|
||||
|
@ -299,17 +303,34 @@ class KExpertsMarlin(KExpertsBase):
|
|||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):
|
||||
if device is None: device = self.device
|
||||
assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU"
|
||||
if w is None: w = self.load_weights()[self.key]
|
||||
if w is None:
|
||||
w = self.load_weights()
|
||||
load_by_experts = True
|
||||
|
||||
if isinstance(w, dict):
|
||||
self.gate = w["gate"]
|
||||
self.up = (w["up"])
|
||||
self.down = (w["down"])
|
||||
for i in range(self.expert_num):
|
||||
self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device)
|
||||
self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device)
|
||||
self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device)
|
||||
self.loaded_experts_idx.append(i)
|
||||
if load_by_experts:
|
||||
if isinstance(w, dict):
|
||||
self.gate = w["gate"]
|
||||
self.up = (w["up"])
|
||||
self.down = (w["down"])
|
||||
for i in tqdm(range(self.expert_num), desc=f"Dequanting and quanting for KExpertsMarlin {self.key}"):
|
||||
up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", self.up, i, self.elements_per_tensor, device=self.device)
|
||||
gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", self.gate, i, self.elements_per_tensor, device=self.device)
|
||||
down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", self.down, i, self.elements_per_tensor, device=self.device)
|
||||
|
||||
self.up_projs[i].load(nn.Parameter(up_weights), device=device)
|
||||
self.gate_projs[i].load(nn.Parameter(gate_weights), device=device)
|
||||
self.down_projs[i].load(nn.Parameter(down_weights), device=device)
|
||||
self.loaded_experts_idx.append(i)
|
||||
else:
|
||||
if isinstance(w, dict):
|
||||
self.gate = w["gate"]
|
||||
self.up = (w["up"])
|
||||
self.down = (w["down"])
|
||||
for i in range(self.expert_num):
|
||||
self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device)
|
||||
self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device)
|
||||
self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device)
|
||||
self.loaded_experts_idx.append(i)
|
||||
return
|
||||
|
||||
def unload(self):
|
||||
|
@ -329,20 +350,13 @@ class KExpertsMarlin(KExpertsBase):
|
|||
gate = None
|
||||
up = None
|
||||
down = None
|
||||
gate_type = None
|
||||
up_type = None
|
||||
down_type = None
|
||||
|
||||
for key in keys:
|
||||
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
|
||||
gate = self.gguf_loader.load_gguf_tensor(key + ".ffn_gate_exps.weight")
|
||||
up = self.gguf_loader.load_gguf_tensor(key + ".ffn_up_exps.weight")
|
||||
down = self.gguf_loader.load_gguf_tensor(key + ".ffn_down_exps.weight")
|
||||
gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"]
|
||||
up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
|
||||
down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
|
||||
# tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"])
|
||||
res = {key:{"gate": nn.Parameter(gate), "up": nn.Parameter(up), "down": nn.Parameter(down), "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
|
||||
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
|
||||
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
|
||||
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
|
||||
res = {"gate": gate, "up": up, "down": down}
|
||||
return res
|
||||
|
||||
def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -381,6 +395,7 @@ class KExpertsMarlin(KExpertsBase):
|
|||
|
||||
return final_hidden_states.to(dtype=org_dtype, device=org_device)
|
||||
|
||||
# untested, CUDA OOM
|
||||
class KExpertsTorch(KExpertsBase):
|
||||
expert_num: int
|
||||
loaded_experts_idx: list[int]
|
||||
|
@ -402,19 +417,39 @@ class KExpertsTorch(KExpertsBase):
|
|||
# self.loaded_experts_idx = []
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
self.device = device
|
||||
self.gate = None
|
||||
self.up = None
|
||||
self.donw = None
|
||||
self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size
|
||||
self.gate = [None for _ in range(self.expert_num)]
|
||||
self.up = [None for _ in range(self.expert_num)]
|
||||
self.down = [None for _ in range(self.expert_num)]
|
||||
self.dtype = torch.get_default_dtype()
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):
|
||||
if device is None: device = self.device
|
||||
if w is None: w = self.load_weights(device=device)[self.key]
|
||||
if w is None:
|
||||
w = self.load_weights()
|
||||
load_by_experts = True
|
||||
|
||||
if isinstance(w, dict):
|
||||
self.gate = w["gate"].to(device=device, dtype=self.dtype)
|
||||
self.up = w["up"].to(device=device, dtype=self.dtype)
|
||||
self.down = w["down"].to(device=device, dtype=self.dtype)
|
||||
if load_by_experts:
|
||||
if isinstance(w, dict):
|
||||
for i in tqdm(range(self.expert_num), desc=f"Dequanting for KExpertsTorch {self.key}"):
|
||||
up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", w["up"], i, self.elements_per_tensor, device=self.device)
|
||||
gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", w["gate"], i, self.elements_per_tensor, device=self.device)
|
||||
down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", w["down"], i, self.elements_per_tensor, device=self.device)
|
||||
|
||||
self.up[i] = up_weights
|
||||
self.gate[i] = gate_weights
|
||||
self.down[i] = down_weights
|
||||
else:
|
||||
if isinstance(w, dict):
|
||||
for i in range(self.expert_num):
|
||||
self.gate[i] = w["gate"][i, ...].to(device=device, dtype=self.dtype)
|
||||
self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype)
|
||||
self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype)
|
||||
|
||||
self.up = torch.cat(self.gate, dim=0)
|
||||
self.gate = torch.cat(self.gate, dim=0)
|
||||
self.down = torch.cat(self.gate, dim=0)
|
||||
return
|
||||
|
||||
def unload(self):
|
||||
if self.gate is not None:
|
||||
|
@ -422,6 +457,25 @@ class KExpertsTorch(KExpertsBase):
|
|||
self.up = None
|
||||
self.down = None
|
||||
|
||||
def load_weights(self, override_key: str | None = None):
|
||||
res = {}
|
||||
if override_key is not None:
|
||||
keys = override_key
|
||||
else:
|
||||
keys = [self.key]
|
||||
|
||||
gate = None
|
||||
up = None
|
||||
down = None
|
||||
|
||||
for key in keys:
|
||||
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
|
||||
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
|
||||
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
|
||||
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
|
||||
res = {"gate": gate, "up": up, "down": down}
|
||||
return res
|
||||
|
||||
def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
org_device = hidden_states_cpu.device
|
||||
|
@ -582,7 +636,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
|||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = (
|
||||
self.moe_on_cpuinfer(
|
||||
self.moe_kexperts(
|
||||
hidden_states_expert, selected_experts_expert, routing_weights_expert
|
||||
)
|
||||
.view(*orig_shape)
|
||||
|
@ -601,8 +655,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
|||
return y, router_logits
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||
outs = torch.empty_like(x)
|
||||
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||
outs = self.experts(x, topk_ids, topk_weight)
|
||||
return outs
|
||||
|
||||
|
@ -672,7 +725,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
|
|||
y_ = self.shared_experts(identity).squeeze(0)
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
|
||||
y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
|
@ -692,8 +745,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
|
|||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||
outs = torch.empty_like(x)
|
||||
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||
outs = self.experts(x, topk_ids, topk_weight)
|
||||
return outs
|
||||
|
||||
|
@ -773,7 +825,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
|
|||
y_ = self.shared_experts(identity).squeeze(0)
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
|
||||
y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
|
@ -793,8 +845,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
|
|||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||
outs = torch.empty_like(x)
|
||||
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||
outs = self.experts(x, topk_ids, topk_weight)
|
||||
return outs
|
||||
|
||||
|
@ -881,7 +932,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
|
|||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = (
|
||||
self.moe_on_cpuinfer(
|
||||
self.moe_kexperts(
|
||||
hidden_states_expert, selected_experts_expert, routing_weights_expert
|
||||
)
|
||||
.view(*orig_shape)
|
||||
|
@ -900,8 +951,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
|
|||
return y, router_logits
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||
outs = torch.empty_like(x)
|
||||
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||
outs = self.experts(x, topk_ids, topk_weight)
|
||||
return outs
|
||||
|
||||
|
|
|
@ -119,7 +119,7 @@ class KLinearTorch(KLinearBase):
|
|||
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
self.has_bias = False
|
||||
self.dtype = torch.get_default_dtype()
|
||||
self.w = None
|
||||
self.weight = None
|
||||
self.has_bias = False
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -127,7 +127,7 @@ class KLinearTorch(KLinearBase):
|
|||
out_device = x.device
|
||||
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
|
||||
x = x.to(device=self.device, dtype=self.dtype)
|
||||
x = x @ self.w
|
||||
x = x @ self.weight
|
||||
if self.has_bias:
|
||||
x = x + self.bias
|
||||
x = x.to(dtype=dtype, device=out_device)
|
||||
|
@ -140,27 +140,27 @@ class KLinearTorch(KLinearBase):
|
|||
|
||||
if isinstance(w, nn.Parameter):
|
||||
try:
|
||||
self.w = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
||||
self.weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
||||
except:
|
||||
self.w = w.to(dtype=self.dtype).T
|
||||
self.weight = w.to(dtype=self.dtype).T
|
||||
self.has_bias = False
|
||||
elif isinstance(w, tuple):
|
||||
try:
|
||||
self.w = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
||||
self.weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
||||
except:
|
||||
self.w = w[0].to(dtype=self.dtype).T
|
||||
self.weight = w[0].to(dtype=self.dtype).T
|
||||
self.bias = w[1].to(dtype=self.dtype)
|
||||
self.has_bias = True
|
||||
else:
|
||||
raise ValueError("Invalid weight type")
|
||||
# self.linear = self.linear.to(device)
|
||||
self.w = self.w.to(device)
|
||||
self.weight = self.weight.to(device)
|
||||
if self.has_bias:
|
||||
self.bias = self.bias.to(device)
|
||||
|
||||
def unload(self):
|
||||
if self.w is not None:
|
||||
self.w = None
|
||||
if self.weight is not None:
|
||||
self.weight = None
|
||||
if self.has_bias:
|
||||
self.bias = None
|
||||
|
||||
|
@ -218,6 +218,7 @@ class KLinearMarlin(KLinearBase):
|
|||
self.workspace = MarlinWorkspace(
|
||||
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device
|
||||
)
|
||||
self.weight = marlin_q_w # modeling_xxx.py may use linear.weight
|
||||
self.marlin_q_w = marlin_q_w
|
||||
self.marlin_s = marlin_s
|
||||
self.g_idx = g_idx
|
||||
|
@ -424,11 +425,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
|
|||
if mode == InferenceState.PREFILL:
|
||||
self.generate_linear.unload()
|
||||
self.prefill_linear.load(w=w)
|
||||
self.device = self.prefill_linear.device
|
||||
self.device = self.prefill_linear.device
|
||||
self.weight = self.prefill_linear.weight # modeling_xxx.py may use linear.weight
|
||||
elif mode == InferenceState.GENERATE:
|
||||
self.prefill_linear.unload()
|
||||
self.generate_linear.load(w=w)
|
||||
self.device = self.generate_linear.device
|
||||
self.weight = self.generate_linear.weight # modeling_xxx.py may use linear.weight
|
||||
elif mode == InferenceState.UNLOAD:
|
||||
self.prefill_linear.unload()
|
||||
self.generate_linear.unload()
|
||||
|
|
|
@ -182,6 +182,53 @@
|
|||
generate_device: "cuda:3"
|
||||
prefill_device: "cuda:3"
|
||||
|
||||
# === MLP Experts Replacement ===
|
||||
# replace with marlin expert. Open and modify layer-num as needed.
|
||||
# Each layer of malin experts takes about 6GB of GPU memory.
|
||||
# !!!Do remember 'close' cuda graph if you are using marlin expert.!!!
|
||||
# !!!KExpertsTorch is untested, we don't have enough VRAM.!!!
|
||||
|
||||
# # GPU 0: layers 3–4
|
||||
# - match:
|
||||
# name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$"
|
||||
# replace:
|
||||
# class: ktransformers.operators.experts.KTransformersExperts
|
||||
# kwargs:
|
||||
# generate_device: "cuda:0"
|
||||
# generate_op: "KExpertsMarlin"
|
||||
# recursive: False
|
||||
|
||||
# # GPU 1: layers 15–17
|
||||
# - match:
|
||||
# name: "^model\\.layers\\.(1[5-7])\\.mlp\\.experts$"
|
||||
# replace:
|
||||
# class: ktransformers.operators.experts.KTransformersExperts
|
||||
# kwargs:
|
||||
# generate_device: "cuda:1"
|
||||
# generate_op: "KExpertsMarlin"
|
||||
# recursive: False
|
||||
|
||||
# # GPU 2: layers 30–32
|
||||
# - match:
|
||||
# name: "^model\\.layers\\.(3[0-2])\\.mlp\\.experts$"
|
||||
# replace:
|
||||
# class: ktransformers.operators.experts.KTransformersExperts
|
||||
# kwargs:
|
||||
# generate_device: "cuda:2"
|
||||
# generate_op: "KExpertsMarlin"
|
||||
# recursive: False
|
||||
|
||||
# # GPU 3: layers 45–46
|
||||
# - match:
|
||||
# name: "^model\\.layers\\.(4[5-6])\\.mlp\\.experts$"
|
||||
# replace:
|
||||
# class: ktransformers.operators.experts.KTransformersExperts
|
||||
# kwargs:
|
||||
# generate_device: "cuda:3"
|
||||
# generate_op: "KExpertsMarlin"
|
||||
# recursive: False
|
||||
|
||||
|
||||
# === MLP Experts Replacement ===
|
||||
|
||||
# GPU 0: layers 0–14
|
||||
|
@ -316,6 +363,8 @@
|
|||
generate_device: "cuda:2"
|
||||
prefill_device: "cuda:2"
|
||||
|
||||
# don't inject lm_head if already inject marlin experts
|
||||
|
||||
# For final modules (model.norm and lm_head), ensure they are on GPU 3 (as in your original config)
|
||||
- match:
|
||||
name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)"
|
||||
|
|
|
@ -713,6 +713,8 @@
|
|||
generate_device: "cuda:7"
|
||||
prefill_device: "cuda:7"
|
||||
|
||||
# don't inject lm_head if already inject marlin experts
|
||||
|
||||
# For final modules (model.norm and lm_head), ensure they are on GPU 7 (as in your original config)
|
||||
- match:
|
||||
name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)"
|
||||
|
|
|
@ -282,8 +282,38 @@ class GGUFLoader:
|
|||
itemsize = int(np.empty([], dtype = item_type).itemsize)
|
||||
return mmap_data[offset : offset + itemsize * item_count]
|
||||
|
||||
def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "gpu")->torch.Tensor:
|
||||
t = self.tensor_info[name]
|
||||
if device.lower() == "cpu":
|
||||
print(f"loading expert {expert_id} of {name} with CPU")
|
||||
shape = t["shape"]
|
||||
ggml_type = t["ggml_type"]
|
||||
if ggml_type not in GGML_NAMES:
|
||||
raise NotImplementedError(f"ggml_type {ggml_type} not implemented")
|
||||
ggml_name = GGML_NAMES[ggml_type]
|
||||
|
||||
# TODO: experts may fused in quant block, split it
|
||||
assert elements_per_expert % GGML_ELEMENTS_PER_BLOCK[ggml_name] == 0, "experts may fused in quant block, please use CPU dequant"
|
||||
|
||||
blocks_per_experts = elements_per_expert // GGML_ELEMENTS_PER_BLOCK[ggml_name]
|
||||
block_size = GGML_BLOCK_SIZES[ggml_name]
|
||||
offset = expert_id * block_size * blocks_per_experts
|
||||
data = data[offset: offset + block_size * blocks_per_experts]
|
||||
|
||||
if "cuda" in device.lower():
|
||||
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
|
||||
else:
|
||||
values = GGML_DEQUANTIZE[ggml_name](data)
|
||||
values = torch.from_numpy(values)
|
||||
|
||||
values = values.view(shape[-2::-1])
|
||||
|
||||
return values
|
||||
|
||||
def load_gguf_tensor(self, name: str, device:str = "cpu")->torch.Tensor:
|
||||
t = self.tensor_info[name]
|
||||
if device.lower() == "cpu":
|
||||
print(f"loading {name} with CPU")
|
||||
|
||||
shape = t["shape"]
|
||||
ggml_type = t["ggml_type"]
|
||||
|
|
Loading…
Add table
Reference in a new issue