mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 04:30:03 +00:00
support KExpertsMarlin backend
This commit is contained in:
parent
0262f954c7
commit
c4d9bc6670
5 changed files with 214 additions and 46 deletions
|
@ -302,13 +302,13 @@ class KExpertsMarlin(KExpertsBase):
|
|||
if w is None: w = self.load_weights()[self.key]
|
||||
|
||||
if isinstance(w, dict):
|
||||
self.gate = nn.Parameter(torch.from_numpy(w["gate"]))
|
||||
self.up = nn.Parameter(torch.from_numpy(w["up"]))
|
||||
self.down = nn.Parameter(torch.from_numpy(w["down"]))
|
||||
self.gate = w["gate"]
|
||||
self.up = (w["up"])
|
||||
self.down = (w["down"])
|
||||
for i in range(self.expert_num):
|
||||
self.up_projs[i].load(self.up[i,...], device=device)
|
||||
self.gate_projs[i].load(self.gate[i,...], device=device)
|
||||
self.down_projs[i].load(self.down[i,...], device=device)
|
||||
self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device)
|
||||
self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device)
|
||||
self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device)
|
||||
self.loaded_experts_idx.append(i)
|
||||
return
|
||||
|
||||
|
@ -342,22 +342,44 @@ class KExpertsMarlin(KExpertsBase):
|
|||
up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
|
||||
down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
|
||||
# tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"])
|
||||
res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
|
||||
res = {key:{"gate": nn.Parameter(gate), "up": nn.Parameter(up), "down": nn.Parameter(down), "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
|
||||
return res
|
||||
|
||||
def forward(self, input_tensor:torch.Tensor, expert_ids, weights):
|
||||
# forward
|
||||
device = input_tensor.device
|
||||
input_tensor = input_tensor.to("cuda")
|
||||
outs = torch.zeros_like(input_tensor)
|
||||
for expert_idx in range(expert_ids.size(0)):
|
||||
down_proj = self.down_projs[expert_idx]
|
||||
gate_proj = self.gate_projs[expert_idx]
|
||||
up_proj = self.up_projs[expert_idx]
|
||||
def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
|
||||
org_dtype = hidden_states_cpu.dtype
|
||||
org_device = hidden_states_cpu.device
|
||||
hidden_states_cpu = hidden_states_cpu.to(self.device)
|
||||
selected_experts_cpu = selected_experts_cpu.to(self.device)
|
||||
routing_weights_cpu = routing_weights_cpu.to(self.device).to(org_dtype)
|
||||
|
||||
outs += down_proj(self.act_fn(gate_proj(input_tensor)) * up_proj(input_tensor)) * weights[expert_idx]
|
||||
outs = outs.to(device)
|
||||
return outs
|
||||
batch_sequence_length, hidden_dim = hidden_states_cpu.size()
|
||||
|
||||
final_hidden_states = torch.zeros(
|
||||
(batch_sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device
|
||||
)
|
||||
# One hot encode the selected experts to create an expert mask
|
||||
# this will be used to easily index which expert is going to be sollicitated
|
||||
expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.expert_num).permute(2, 1, 0)
|
||||
|
||||
# Loop over all available experts in the model and perform the computation on each expert
|
||||
for expert_idx in range(self.expert_num):
|
||||
if not expert_mask[expert_idx].any():
|
||||
continue
|
||||
idx, top_x = torch.where(expert_mask[expert_idx])
|
||||
# Index the correct hidden states and compute the expert hidden state for
|
||||
# the current expert. We need to make sure to multiply the output hidden
|
||||
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
||||
current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim)
|
||||
G = self.gate_projs[expert_idx].forward(current_state)
|
||||
A = self.act_fn(G)
|
||||
U = self.up_projs[expert_idx].forward(current_state)
|
||||
H = A * U # Element-wise multiplication
|
||||
current_hidden_states = self.down_projs[expert_idx].forward(H) * routing_weights_cpu[top_x, idx, None]
|
||||
# However `index_add_` only support torch tensors for indexing so we'll use
|
||||
# the `top_x` tensor here.
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states)
|
||||
|
||||
return final_hidden_states.to(dtype=org_dtype, device=org_device)
|
||||
|
||||
class KExpertsTorch(KExpertsBase):
|
||||
expert_num: int
|
||||
|
|
|
@ -138,10 +138,10 @@ class KLinearTorch(KLinearBase):
|
|||
if w is None: w = self.load_weight(device=device)
|
||||
|
||||
if isinstance(w, nn.Parameter):
|
||||
self.w = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
||||
self.w = w.to(dtype=self.dtype).T
|
||||
self.has_bias = False
|
||||
elif isinstance(w, tuple):
|
||||
self.w = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
||||
self.w = w[0].to(dtype=self.dtype).T
|
||||
self.bias = w[1].to(dtype=self.dtype)
|
||||
self.has_bias = True
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,143 @@
|
|||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\.([3456][0-9])\\."
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda:0"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "cuda:0"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda:1"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "cuda:1"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\.([3456][0-9])\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
transfer_map:
|
||||
30: "cuda:1"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
|
||||
- match:
|
||||
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
|
@ -24,8 +24,8 @@ class KTransformersInterface(TransformersInterface):
|
|||
self.args = args
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_grad_enabled(False)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=True)
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
|
||||
|
@ -71,10 +71,11 @@ class KTransformersInterface(TransformersInterface):
|
|||
self.streamer = TextStreamer(self.tokenizer)
|
||||
|
||||
def decode_one_tokens(self):
|
||||
if not hasattr(self, "cuda_graph_runner"):
|
||||
device_map = self.model.gguf_loader.tensor_device_map
|
||||
torch_device = get_device("blk.0.self_attn", device_map)
|
||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||
if self.args.use_cuda_graph:
|
||||
if not hasattr(self, "cuda_graph_runner"):
|
||||
self.cuda_graph_runner = CUDAGraphRunner()
|
||||
self.cuda_graph_runner.capture(
|
||||
self.model,
|
||||
|
|
|
@ -93,6 +93,8 @@ class Config(metaclass=Singleton):
|
|||
self.model_name: str = self.model.get("name", "")
|
||||
self.model_device: str = self.model.get("device", "cuda:0")
|
||||
self.gguf_path: Optional[str] = self.model.get("gguf_path", None)
|
||||
self.use_cuda_graph = self.model.get("use_cuda_graph", True)
|
||||
self.trust_remote_code = self.model.get("trust_remote_code", True)
|
||||
# self.model_cache_lens = self.model.get("cache_lens")
|
||||
self.optimize_config_path: Optional[str] = self.model.get(
|
||||
"optimize_config_path", None
|
||||
|
|
Loading…
Add table
Reference in a new issue