diff --git a/doc/en/FAQ.md b/doc/en/FAQ.md index 75e5e10..e738a29 100644 --- a/doc/en/FAQ.md +++ b/doc/en/FAQ.md @@ -25,7 +25,7 @@ from-https://github.com/kvcache-ai/ktransformers/issues/129#issue-2842799552 1. local_chat.py: You can increase the context window size by setting `--max_new_tokens` to a larger value. 2. server: Increase the `--cache_lens' to a larger value. 2. Move more weights to the GPU. - Refer to the ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml + Refer to the ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml ```yaml - match: name: "^model\\.layers\\.([4-10])\\.mlp\\.experts$" # inject experts in layer 4~10 as marlin expert @@ -39,6 +39,8 @@ from-https://github.com/kvcache-ai/ktransformers/issues/129#issue-2842799552 You can modify layer as you want, eg. `name: "^model\\.layers\\.([4-10])\\.mlp\\.experts$"` to `name: "^model\\.layers\\.([4-12])\\.mlp\\.experts$"` to move more weights to the GPU. > Note: The first matched rule in yaml will be applied. For example, if you have two rules that match the same layer, only the first rule's replacement will be valid. + > Note:Currently, executing experts on the GPU will conflict with CUDA Graph. Without CUDA Graph, there will be a significant slowdown. Therefore, unless you have a substantial amount of VRAM (placing a single layer of experts for DeepSeek-V3/R1 on the GPU requires at least 5.6GB of VRAM), we do not recommend enabling this feature. We are actively working on optimization. + > Note KExpertsTorch is untested. ### Q: If I don't have enough VRAM, but I have multiple GPUs, how can I utilize them? diff --git a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu index 0c49fa7..d5b4a2c 100644 --- a/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu +++ b/ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu @@ -17,8 +17,8 @@ #include __global__ void dequantize_q8_0_kernel(float* output, const float* scales, const int8_t* qs, int num_blocks, int blk_size) { - int global_idx = blockIdx.x * blockDim.x + threadIdx.x; - for (auto block_id=global_idx; block_id(data + block_id * blk_size + 80))); @@ -72,10 +72,10 @@ __global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { - int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; const uint32_t kmask1 = 0x03030303; const uint32_t kmask2 = 0x0f0f0f0f; - for (auto block_id=global_idx; block_id(data + block_id * blk_size + 0))); @@ -181,8 +181,8 @@ __global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size } __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { - int global_idx = blockIdx.x * blockDim.x + threadIdx.x; - for (auto block_id=global_idx; block_id(data + block_id * blk_size + 208))); @@ -215,8 +215,8 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; __global__ void dequantize_iq4_xs_kernel(int8_t* data, float* output, int blk_size, int num_blocks) { - int global_idx = blockIdx.x * blockDim.x + threadIdx.x; - for (auto block_id=global_idx; block_id(data + block_id * blk_size))); const uint16_t scales_h = *(reinterpret_cast(data + block_id * blk_size + 2)); diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 274a3ca..32675dc 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -18,6 +18,7 @@ import torch.nn.functional as F import torch import sys, os from ktransformers.operators.base_operator import BaseInjectedModule +from tqdm import tqdm sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build")) sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release")) @@ -225,6 +226,7 @@ class KExpertsCPU(KExpertsBase): return def load_weights(self, override_key: str | None = None, device: str = "cpu"): + # TODO: support Bias res = {} if override_key is not None: keys = override_key @@ -288,6 +290,8 @@ class KExpertsMarlin(KExpertsBase): self.act_fn = ACT2FN[config.hidden_act] assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU" self.device = device + self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size + # create empty marlin experts according to the number of experts per token # up self.up_projs = [KLinearMarlin(key+ "." + "ffn_up_exps", gguf_loader, config, device=device) for i in range(self.expert_num)] @@ -299,17 +303,34 @@ class KExpertsMarlin(KExpertsBase): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False): if device is None: device = self.device assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU" - if w is None: w = self.load_weights()[self.key] + if w is None: + w = self.load_weights() + load_by_experts = True - if isinstance(w, dict): - self.gate = w["gate"] - self.up = (w["up"]) - self.down = (w["down"]) - for i in range(self.expert_num): - self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device) - self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device) - self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device) - self.loaded_experts_idx.append(i) + if load_by_experts: + if isinstance(w, dict): + self.gate = w["gate"] + self.up = (w["up"]) + self.down = (w["down"]) + for i in tqdm(range(self.expert_num), desc=f"Dequanting and quanting for KExpertsMarlin {self.key}"): + up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", self.up, i, self.elements_per_tensor, device=self.device) + gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", self.gate, i, self.elements_per_tensor, device=self.device) + down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", self.down, i, self.elements_per_tensor, device=self.device) + + self.up_projs[i].load(nn.Parameter(up_weights), device=device) + self.gate_projs[i].load(nn.Parameter(gate_weights), device=device) + self.down_projs[i].load(nn.Parameter(down_weights), device=device) + self.loaded_experts_idx.append(i) + else: + if isinstance(w, dict): + self.gate = w["gate"] + self.up = (w["up"]) + self.down = (w["down"]) + for i in range(self.expert_num): + self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device) + self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device) + self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device) + self.loaded_experts_idx.append(i) return def unload(self): @@ -329,20 +350,13 @@ class KExpertsMarlin(KExpertsBase): gate = None up = None down = None - gate_type = None - up_type = None - down_type = None for key in keys: if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info: - gate = self.gguf_loader.load_gguf_tensor(key + ".ffn_gate_exps.weight") - up = self.gguf_loader.load_gguf_tensor(key + ".ffn_up_exps.weight") - down = self.gguf_loader.load_gguf_tensor(key + ".ffn_down_exps.weight") - gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"] - up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"] - down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"] - # tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"]) - res = {key:{"gate": nn.Parameter(gate), "up": nn.Parameter(up), "down": nn.Parameter(down), "gate_type": gate_type, "up_type": up_type, "down_type": down_type}} + gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight") + up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight") + down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight") + res = {"gate": gate, "up": up, "down": down} return res def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor: @@ -381,6 +395,7 @@ class KExpertsMarlin(KExpertsBase): return final_hidden_states.to(dtype=org_dtype, device=org_device) +# untested, CUDA OOM class KExpertsTorch(KExpertsBase): expert_num: int loaded_experts_idx: list[int] @@ -402,19 +417,39 @@ class KExpertsTorch(KExpertsBase): # self.loaded_experts_idx = [] self.act_fn = ACT2FN[config.hidden_act] self.device = device - self.gate = None - self.up = None - self.donw = None + self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size + self.gate = [None for _ in range(self.expert_num)] + self.up = [None for _ in range(self.expert_num)] + self.down = [None for _ in range(self.expert_num)] self.dtype = torch.get_default_dtype() def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False): if device is None: device = self.device - if w is None: w = self.load_weights(device=device)[self.key] + if w is None: + w = self.load_weights() + load_by_experts = True - if isinstance(w, dict): - self.gate = w["gate"].to(device=device, dtype=self.dtype) - self.up = w["up"].to(device=device, dtype=self.dtype) - self.down = w["down"].to(device=device, dtype=self.dtype) + if load_by_experts: + if isinstance(w, dict): + for i in tqdm(range(self.expert_num), desc=f"Dequanting for KExpertsTorch {self.key}"): + up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", w["up"], i, self.elements_per_tensor, device=self.device) + gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", w["gate"], i, self.elements_per_tensor, device=self.device) + down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", w["down"], i, self.elements_per_tensor, device=self.device) + + self.up[i] = up_weights + self.gate[i] = gate_weights + self.down[i] = down_weights + else: + if isinstance(w, dict): + for i in range(self.expert_num): + self.gate[i] = w["gate"][i, ...].to(device=device, dtype=self.dtype) + self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype) + self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype) + + self.up = torch.cat(self.gate, dim=0) + self.gate = torch.cat(self.gate, dim=0) + self.down = torch.cat(self.gate, dim=0) + return def unload(self): if self.gate is not None: @@ -422,6 +457,25 @@ class KExpertsTorch(KExpertsBase): self.up = None self.down = None + def load_weights(self, override_key: str | None = None): + res = {} + if override_key is not None: + keys = override_key + else: + keys = [self.key] + + gate = None + up = None + down = None + + for key in keys: + if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info: + gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight") + up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight") + down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight") + res = {"gate": gate, "up": up, "down": down} + return res + def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor: org_device = hidden_states_cpu.device @@ -582,7 +636,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock): if isinstance(self.experts, KExpertsBase): y = ( - self.moe_on_cpuinfer( + self.moe_kexperts( hidden_states_expert, selected_experts_expert, routing_weights_expert ) .view(*orig_shape) @@ -601,8 +655,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock): return y, router_logits @torch.no_grad() - def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: - outs = torch.empty_like(x) + def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: outs = self.experts(x, topk_ids, topk_weight) return outs @@ -672,7 +725,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE): y_ = self.shared_experts(identity).squeeze(0) if isinstance(self.experts, KExpertsBase): - y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device) + y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device) elif hidden_states.size(0) > 10: # TODO may bugs here y = ( @@ -692,8 +745,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE): return y @torch.no_grad() - def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: - outs = torch.empty_like(x) + def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: outs = self.experts(x, topk_ids, topk_weight) return outs @@ -773,7 +825,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): y_ = self.shared_experts(identity).squeeze(0) if isinstance(self.experts, KExpertsBase): - y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device) + y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device) elif hidden_states.size(0) > 10: # TODO may bugs here y = ( @@ -793,8 +845,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): return y @torch.no_grad() - def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: - outs = torch.empty_like(x) + def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: outs = self.experts(x, topk_ids, topk_weight) return outs @@ -881,7 +932,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock): if isinstance(self.experts, KExpertsBase): y = ( - self.moe_on_cpuinfer( + self.moe_kexperts( hidden_states_expert, selected_experts_expert, routing_weights_expert ) .view(*orig_shape) @@ -900,8 +951,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock): return y, router_logits @torch.no_grad() - def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: - outs = torch.empty_like(x) + def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: outs = self.experts(x, topk_ids, topk_weight) return outs diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 305f266..df01ac9 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -119,7 +119,7 @@ class KLinearTorch(KLinearBase): super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) self.has_bias = False self.dtype = torch.get_default_dtype() - self.w = None + self.weight = None self.has_bias = False def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -127,7 +127,7 @@ class KLinearTorch(KLinearBase): out_device = x.device # TODO: support CUDA Graph when using cpu, but CPUInfer is recommended. x = x.to(device=self.device, dtype=self.dtype) - x = x @ self.w + x = x @ self.weight if self.has_bias: x = x + self.bias x = x.to(dtype=dtype, device=out_device) @@ -140,27 +140,27 @@ class KLinearTorch(KLinearBase): if isinstance(w, nn.Parameter): try: - self.w = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T + self.weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T except: - self.w = w.to(dtype=self.dtype).T + self.weight = w.to(dtype=self.dtype).T self.has_bias = False elif isinstance(w, tuple): try: - self.w = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T + self.weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T except: - self.w = w[0].to(dtype=self.dtype).T + self.weight = w[0].to(dtype=self.dtype).T self.bias = w[1].to(dtype=self.dtype) self.has_bias = True else: raise ValueError("Invalid weight type") # self.linear = self.linear.to(device) - self.w = self.w.to(device) + self.weight = self.weight.to(device) if self.has_bias: self.bias = self.bias.to(device) def unload(self): - if self.w is not None: - self.w = None + if self.weight is not None: + self.weight = None if self.has_bias: self.bias = None @@ -218,6 +218,7 @@ class KLinearMarlin(KLinearBase): self.workspace = MarlinWorkspace( self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device ) + self.weight = marlin_q_w # modeling_xxx.py may use linear.weight self.marlin_q_w = marlin_q_w self.marlin_s = marlin_s self.g_idx = g_idx @@ -424,11 +425,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase): if mode == InferenceState.PREFILL: self.generate_linear.unload() self.prefill_linear.load(w=w) - self.device = self.prefill_linear.device + self.device = self.prefill_linear.device + self.weight = self.prefill_linear.weight # modeling_xxx.py may use linear.weight elif mode == InferenceState.GENERATE: self.prefill_linear.unload() self.generate_linear.load(w=w) self.device = self.generate_linear.device + self.weight = self.generate_linear.weight # modeling_xxx.py may use linear.weight elif mode == InferenceState.UNLOAD: self.prefill_linear.unload() self.generate_linear.unload() diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml index 572f9e5..84ab801 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml @@ -182,6 +182,53 @@ generate_device: "cuda:3" prefill_device: "cuda:3" +# === MLP Experts Replacement === +# replace with marlin expert. Open and modify layer-num as needed. +# Each layer of malin experts takes about 6GB of GPU memory. +# !!!Do remember 'close' cuda graph if you are using marlin expert.!!! +# !!!KExpertsTorch is untested, we don't have enough VRAM.!!! + +# # GPU 0: layers 3–4 +# - match: +# name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$" +# replace: +# class: ktransformers.operators.experts.KTransformersExperts +# kwargs: +# generate_device: "cuda:0" +# generate_op: "KExpertsMarlin" +# recursive: False + +# # GPU 1: layers 15–17 +# - match: +# name: "^model\\.layers\\.(1[5-7])\\.mlp\\.experts$" +# replace: +# class: ktransformers.operators.experts.KTransformersExperts +# kwargs: +# generate_device: "cuda:1" +# generate_op: "KExpertsMarlin" +# recursive: False + +# # GPU 2: layers 30–32 +# - match: +# name: "^model\\.layers\\.(3[0-2])\\.mlp\\.experts$" +# replace: +# class: ktransformers.operators.experts.KTransformersExperts +# kwargs: +# generate_device: "cuda:2" +# generate_op: "KExpertsMarlin" +# recursive: False + +# # GPU 3: layers 45–46 +# - match: +# name: "^model\\.layers\\.(4[5-6])\\.mlp\\.experts$" +# replace: +# class: ktransformers.operators.experts.KTransformersExperts +# kwargs: +# generate_device: "cuda:3" +# generate_op: "KExpertsMarlin" +# recursive: False + + # === MLP Experts Replacement === # GPU 0: layers 0–14 @@ -316,6 +363,8 @@ generate_device: "cuda:2" prefill_device: "cuda:2" +# don't inject lm_head if already inject marlin experts + # For final modules (model.norm and lm_head), ensure they are on GPU 3 (as in your original config) - match: name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)" diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml index 907f5d3..a10b57f 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml @@ -713,6 +713,8 @@ generate_device: "cuda:7" prefill_device: "cuda:7" +# don't inject lm_head if already inject marlin experts + # For final modules (model.norm and lm_head), ensure they are on GPU 7 (as in your original config) - match: name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)" diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index cdb5855..62059a7 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -282,8 +282,38 @@ class GGUFLoader: itemsize = int(np.empty([], dtype = item_type).itemsize) return mmap_data[offset : offset + itemsize * item_count] + def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "gpu")->torch.Tensor: + t = self.tensor_info[name] + if device.lower() == "cpu": + print(f"loading expert {expert_id} of {name} with CPU") + shape = t["shape"] + ggml_type = t["ggml_type"] + if ggml_type not in GGML_NAMES: + raise NotImplementedError(f"ggml_type {ggml_type} not implemented") + ggml_name = GGML_NAMES[ggml_type] + + # TODO: experts may fused in quant block, split it + assert elements_per_expert % GGML_ELEMENTS_PER_BLOCK[ggml_name] == 0, "experts may fused in quant block, please use CPU dequant" + + blocks_per_experts = elements_per_expert // GGML_ELEMENTS_PER_BLOCK[ggml_name] + block_size = GGML_BLOCK_SIZES[ggml_name] + offset = expert_id * block_size * blocks_per_experts + data = data[offset: offset + block_size * blocks_per_experts] + + if "cuda" in device.lower(): + values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) + else: + values = GGML_DEQUANTIZE[ggml_name](data) + values = torch.from_numpy(values) + + values = values.view(shape[-2::-1]) + + return values + def load_gguf_tensor(self, name: str, device:str = "cpu")->torch.Tensor: t = self.tensor_info[name] + if device.lower() == "cpu": + print(f"loading {name} with CPU") shape = t["shape"] ggml_type = t["ggml_type"]