mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 14:51:06 +00:00
support smt and glm4
This commit is contained in:
parent
613f0b7c37
commit
590fcb41cd
5 changed files with 95 additions and 7 deletions
|
@ -28,7 +28,7 @@ from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.cache_utils import Cache, DynamicCache
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
from transformers.generation import GenerationMixin
|
from transformers.generation import GenerationMixin
|
||||||
from transformers.integrations import use_kernel_forward_from_hub
|
# from transformers.integrations import use_kernel_forward_from_hub
|
||||||
from transformers.masking_utils import create_causal_mask
|
from transformers.masking_utils import create_causal_mask
|
||||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from transformers.modeling_layers import GradientCheckpointingLayer
|
from transformers.modeling_layers import GradientCheckpointingLayer
|
||||||
|
@ -36,9 +36,9 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutpu
|
||||||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
from transformers.processing_utils import Unpack
|
from transformers.processing_utils import Unpack
|
||||||
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
|
# from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
|
||||||
# from transformers.utils import auto_docstring, can_return_tuple
|
from transformers.utils import auto_docstring, can_return_tuple
|
||||||
from transformers.utils.generic import check_model_inputs
|
# from transformers.utils.generic import check_model_inputs
|
||||||
from .configuration_glm4_moe import Glm4MoeConfig
|
from .configuration_glm4_moe import Glm4MoeConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1388,6 +1388,78 @@ class KSmallthinkerExperts(BaseInjectedModule, KExpertsBase):
|
||||||
else:
|
else:
|
||||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||||
|
|
||||||
|
class KGlm4Experts(BaseInjectedModule, KExpertsBase):
|
||||||
|
def __init__(self,
|
||||||
|
key: str,
|
||||||
|
gguf_loader: GGUFLoader,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
orig_module: nn.Module,
|
||||||
|
# device: str = "cuda",
|
||||||
|
prefill_device:str = "cuda",
|
||||||
|
prefill_op: str | None = "KExpertsTorch",
|
||||||
|
generate_device: str = "cpu",
|
||||||
|
generate_op: str | None = "KExpertsCPU",
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||||
|
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||||
|
if generate_op is not None:
|
||||||
|
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
|
||||||
|
else:
|
||||||
|
self.generate_experts = None
|
||||||
|
if prefill_op is not None:
|
||||||
|
self.prefill_experts = None
|
||||||
|
self.gpu_mlp_type = prefill_op
|
||||||
|
self.cpu_mlp_type = generate_op
|
||||||
|
self.mode = InferenceState.UNLOAD
|
||||||
|
|
||||||
|
def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True):
|
||||||
|
# TODO support w as input
|
||||||
|
if not mode: mode = InferenceState.GENERATE
|
||||||
|
if mode == InferenceState.GENERATE:
|
||||||
|
# self.prefill_experts.unload()
|
||||||
|
self.generate_experts.load(w, warmup=warmup)
|
||||||
|
self.device = self.generate_experts.device
|
||||||
|
self.mode = mode
|
||||||
|
elif mode == InferenceState.PREFILL:
|
||||||
|
self.generate_experts.unload()
|
||||||
|
self.prefill_experts.load(w, warmup=warmup)
|
||||||
|
self.device = self.prefill_experts.device
|
||||||
|
self.mode = mode
|
||||||
|
elif mode == InferenceState.UNLOAD:
|
||||||
|
self.unload()
|
||||||
|
self.mode = mode
|
||||||
|
self.device = self.generate_experts.device
|
||||||
|
else:
|
||||||
|
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||||
|
|
||||||
|
def unload(self):
|
||||||
|
if self.generate_experts is not None:
|
||||||
|
self.generate_experts.unload()
|
||||||
|
if self.prefill_experts is not None:
|
||||||
|
self.prefill_experts.unload()
|
||||||
|
self.device = self.generate_experts.device
|
||||||
|
|
||||||
|
def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0):
|
||||||
|
if self.mode == InferenceState.GENERATE:
|
||||||
|
assert self.generate_experts is not None, "generate_experts is None"
|
||||||
|
return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
|
||||||
|
elif self.mode == InferenceState.PREFILL:
|
||||||
|
assert self.prefill_experts is not None, "prefill_experts is None"
|
||||||
|
return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
|
||||||
|
else:
|
||||||
|
raise ValueError("load or set_inference_mode before forward")
|
||||||
|
|
||||||
|
def set_inference_mode(self, mode: InferenceState):
|
||||||
|
if mode == InferenceState.GENERATE:
|
||||||
|
self.load(mode=InferenceState.GENERATE, warmup=False)
|
||||||
|
elif mode == InferenceState.PREFILL:
|
||||||
|
self.load(mode=InferenceState.PREFILL, warmup=False)
|
||||||
|
elif mode == InferenceState.UNLOAD:
|
||||||
|
self.unload()
|
||||||
|
else:
|
||||||
|
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||||
|
|
||||||
|
|
||||||
class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
||||||
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
|
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
|
||||||
|
|
|
@ -49,7 +49,7 @@
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism
|
class: ktransformers.operators.experts.KGlm4Experts # custom MoE Kernel with expert paralleism
|
||||||
kwargs:
|
kwargs:
|
||||||
prefill_device: "cuda"
|
prefill_device: "cuda"
|
||||||
prefill_op: None
|
prefill_op: None
|
||||||
|
|
|
@ -2,6 +2,9 @@ import argparse
|
||||||
from ktransformers.server.backend.args import ConfigArgs, default_args
|
from ktransformers.server.backend.args import ConfigArgs, default_args
|
||||||
from ktransformers.util.utils import get_free_ports
|
from ktransformers.util.utils import get_free_ports
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
|
||||||
|
from ktransformers.models.configuration_smallthinker import SmallthinkerConfig
|
||||||
|
from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig
|
||||||
|
|
||||||
class ArgumentParser:
|
class ArgumentParser:
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
|
@ -136,7 +139,20 @@ class ArgumentParser:
|
||||||
self.cfg.server_port = args.port
|
self.cfg.server_port = args.port
|
||||||
self.cfg.user_force_think = args.force_think
|
self.cfg.user_force_think = args.force_think
|
||||||
|
|
||||||
|
if args.model_name == "Qwen3MoeForCausalLM":
|
||||||
|
model_config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||||
|
elif args.model_name == "Glm4MoeForCausalLM":
|
||||||
|
model_config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||||
|
elif args.model_name == "SmallthinkerForCausalLM":
|
||||||
|
model_config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||||
|
model_config._attn_implementation = "eager"
|
||||||
|
else:
|
||||||
|
try:
|
||||||
model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||||
|
except:
|
||||||
|
raise ValueError(f"Model {args.model_name} not supported. Please check your model directory or model name.")
|
||||||
|
|
||||||
|
|
||||||
if model_config.architectures[0] == "Qwen3MoeForCausalLM" or model_config.architectures[0] == "Qwen2MoeForCausalLM" :
|
if model_config.architectures[0] == "Qwen3MoeForCausalLM" or model_config.architectures[0] == "Qwen2MoeForCausalLM" :
|
||||||
args.gpu_memory_size = args.cache_lens*2*2*model_config.num_hidden_layers*model_config.num_key_value_heads*model_config.head_dim
|
args.gpu_memory_size = args.cache_lens*2*2*model_config.num_hidden_layers*model_config.num_key_value_heads*model_config.head_dim
|
||||||
args.architectures = model_config.architectures[0]
|
args.architectures = model_config.architectures[0]
|
||||||
|
|
|
@ -139,7 +139,7 @@ class SafeTensorLoader(ModelLoader):
|
||||||
experts_count = 0
|
experts_count = 0
|
||||||
|
|
||||||
key_no_proj = False
|
key_no_proj = False
|
||||||
if self.has_tensor(f"{base_key}.{experts_count}.up_proj.weight") or self.has_tensor(f"{base_key}.{experts_count}.up.weight"):
|
if self.has_tensor(f"{base_key}.{experts_count}.up.weight"):
|
||||||
key_no_proj = True
|
key_no_proj = True
|
||||||
|
|
||||||
# First, count how many experts we have by checking for expert 0's up_proj
|
# First, count how many experts we have by checking for expert 0's up_proj
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue