support KExpertsMarlin backend

This commit is contained in:
Azure 2025-02-07 05:57:40 +00:00
parent 0262f954c7
commit c4d9bc6670
5 changed files with 214 additions and 46 deletions

View file

@ -302,13 +302,13 @@ class KExpertsMarlin(KExpertsBase):
if w is None: w = self.load_weights()[self.key]
if isinstance(w, dict):
self.gate = nn.Parameter(torch.from_numpy(w["gate"]))
self.up = nn.Parameter(torch.from_numpy(w["up"]))
self.down = nn.Parameter(torch.from_numpy(w["down"]))
self.gate = w["gate"]
self.up = (w["up"])
self.down = (w["down"])
for i in range(self.expert_num):
self.up_projs[i].load(self.up[i,...], device=device)
self.gate_projs[i].load(self.gate[i,...], device=device)
self.down_projs[i].load(self.down[i,...], device=device)
self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device)
self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device)
self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device)
self.loaded_experts_idx.append(i)
return
@ -342,23 +342,45 @@ class KExpertsMarlin(KExpertsBase):
up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
# tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"])
res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
res = {key:{"gate": nn.Parameter(gate), "up": nn.Parameter(up), "down": nn.Parameter(down), "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
return res
def forward(self, input_tensor:torch.Tensor, expert_ids, weights):
# forward
device = input_tensor.device
input_tensor = input_tensor.to("cuda")
outs = torch.zeros_like(input_tensor)
for expert_idx in range(expert_ids.size(0)):
down_proj = self.down_projs[expert_idx]
gate_proj = self.gate_projs[expert_idx]
up_proj = self.up_projs[expert_idx]
def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
org_dtype = hidden_states_cpu.dtype
org_device = hidden_states_cpu.device
hidden_states_cpu = hidden_states_cpu.to(self.device)
selected_experts_cpu = selected_experts_cpu.to(self.device)
routing_weights_cpu = routing_weights_cpu.to(self.device).to(org_dtype)
batch_sequence_length, hidden_dim = hidden_states_cpu.size()
outs += down_proj(self.act_fn(gate_proj(input_tensor)) * up_proj(input_tensor)) * weights[expert_idx]
outs = outs.to(device)
return outs
final_hidden_states = torch.zeros(
(batch_sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.expert_num).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.expert_num):
if not expert_mask[expert_idx].any():
continue
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim)
G = self.gate_projs[expert_idx].forward(current_state)
A = self.act_fn(G)
U = self.up_projs[expert_idx].forward(current_state)
H = A * U # Element-wise multiplication
current_hidden_states = self.down_projs[expert_idx].forward(H) * routing_weights_cpu[top_x, idx, None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states)
return final_hidden_states.to(dtype=org_dtype, device=org_device)
class KExpertsTorch(KExpertsBase):
expert_num: int
loaded_experts_idx: list[int]

View file

@ -138,10 +138,10 @@ class KLinearTorch(KLinearBase):
if w is None: w = self.load_weight(device=device)
if isinstance(w, nn.Parameter):
self.w = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
self.w = w.to(dtype=self.dtype).T
self.has_bias = False
elif isinstance(w, tuple):
self.w = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
self.w = w[0].to(dtype=self.dtype).T
self.bias = w[1].to(dtype=self.dtype)
self.has_bias = True
else:

View file

@ -0,0 +1,143 @@
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:0"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda:0"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:1"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda:1"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map:
30: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"

View file

@ -24,8 +24,8 @@ class KTransformersInterface(TransformersInterface):
self.args = args
torch.set_default_dtype(torch.bfloat16)
torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=True)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
if config.architectures[0] == "Qwen2MoeForCausalLM":
config._attn_implementation = "flash_attention_2"
@ -71,30 +71,31 @@ class KTransformersInterface(TransformersInterface):
self.streamer = TextStreamer(self.tokenizer)
def decode_one_tokens(self):
if not hasattr(self, "cuda_graph_runner"):
device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device("blk.0.self_attn", device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
self.cuda_graph_runner = CUDAGraphRunner()
self.cuda_graph_runner.capture(
self.model,
self.current_ids,
self.active_cache_position.unsqueeze(0),
self.active_cache_position,
self.cache,
main_device=torch_device,
return_dict=False,
use_cache=True,
)
device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device("blk.0.self_attn", device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
if self.args.use_cuda_graph:
if not hasattr(self, "cuda_graph_runner"):
self.cuda_graph_runner = CUDAGraphRunner()
self.cuda_graph_runner.capture(
self.model,
self.current_ids,
self.active_cache_position.unsqueeze(0),
self.active_cache_position,
self.cache,
main_device=torch_device,
return_dict=False,
use_cache=True,
)
if hasattr(self, "cuda_graph_runner"):
logits = self.cuda_graph_runner(
self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position
)
self.cache.change_seq_length(1)
torch.cuda.synchronize()
logits = logits[0, -1, :]
return self.logits_to_token(logits)
if hasattr(self, "cuda_graph_runner"):
logits = self.cuda_graph_runner(
self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position
)
self.cache.change_seq_length(1)
torch.cuda.synchronize()
logits = logits[0, -1, :]
return self.logits_to_token(logits)
if self.use_static_cache:
mask = torch.ones((1, self.seq_length)).to(torch_device)

View file

@ -93,6 +93,8 @@ class Config(metaclass=Singleton):
self.model_name: str = self.model.get("name", "")
self.model_device: str = self.model.get("device", "cuda:0")
self.gguf_path: Optional[str] = self.model.get("gguf_path", None)
self.use_cuda_graph = self.model.get("use_cuda_graph", True)
self.trust_remote_code = self.model.get("trust_remote_code", True)
# self.model_cache_lens = self.model.get("cache_lens")
self.optimize_config_path: Optional[str] = self.model.get(
"optimize_config_path", None