diff --git a/ktransformers/models/modeling_glm4_moe.py b/ktransformers/models/modeling_glm4_moe.py index 12227b2..0f96c87 100644 --- a/ktransformers/models/modeling_glm4_moe.py +++ b/ktransformers/models/modeling_glm4_moe.py @@ -28,7 +28,7 @@ from torch import nn from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.generation import GenerationMixin -from transformers.integrations import use_kernel_forward_from_hub +# from transformers.integrations import use_kernel_forward_from_hub from transformers.masking_utils import create_causal_mask from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_layers import GradientCheckpointingLayer @@ -36,9 +36,9 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutpu from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack -from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple -# from transformers.utils import auto_docstring, can_return_tuple -from transformers.utils.generic import check_model_inputs +# from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple +from transformers.utils import auto_docstring, can_return_tuple +# from transformers.utils.generic import check_model_inputs from .configuration_glm4_moe import Glm4MoeConfig diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index d242855..0620af0 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -1388,6 +1388,78 @@ class KSmallthinkerExperts(BaseInjectedModule, KExpertsBase): else: raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") +class KGlm4Experts(BaseInjectedModule, KExpertsBase): + def __init__(self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + # device: str = "cuda", + prefill_device:str = "cuda", + prefill_op: str | None = "KExpertsTorch", + generate_device: str = "cpu", + generate_op: str | None = "KExpertsCPU", + **kwargs): + + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + if generate_op is not None: + self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs) + else: + self.generate_experts = None + if prefill_op is not None: + self.prefill_experts = None + self.gpu_mlp_type = prefill_op + self.cpu_mlp_type = generate_op + self.mode = InferenceState.UNLOAD + + def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True): + # TODO support w as input + if not mode: mode = InferenceState.GENERATE + if mode == InferenceState.GENERATE: + # self.prefill_experts.unload() + self.generate_experts.load(w, warmup=warmup) + self.device = self.generate_experts.device + self.mode = mode + elif mode == InferenceState.PREFILL: + self.generate_experts.unload() + self.prefill_experts.load(w, warmup=warmup) + self.device = self.prefill_experts.device + self.mode = mode + elif mode == InferenceState.UNLOAD: + self.unload() + self.mode = mode + self.device = self.generate_experts.device + else: + raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") + + def unload(self): + if self.generate_experts is not None: + self.generate_experts.unload() + if self.prefill_experts is not None: + self.prefill_experts.unload() + self.device = self.generate_experts.device + + def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0): + if self.mode == InferenceState.GENERATE: + assert self.generate_experts is not None, "generate_experts is None" + return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx) + elif self.mode == InferenceState.PREFILL: + assert self.prefill_experts is not None, "prefill_experts is None" + return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx) + else: + raise ValueError("load or set_inference_mode before forward") + + def set_inference_mode(self, mode: InferenceState): + if mode == InferenceState.GENERATE: + self.load(mode=InferenceState.GENERATE, warmup=False) + elif mode == InferenceState.PREFILL: + self.load(mode=InferenceState.PREFILL, warmup=False) + elif mode == InferenceState.UNLOAD: + self.unload() + else: + raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") + class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock): def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0): diff --git a/ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml b/ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml index b3397e5..58dc887 100644 --- a/ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml +++ b/ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml @@ -49,7 +49,7 @@ - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: - class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism + class: ktransformers.operators.experts.KGlm4Experts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: None diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py index 748bd47..8bb34b7 100644 --- a/ktransformers/server/args.py +++ b/ktransformers/server/args.py @@ -2,6 +2,9 @@ import argparse from ktransformers.server.backend.args import ConfigArgs, default_args from ktransformers.util.utils import get_free_ports from transformers import AutoConfig +from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig +from ktransformers.models.configuration_smallthinker import SmallthinkerConfig +from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig class ArgumentParser: def __init__(self, cfg): @@ -136,7 +139,20 @@ class ArgumentParser: self.cfg.server_port = args.port self.cfg.user_force_think = args.force_think - model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) + if args.model_name == "Qwen3MoeForCausalLM": + model_config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True) + elif args.model_name == "Glm4MoeForCausalLM": + model_config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True) + elif args.model_name == "SmallthinkerForCausalLM": + model_config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True) + model_config._attn_implementation = "eager" + else: + try: + model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) + except: + raise ValueError(f"Model {args.model_name} not supported. Please check your model directory or model name.") + + if model_config.architectures[0] == "Qwen3MoeForCausalLM" or model_config.architectures[0] == "Qwen2MoeForCausalLM" : args.gpu_memory_size = args.cache_lens*2*2*model_config.num_hidden_layers*model_config.num_key_value_heads*model_config.head_dim args.architectures = model_config.architectures[0] diff --git a/ktransformers/util/custom_loader.py b/ktransformers/util/custom_loader.py index f45a226..ee08e47 100644 --- a/ktransformers/util/custom_loader.py +++ b/ktransformers/util/custom_loader.py @@ -139,7 +139,7 @@ class SafeTensorLoader(ModelLoader): experts_count = 0 key_no_proj = False - if self.has_tensor(f"{base_key}.{experts_count}.up_proj.weight") or self.has_tensor(f"{base_key}.{experts_count}.up.weight"): + if self.has_tensor(f"{base_key}.{experts_count}.up.weight"): key_no_proj = True # First, count how many experts we have by checking for expert 0's up_proj