From 412055d450fd870d65caf6b17589f46df56b19af Mon Sep 17 00:00:00 2001 From: Atream Date: Wed, 14 Aug 2024 16:10:54 +0800 Subject: [PATCH] [feature] experts can be injected using CPUInfer [fix] fix ktransformers interface when use new CUDAGraphRunner [fix] fix YAML and optimize logic, the top rule has the highest priority --- ktransformers/operators/cpuinfer.py | 18 +++ ktransformers/operators/experts.py | 94 +++++++------- ktransformers/operators/linear.py | 119 +++++++++++++++++- ktransformers/optimize/optimize.py | 5 +- .../DeepSeek-V2-Chat-multi-gpu-4.yaml | 62 ++++----- .../DeepSeek-V2-Chat-multi-gpu.yaml | 34 ++--- .../optimize_rules/DeepSeek-V2-Chat.yaml | 33 +++-- .../DeepSeek-V2-Lite-Chat-multi-gpu.yaml | 34 ++--- .../optimize/optimize_rules/Mixtral.yaml | 18 +-- .../Qwen2-57B-A14B-Instruct-multi-gpu.yaml | 29 ++--- .../Qwen2-57B-A14B-Instruct.yaml | 22 ++-- .../backend/interfaces/ktransformers.py | 6 +- ktransformers/util/utils.py | 2 +- 13 files changed, 318 insertions(+), 158 deletions(-) create mode 100644 ktransformers/operators/cpuinfer.py diff --git a/ktransformers/operators/cpuinfer.py b/ktransformers/operators/cpuinfer.py new file mode 100644 index 0000000..027cc8b --- /dev/null +++ b/ktransformers/operators/cpuinfer.py @@ -0,0 +1,18 @@ +import sys, os +from typing import Any +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build")) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release")) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug")) +import cpuinfer_ext +from ktransformers.server.config.config import Config +class CPUInfer: + cpu_infer = None + def __init__(self, cpu_infer:int = Config().cpu_infer): + if CPUInfer.cpu_infer is None: + CPUInfer.cpu_infer = cpuinfer_ext.CPUInfer(cpu_infer) + + def __getattribute__(self, __name: str) -> Any: + return CPUInfer.cpu_infer.__getattribute__(__name) + + def __setattr__(self, __name: str, __value: Any) -> None: + return CPUInfer.cpu_infer.__setattr__(__name, __value) \ No newline at end of file diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 7028c74..75fb729 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -33,6 +33,7 @@ from transformers.configuration_utils import PretrainedConfig from abc import ABC, abstractmethod from ktransformers.operators.linear import QuantizedLinearMarlin, QuantizedLinearTorch, KTransformerLinear import time +from ktransformers.operators.cpuinfer import CPUInfer # class Base(BaseInjectedModule, ABC): @@ -117,7 +118,7 @@ class MLPCPUExperts(MLPExpertsBase): output_cpu:Tensor = None output_gpu_map:dict = {} # Manage output tensor buffer on different gpu #stream_map:dict = {} # Manage cuda stream on different gpu - CPU_INFER = cpuinfer_ext.CPUInfer(Config().cpu_infer) + CPU_INFER = CPUInfer(Config().cpu_infer) def __init__( self, key: str, @@ -126,7 +127,7 @@ class MLPCPUExperts(MLPExpertsBase): n_routed_experts: int, orig_module: nn.Module = None, device: str = "cpu", - out_device: str = "cuda", # this device mean which device the output should on + out_device: str = "cuda", # this device mean which device the output should on. TODO: support cpu. **kwargs ): super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) @@ -135,51 +136,50 @@ class MLPCPUExperts(MLPExpertsBase): self.out_device = out_device def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False): - with torch.device(self.out_device): - if device: - assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU, Parameter \"device\" can be cpu or None." - if w is None: w = self.load_weights()[self.key] - self.gate = w["gate"] - self.up = w["up"] - self.down = w["down"] - self.gate_type = w["gate_type"] - self.up_type = w["up_type"] - self.down_type = w["down_type"] - gate_ptr = ctypes.addressof( - ctypes.cast(self.gate.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents - ) - up_ptr = ctypes.addressof( - ctypes.cast(self.up.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents - ) - down_ptr = ctypes.addressof( - ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents - ) - # print(self.gate_qtype, self.up_qtype, self.down_qtype) - n_routed_experts = self.n_routed_experts - # n_routed_experts = len(self.orig_module) - moe_config = MOEConfig( - n_routed_experts, - self.config.num_experts_per_tok, - self.config.hidden_size, - self.config.moe_intermediate_size, - 64, - 10, - 1024, - gate_ptr, - up_ptr, - down_ptr, - self.gate_type, - self.up_type, - self.down_type, - 30, # TODO: get from model.dtype - ) - # print(n_routed_experts, hidden_size, moe_intermediate_size) - num_experts_per_tok = self.config.num_experts_per_tok - self.moe = MOE(moe_config) - self.cpu_infer = MLPCPUExperts.CPU_INFER - if warmup: - self.cpu_infer.submit(self.moe.warm_up()) - self.cpu_infer.sync() + if device: + assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU, Parameter \"device\" can be cpu or None." + if w is None: w = self.load_weights()[self.key] + self.gate = w["gate"] + self.up = w["up"] + self.down = w["down"] + self.gate_type = w["gate_type"] + self.up_type = w["up_type"] + self.down_type = w["down_type"] + gate_ptr = ctypes.addressof( + ctypes.cast(self.gate.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents + ) + up_ptr = ctypes.addressof( + ctypes.cast(self.up.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents + ) + down_ptr = ctypes.addressof( + ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents + ) + # print(self.gate_qtype, self.up_qtype, self.down_qtype) + n_routed_experts = self.n_routed_experts + # n_routed_experts = len(self.orig_module) + moe_config = MOEConfig( + n_routed_experts, + self.config.num_experts_per_tok, + self.config.hidden_size, + self.config.moe_intermediate_size, + 64, + 10, + 1024, + gate_ptr, + up_ptr, + down_ptr, + self.gate_type, + self.up_type, + self.down_type, + 30, # TODO: get from model.dtype + ) + # print(n_routed_experts, hidden_size, moe_intermediate_size) + num_experts_per_tok = self.config.num_experts_per_tok + self.moe = MOE(moe_config) + self.cpu_infer = MLPCPUExperts.CPU_INFER + if warmup: + self.cpu_infer.submit(self.moe.warm_up()) + self.cpu_infer.sync() if self.out_device not in MLPCPUExperts.output_gpu_map: MLPCPUExperts.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device) if MLPCPUExperts.input_tensor_cpu == None: diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 90b5506..e984a90 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -11,8 +11,9 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' +import ctypes import torch -from torch import nn +from torch import Tensor, nn import KTransformersOps from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.utils import InferenceState @@ -25,7 +26,13 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl from ktransformers.operators.base_operator import BaseInjectedModule from transformers.configuration_utils import PretrainedConfig from abc import ABC, abstractmethod - +import sys, os +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build")) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release")) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug")) +import cpuinfer_ext +from ktransformers.operators.cpuinfer import CPUInfer +from ktransformers.server.config.config import Config #class QuantizedLinearBase(BaseInjectedModule, ABC): class QuantizedLinearBase(ABC): @@ -118,6 +125,7 @@ class QuantizedLinearTorch(QuantizedLinearBase): def forward(self, x: torch.Tensor) -> torch.Tensor: dtype = x.dtype out_device = x.device + # TODO: support CUDA Graph when using cpu, but CPUInfer is recommended. x = x.to(device=self.device, dtype=self.dtype) x = x @ self.w if self.has_bias: @@ -128,7 +136,7 @@ class QuantizedLinearTorch(QuantizedLinearBase): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): if device is None: device = self.device if w is None: w = self.load_weight(device=device) - + if isinstance(w, nn.Parameter): self.w = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T self.has_bias = False @@ -243,10 +251,113 @@ class QuantizedLinearMarlin(QuantizedLinearBase): self.g_idx = None self.sort_indices = None self.workspace = None - + +class QuantizedLinearCPUInfer(QuantizedLinearBase): + CPU_INFER = CPUInfer(Config().cpu_infer) + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module = None, + device: str = "cpu", + out_device: str = "cuda", # this device mean which device the output should on. TODO: support cpu. + stride = 16, + group_max_len = 1024, + **kwargs, + ): + super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) + self.has_bias = False + self.dtype = torch.get_default_dtype() + self.w = None + self.has_bias = False + self.stride = stride + self.group_max_len = group_max_len + self.out_device = out_device + + def forward(self, x: torch.Tensor) -> torch.Tensor: + origin_shape = x.shape # [batch_size, q_len, hidden_size] + if origin_shape[1] == 1: + out_device = x.device + self.input_tensor_cpu.copy_(x, non_blocking=True) + qlen = origin_shape[1] + QuantizedLinearCPUInfer.CPU_INFER.submit_with_cuda_stream( + torch.cuda.current_stream().cuda_stream, + self.linear.forward( + qlen, + self.input_tensor_cpu.data_ptr(), + self.output_cpu.data_ptr() + ) + ) + QuantizedLinearCPUInfer.CPU_INFER.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream) + self.output_gpu.copy_(self.output_cpu, non_blocking=True) + if self.has_bias: + self.output_gpu += self.bias + return self.output_gpu + else: + dtype = x.dtype + out_device = x.device + x = x.to(device=self.device) + qlen = origin_shape[1] + output_shape = (*origin_shape[:-1], self.out_features) + output = torch.empty(output_shape, device=x.device, dtype=x.dtype) + QuantizedLinearCPUInfer.CPU_INFER.submit( + self.linear.forward( + qlen, + x.data_ptr(), + output.data_ptr() + ) + ) + QuantizedLinearCPUInfer.CPU_INFER.sync() + if self.has_bias: + output = output + self.bias + output = output.to(dtype=dtype, device=out_device) + return output + + def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None, warmup:bool = True): + print(f"loading {self.key} to {self.device} using CPUInfer") + if device is None: device = self.device + self.load_weights(w=w, device=device) + if self.bias is not None: + self.has_bias = True + self.bias = self.bias.to(device) + + weight_ptr = ctypes.addressof( + ctypes.cast(self.weight.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents + ) + config = cpuinfer_ext.linear.LinearConfig(self.in_features, self.out_features, self.stride, self.group_max_len, weight_ptr, self.weight_type, 30) + self.linear = cpuinfer_ext.linear.Linear(config) + + if warmup: + QuantizedLinearCPUInfer.CPU_INFER.submit(self.linear.warm_up()) + QuantizedLinearCPUInfer.CPU_INFER.sync() + self.input_tensor_cpu = torch.zeros((1, 1, self.in_features), device="cpu", pin_memory=True) + self.output_cpu = torch.zeros((1, 1, self.out_features), device="cpu", pin_memory=True, dtype=torch.bfloat16) + self.output_gpu = torch.zeros((1, 1, self.out_features), device=self.out_device) + + def load_weights(self, w: dict | nn.Parameter | tuple | None = None, device: str = "cpu"): + if self.key + ".weight" in self.gguf_loader.tensor_info: + if self.key + ".bias" in self.gguf_loader.tensor_file_map: + self.weight = self.gguf_loader.get_mmap_tensor(self.key + ".weight") + self.weight_type = self.gguf_loader.tensor_info[self.key + ".weight"]["ggml_type"] + self.bias = self.gguf_loader.load_gguf_tensor(self.key + ".bias", device=device) + else: + self.weight = self.gguf_loader.get_mmap_tensor(self.key + ".weight") + self.weight_type = self.gguf_loader.tensor_info[self.key + ".weight"]["ggml_type"] + self.bias = None + else: + raise ValueError(f"Linear {self.key} not found in gguf_loader") + + def unload(self): + if self.w is not None: + self.w = None + if self.has_bias: + self.bias = None + LINEAR_MAP = { "QuantizedLinearMarlin": QuantizedLinearMarlin, "QuantizedLinearTorch": QuantizedLinearTorch, + "QuantizedLinearCPUInfer": QuantizedLinearCPUInfer } class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase): diff --git a/ktransformers/optimize/optimize.py b/ktransformers/optimize/optimize.py index 36ab62d..32eab01 100644 --- a/ktransformers/optimize/optimize.py +++ b/ktransformers/optimize/optimize.py @@ -58,7 +58,6 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p #print("gen_optimize_config", prefix, module_name, translated_name) recursive = True for rule in rule_list: - #print(rule) match_meta = rule["match"] if "class" not in match_meta and "name" not in match_meta: raise Exception("match must have at least one of \"class\" and \"name\"") @@ -87,6 +86,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p out_data[module_name]["kwargs"].update(copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict()) if "recursive" in rule: recursive = bool(rule["recursive"]) + break if module_name not in out_data: out_data[module_name]= { @@ -127,5 +127,6 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo with torch.device("meta"): inject(module, optimize_config, model_config, gguf_loader) load_weights(module, gguf_loader) - model_config.gguf_loader = gguf_loader + module.gguf_loader = gguf_loader del_meta(module) + torch.cuda.empty_cache() diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml index 1d6b46f..31c5c87 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml @@ -1,32 +1,3 @@ -- match: - name: "^model\\.layers\\.([0-9])\\." - replace: - class: "default" - kwargs: - generate_device: "cuda:0" - prefill_device: "cuda:0" -- match: - name: "(^model\\.layers\\.([1][0-9])\\.)" - replace: - class: "default" - kwargs: - generate_device: "cuda:1" - prefill_device: "cuda:1" -- match: - name: "(^model\\.layers\\.([2][0-9])\\.)" - replace: - class: "default" - kwargs: - generate_device: "cuda:2" - prefill_device: "cuda:2" -- match: - name: "(^model\\.layers\\.([345][0-9])\\.)|(^model.norm)|(^lm_head)" - replace: - class: "default" - kwargs: - generate_device: "cuda:3" - prefill_device: "cuda:3" - - match: name: "^model.embed_tokens" replace: @@ -69,7 +40,7 @@ prefill_device: "cuda:3" - match: - name: "^model\\.layers\\.([1][0-9])\\.(?!self_attn).*$" # regular expression + name: "^model\\.layers\\.([0-9])\\.(?!self_attn).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types @@ -225,4 +196,33 @@ transfer_map: 10: "cuda:1" 20: "cuda:2" - 30: "cuda:3" \ No newline at end of file + 30: "cuda:3" + +- match: + name: "^model\\.layers\\.([0-9])\\." + replace: + class: "default" + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" +- match: + name: "(^model\\.layers\\.([1][0-9])\\.)" + replace: + class: "default" + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" +- match: + name: "(^model\\.layers\\.([2][0-9])\\.)" + replace: + class: "default" + kwargs: + generate_device: "cuda:2" + prefill_device: "cuda:2" +- match: + name: "(^model\\.layers\\.([345][0-9])\\.)|(^model.norm)|(^lm_head)" + replace: + class: "default" + kwargs: + generate_device: "cuda:3" + prefill_device: "cuda:3" \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml index 45af034..15e8e10 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml @@ -1,19 +1,3 @@ -- match: - name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." - replace: - class: "default" - kwargs: - generate_device: "cuda:0" - prefill_device: "cuda:0" - -- match: - name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)|(lm_head)" - replace: - class: "default" - kwargs: - generate_device: "cuda:1" - prefill_device: "cuda:1" - - match: name: "^model.embed_tokens" replace: @@ -123,4 +107,20 @@ kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill transfer_map: - 30: "cuda:1" \ No newline at end of file + 30: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." + replace: + class: "default" + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + +- match: + name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)|(lm_head)" + replace: + class: "default" + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml index 328c9d7..47fe084 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml @@ -1,14 +1,21 @@ -- match: - name: "^model\\.layers\\..*\\.|^lm_head" - replace: - class: "default" - kwargs: - generate_device: "cuda" - prefill_device: "cuda" - match: class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbedding + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +#- match: +# name: "^model\\.layers\\.([1-5][0-9])\\.mlp\\.shared_experts.*$" # regular expression +# class: torch.nn.Linear # only match modules matching name and class simultaneously +# replace: +# class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types +# kwargs: +# generate_device: "cpu" +# prefill_device: "cuda" +# generate_op: "QuantizedLinearCPUInfer" +# prefill_op: "QuantizedLinearTorch" +# out_device: "cuda" - match: name: "^model\\.layers\\.(?!.*self_attn).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously @@ -24,6 +31,9 @@ class: ktransformers.models.modeling_deepseek.DeepseekV2MoE replace: class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: @@ -39,16 +49,21 @@ name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" - match: name: "^model$" replace: class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelKTransformers" kwargs: + generate_device: "cuda" + prefill_device: "cuda" per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: - generate_device: "cpu" - prefill_device: "cpu" \ No newline at end of file + generate_device: "cpu" + prefill_device: "cpu" \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml index c9c1809..e79e4fd 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml @@ -1,19 +1,3 @@ -- match: - name: "^model\\.layers\\.(0|[1-9])\\." - replace: - class: "default" - kwargs: - generate_device: "cuda:0" - prefill_device: "cuda:0" - -- match: - name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)|(lm_head)" - replace: - class: "default" - kwargs: - generate_device: "cuda:1" - prefill_device: "cuda:1" - - match: name: "^model.embed_tokens" replace: @@ -123,4 +107,20 @@ kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill transfer_map: - 10: "cuda:1" \ No newline at end of file + 10: "cuda:1" + +- match: + name: "^model\\.layers\\.(0|[1-9])\\." + replace: + class: "default" + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + +- match: + name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)|(lm_head)" + replace: + class: "default" + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/Mixtral.yaml b/ktransformers/optimize/optimize_rules/Mixtral.yaml index 5bd6705..21fdb72 100644 --- a/ktransformers/optimize/optimize_rules/Mixtral.yaml +++ b/ktransformers/optimize/optimize_rules/Mixtral.yaml @@ -1,14 +1,10 @@ -- match: - name: "^model\\.layers\\..*\\." - replace: - class: "default" - kwargs: - generate_device: "cuda" - prefill_device: "cuda" - match: class: ktransformers.models.modeling_mixtral.MixtralRotaryEmbedding replace: class: ktransformers.operators.RoPE.RotaryEmbedding + kwargs: + generate_device: "cuda" + prefill_device: "cuda" - match: name: "^model\\.layers\\..*$" class: torch.nn.Linear # only match modules matching name and class simultaneously @@ -43,3 +39,11 @@ kwargs: generate_device: "cpu" prefill_device: "cpu" + +- match: + name: "^model\\.layers\\..*\\." + replace: + class: "default" + kwargs: + generate_device: "cuda" + prefill_device: "cuda" \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml index 82415aa..d48ebeb 100644 --- a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml +++ b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml @@ -1,10 +1,3 @@ -- match: - name: "^model\\.layers\\.([012])\\." - replace: - class: "default" - kwargs: - generate_device: "cuda:0" - prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([012])\\." class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding @@ -41,13 +34,6 @@ out_device: "cuda:0" recursive: False # don't recursively inject submodules of this module -- match: - name: "^model\\.layers\\.([12][0-9]|[3-9])\\." - replace: - class: "default" - kwargs: - generate_device: "cuda:1" - prefill_device: "cuda:1" - match: name: "^model\\.layers\\.([12][0-9]|[3-9])\\." class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding @@ -109,3 +95,18 @@ transfer_map: 3: "cuda:1" +- match: + name: "^model\\.layers\\.([012])\\." + replace: + class: "default" + kwargs: + generate_device: "cuda:0" + prefill_device: "cuda:0" + +- match: + name: "^model\\.layers\\.([12][0-9]|[3-9])\\." + replace: + class: "default" + kwargs: + generate_device: "cuda:1" + prefill_device: "cuda:1" \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml index 3fd59cb..a48b15a 100644 --- a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml +++ b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml @@ -1,14 +1,10 @@ -- match: - name: "^model\\.layers\\..*\\." - replace: - class: "default" - kwargs: - generate_device: "cuda" - prefill_device: "cuda" - match: class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding replace: class: ktransformers.operators.RoPE.RotaryEmbedding + kwargs: + generate_device: "cuda" + prefill_device: "cuda" - match: name: "^model\\.layers\\..*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously @@ -24,6 +20,9 @@ class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock replace: class: ktransformers.operators.experts.Qwen2MoeSparseMoeBlockInjected # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: @@ -48,4 +47,11 @@ class: "default" kwargs: generate_device: "cpu" - prefill_device: "cpu" \ No newline at end of file + prefill_device: "cpu" +- match: + name: "^model\\.layers\\..*\\." + replace: + class: "default" + kwargs: + generate_device: "cuda" + prefill_device: "cuda" \ No newline at end of file diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 77b0cda..8d121d5 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -6,6 +6,7 @@ from ktransformers.optimize.optimize import optimize_and_load_gguf from ktransformers.models.custom_cache import StaticCache from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.local_chat import custom_models, default_optimize_rules +from ktransformers.util.utils import get_device class KTransformersThreadContext(TransformersThreadContext): @@ -48,8 +49,11 @@ class KTransformersInterface(TransformersInterface): def decode_one_tokens(self): if not hasattr(self, "cuda_graph_runner"): + device_map = self.model.gguf_loader.tensor_device_map + torch_device = get_device('blk.0.self_attn', device_map) + torch_device = "cuda:0" if torch_device == "cuda" else torch_device self.cuda_graph_runner = CUDAGraphRunner() - self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, return_dict=False, use_cache=True) + self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, main_device=torch_device, return_dict=False, use_cache=True) if hasattr(self, "cuda_graph_runner"): logits = self.cuda_graph_runner(self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position) diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 7993d62..8c91d47 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -89,7 +89,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud os.environ["TOKENIZERS_PARALLELISM"] = "false" torch._dynamo.config.suppress_errors = True batch_size, seq_length = inputs.shape - device_map = model.config.gguf_loader.tensor_device_map + device_map = model.gguf_loader.tensor_device_map torch_device = get_device('blk.0.self_attn', device_map) torch_device = "cuda:0" if torch_device == "cuda" else torch_device inputs = inputs.to(torch_device)