mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 12:40:02 +00:00
support KExpertsMarlin backend
This commit is contained in:
parent
0262f954c7
commit
c4d9bc6670
5 changed files with 214 additions and 46 deletions
|
@ -302,13 +302,13 @@ class KExpertsMarlin(KExpertsBase):
|
||||||
if w is None: w = self.load_weights()[self.key]
|
if w is None: w = self.load_weights()[self.key]
|
||||||
|
|
||||||
if isinstance(w, dict):
|
if isinstance(w, dict):
|
||||||
self.gate = nn.Parameter(torch.from_numpy(w["gate"]))
|
self.gate = w["gate"]
|
||||||
self.up = nn.Parameter(torch.from_numpy(w["up"]))
|
self.up = (w["up"])
|
||||||
self.down = nn.Parameter(torch.from_numpy(w["down"]))
|
self.down = (w["down"])
|
||||||
for i in range(self.expert_num):
|
for i in range(self.expert_num):
|
||||||
self.up_projs[i].load(self.up[i,...], device=device)
|
self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device)
|
||||||
self.gate_projs[i].load(self.gate[i,...], device=device)
|
self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device)
|
||||||
self.down_projs[i].load(self.down[i,...], device=device)
|
self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device)
|
||||||
self.loaded_experts_idx.append(i)
|
self.loaded_experts_idx.append(i)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -342,22 +342,44 @@ class KExpertsMarlin(KExpertsBase):
|
||||||
up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
|
up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
|
||||||
down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
|
down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
|
||||||
# tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"])
|
# tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"])
|
||||||
res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
|
res = {key:{"gate": nn.Parameter(gate), "up": nn.Parameter(up), "down": nn.Parameter(down), "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def forward(self, input_tensor:torch.Tensor, expert_ids, weights):
|
def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
|
||||||
# forward
|
org_dtype = hidden_states_cpu.dtype
|
||||||
device = input_tensor.device
|
org_device = hidden_states_cpu.device
|
||||||
input_tensor = input_tensor.to("cuda")
|
hidden_states_cpu = hidden_states_cpu.to(self.device)
|
||||||
outs = torch.zeros_like(input_tensor)
|
selected_experts_cpu = selected_experts_cpu.to(self.device)
|
||||||
for expert_idx in range(expert_ids.size(0)):
|
routing_weights_cpu = routing_weights_cpu.to(self.device).to(org_dtype)
|
||||||
down_proj = self.down_projs[expert_idx]
|
|
||||||
gate_proj = self.gate_projs[expert_idx]
|
|
||||||
up_proj = self.up_projs[expert_idx]
|
|
||||||
|
|
||||||
outs += down_proj(self.act_fn(gate_proj(input_tensor)) * up_proj(input_tensor)) * weights[expert_idx]
|
batch_sequence_length, hidden_dim = hidden_states_cpu.size()
|
||||||
outs = outs.to(device)
|
|
||||||
return outs
|
final_hidden_states = torch.zeros(
|
||||||
|
(batch_sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device
|
||||||
|
)
|
||||||
|
# One hot encode the selected experts to create an expert mask
|
||||||
|
# this will be used to easily index which expert is going to be sollicitated
|
||||||
|
expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.expert_num).permute(2, 1, 0)
|
||||||
|
|
||||||
|
# Loop over all available experts in the model and perform the computation on each expert
|
||||||
|
for expert_idx in range(self.expert_num):
|
||||||
|
if not expert_mask[expert_idx].any():
|
||||||
|
continue
|
||||||
|
idx, top_x = torch.where(expert_mask[expert_idx])
|
||||||
|
# Index the correct hidden states and compute the expert hidden state for
|
||||||
|
# the current expert. We need to make sure to multiply the output hidden
|
||||||
|
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
||||||
|
current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim)
|
||||||
|
G = self.gate_projs[expert_idx].forward(current_state)
|
||||||
|
A = self.act_fn(G)
|
||||||
|
U = self.up_projs[expert_idx].forward(current_state)
|
||||||
|
H = A * U # Element-wise multiplication
|
||||||
|
current_hidden_states = self.down_projs[expert_idx].forward(H) * routing_weights_cpu[top_x, idx, None]
|
||||||
|
# However `index_add_` only support torch tensors for indexing so we'll use
|
||||||
|
# the `top_x` tensor here.
|
||||||
|
final_hidden_states.index_add_(0, top_x, current_hidden_states)
|
||||||
|
|
||||||
|
return final_hidden_states.to(dtype=org_dtype, device=org_device)
|
||||||
|
|
||||||
class KExpertsTorch(KExpertsBase):
|
class KExpertsTorch(KExpertsBase):
|
||||||
expert_num: int
|
expert_num: int
|
||||||
|
|
|
@ -138,10 +138,10 @@ class KLinearTorch(KLinearBase):
|
||||||
if w is None: w = self.load_weight(device=device)
|
if w is None: w = self.load_weight(device=device)
|
||||||
|
|
||||||
if isinstance(w, nn.Parameter):
|
if isinstance(w, nn.Parameter):
|
||||||
self.w = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
self.w = w.to(dtype=self.dtype).T
|
||||||
self.has_bias = False
|
self.has_bias = False
|
||||||
elif isinstance(w, tuple):
|
elif isinstance(w, tuple):
|
||||||
self.w = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
|
self.w = w[0].to(dtype=self.dtype).T
|
||||||
self.bias = w[1].to(dtype=self.dtype)
|
self.bias = w[1].to(dtype=self.dtype)
|
||||||
self.has_bias = True
|
self.has_bias = True
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -0,0 +1,143 @@
|
||||||
|
- match:
|
||||||
|
name: "^model.embed_tokens"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cpu"
|
||||||
|
prefill_device: "cpu"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([3456][0-9])\\."
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
generate_op: "KLinearMarlin"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
|
||||||
|
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
generate_op: "KLinearMarlin"
|
||||||
|
prefill_op: "KLinearTorch"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.gate.KMoEGate
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
|
||||||
|
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KExpertsCPU"
|
||||||
|
out_device: "cuda:0"
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||||
|
kwargs:
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
prefill_op: "KExpertsTorch"
|
||||||
|
generate_device: "cpu"
|
||||||
|
generate_op: "KExpertsCPU"
|
||||||
|
out_device: "cuda:1"
|
||||||
|
recursive: False # don't recursively inject submodules of this module
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.([3456][0-9])\\.self_attn$"
|
||||||
|
replace:
|
||||||
|
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
||||||
|
- match:
|
||||||
|
name: "^model$"
|
||||||
|
replace:
|
||||||
|
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||||
|
kwargs:
|
||||||
|
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||||
|
transfer_map:
|
||||||
|
30: "cuda:1"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:0"
|
||||||
|
prefill_device: "cuda:0"
|
||||||
|
|
||||||
|
- match:
|
||||||
|
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)"
|
||||||
|
replace:
|
||||||
|
class: "default"
|
||||||
|
kwargs:
|
||||||
|
generate_device: "cuda:1"
|
||||||
|
prefill_device: "cuda:1"
|
|
@ -24,8 +24,8 @@ class KTransformersInterface(TransformersInterface):
|
||||||
self.args = args
|
self.args = args
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=True)
|
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
|
||||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
|
||||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||||
config._attn_implementation = "flash_attention_2"
|
config._attn_implementation = "flash_attention_2"
|
||||||
|
|
||||||
|
@ -71,10 +71,11 @@ class KTransformersInterface(TransformersInterface):
|
||||||
self.streamer = TextStreamer(self.tokenizer)
|
self.streamer = TextStreamer(self.tokenizer)
|
||||||
|
|
||||||
def decode_one_tokens(self):
|
def decode_one_tokens(self):
|
||||||
if not hasattr(self, "cuda_graph_runner"):
|
|
||||||
device_map = self.model.gguf_loader.tensor_device_map
|
device_map = self.model.gguf_loader.tensor_device_map
|
||||||
torch_device = get_device("blk.0.self_attn", device_map)
|
torch_device = get_device("blk.0.self_attn", device_map)
|
||||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||||
|
if self.args.use_cuda_graph:
|
||||||
|
if not hasattr(self, "cuda_graph_runner"):
|
||||||
self.cuda_graph_runner = CUDAGraphRunner()
|
self.cuda_graph_runner = CUDAGraphRunner()
|
||||||
self.cuda_graph_runner.capture(
|
self.cuda_graph_runner.capture(
|
||||||
self.model,
|
self.model,
|
||||||
|
|
|
@ -93,6 +93,8 @@ class Config(metaclass=Singleton):
|
||||||
self.model_name: str = self.model.get("name", "")
|
self.model_name: str = self.model.get("name", "")
|
||||||
self.model_device: str = self.model.get("device", "cuda:0")
|
self.model_device: str = self.model.get("device", "cuda:0")
|
||||||
self.gguf_path: Optional[str] = self.model.get("gguf_path", None)
|
self.gguf_path: Optional[str] = self.model.get("gguf_path", None)
|
||||||
|
self.use_cuda_graph = self.model.get("use_cuda_graph", True)
|
||||||
|
self.trust_remote_code = self.model.get("trust_remote_code", True)
|
||||||
# self.model_cache_lens = self.model.get("cache_lens")
|
# self.model_cache_lens = self.model.get("cache_lens")
|
||||||
self.optimize_config_path: Optional[str] = self.model.get(
|
self.optimize_config_path: Optional[str] = self.model.get(
|
||||||
"optimize_config_path", None
|
"optimize_config_path", None
|
||||||
|
|
Loading…
Add table
Reference in a new issue