mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-13 00:29:59 +00:00
support smt and glm4
This commit is contained in:
parent
b66d96db97
commit
613f0b7c37
8 changed files with 115 additions and 28 deletions
|
@ -19,8 +19,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
|
||||
|
||||
class Glm4MoeConfig(PretrainedConfig):
|
||||
|
|
|
@ -154,6 +154,8 @@ class SmallthinkerConfig(PretrainedConfig):
|
|||
self.moe_num_primary_experts = moe_num_primary_experts
|
||||
self.moe_shared_primary_experts = moe_shared_primary_experts
|
||||
self.moe_ffn_hidden_size = moe_ffn_hidden_size
|
||||
self.num_experts_per_tok = moe_num_active_primary_experts
|
||||
self.moe_intermediate_size = moe_ffn_hidden_size
|
||||
self.moe_enable_early_router = moe_enable_early_router
|
||||
self.moe_primary_router_apply_softmax = moe_primary_router_apply_softmax
|
||||
self.moe_num_active_primary_experts = moe_num_active_primary_experts
|
||||
|
|
|
@ -25,19 +25,20 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
||||
from ...utils.generic import check_model_inputs
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.generation import GenerationMixin
|
||||
from transformers.integrations import use_kernel_forward_from_hub
|
||||
from transformers.masking_utils import create_causal_mask
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.modeling_layers import GradientCheckpointingLayer
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
|
||||
# from transformers.utils import auto_docstring, can_return_tuple
|
||||
from transformers.utils.generic import check_model_inputs
|
||||
from .configuration_glm4_moe import Glm4MoeConfig
|
||||
|
||||
|
||||
|
@ -61,7 +62,7 @@ def eager_attention_forward(
|
|||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
# **kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
@ -268,7 +269,7 @@ class Glm4MoeTopkRouter(nn.Module):
|
|||
return topk_indices, topk_weights
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
# @use_kernel_forward_from_hub("RMSNorm")
|
||||
class Glm4MoeRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
|
@ -369,7 +370,7 @@ class Glm4MoeDecoderLayer(GradientCheckpointingLayer):
|
|||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
# **kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
@ -487,7 +488,7 @@ class Glm4MoeModel(Glm4MoePreTrainedModel):
|
|||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@check_model_inputs
|
||||
# @check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
|
@ -498,7 +499,7 @@ class Glm4MoeModel(Glm4MoePreTrainedModel):
|
|||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
# **kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
@ -594,7 +595,7 @@ class Glm4MoeForCausalLM(Glm4MoePreTrainedModel, GenerationMixin):
|
|||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
# **kwargs: Unpack[TransformersKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
|
|
|
@ -1315,6 +1315,80 @@ class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
|
|||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
|
||||
|
||||
class KSmallthinkerExperts(BaseInjectedModule, KExpertsBase):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
prefill_device:str = "cuda",
|
||||
prefill_op: str | None = "KExpertsTorch",
|
||||
generate_device: str = "cpu",
|
||||
generate_op: str | None = "KExpertsCPU",
|
||||
**kwargs):
|
||||
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
if generate_op is not None:
|
||||
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
|
||||
else:
|
||||
self.generate_experts = None
|
||||
if prefill_op is not None:
|
||||
self.prefill_experts = None
|
||||
self.gpu_mlp_type = prefill_op
|
||||
self.cpu_mlp_type = generate_op
|
||||
self.mode = InferenceState.UNLOAD
|
||||
|
||||
def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True):
|
||||
# TODO support w as input
|
||||
if not mode: mode = InferenceState.GENERATE
|
||||
if mode == InferenceState.GENERATE:
|
||||
# self.prefill_experts.unload()
|
||||
self.generate_experts.load(w, warmup=warmup)
|
||||
self.device = self.generate_experts.device
|
||||
self.mode = mode
|
||||
elif mode == InferenceState.PREFILL:
|
||||
self.generate_experts.unload()
|
||||
self.prefill_experts.load(w, warmup=warmup)
|
||||
self.device = self.prefill_experts.device
|
||||
self.mode = mode
|
||||
elif mode == InferenceState.UNLOAD:
|
||||
self.unload()
|
||||
self.mode = mode
|
||||
self.device = self.generate_experts.device
|
||||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
|
||||
def unload(self):
|
||||
if self.generate_experts is not None:
|
||||
self.generate_experts.unload()
|
||||
if self.prefill_experts is not None:
|
||||
self.prefill_experts.unload()
|
||||
self.device = self.generate_experts.device
|
||||
|
||||
def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0):
|
||||
if self.mode == InferenceState.GENERATE:
|
||||
assert self.generate_experts is not None, "generate_experts is None"
|
||||
return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
|
||||
elif self.mode == InferenceState.PREFILL:
|
||||
assert self.prefill_experts is not None, "prefill_experts is None"
|
||||
return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
|
||||
else:
|
||||
raise ValueError("load or set_inference_mode before forward")
|
||||
|
||||
def set_inference_mode(self, mode: InferenceState):
|
||||
if mode == InferenceState.GENERATE:
|
||||
self.load(mode=InferenceState.GENERATE, warmup=False)
|
||||
elif mode == InferenceState.PREFILL:
|
||||
self.load(mode=InferenceState.PREFILL, warmup=False)
|
||||
elif mode == InferenceState.UNLOAD:
|
||||
self.unload()
|
||||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
|
||||
|
||||
class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
||||
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@
|
|||
- match:
|
||||
name: "^model\\.layers\\..*\\.block_sparse_moe\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism
|
||||
class: ktransformers.operators.experts.KSmallthinkerExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda"
|
||||
prefill_op: None
|
||||
|
|
|
@ -138,8 +138,12 @@ class SafeTensorLoader(ModelLoader):
|
|||
base_key = key # e.g. "model.layers.3.mlp.experts"
|
||||
experts_count = 0
|
||||
|
||||
key_no_proj = False
|
||||
if self.has_tensor(f"{base_key}.{experts_count}.up_proj.weight") or self.has_tensor(f"{base_key}.{experts_count}.up.weight"):
|
||||
key_no_proj = True
|
||||
|
||||
# First, count how many experts we have by checking for expert 0's up_proj
|
||||
while self.has_tensor(f"{base_key}.{experts_count}.up_proj.weight"):
|
||||
while self.has_tensor(f"{base_key}.{experts_count}.up_proj.weight") or self.has_tensor(f"{base_key}.{experts_count}.up.weight"):
|
||||
experts_count += 1
|
||||
|
||||
if experts_count == 0:
|
||||
|
@ -152,6 +156,12 @@ class SafeTensorLoader(ModelLoader):
|
|||
|
||||
# Load all expert weights
|
||||
for expert_id in range(experts_count):
|
||||
|
||||
if key_no_proj:
|
||||
up_key = f"{base_key}.{expert_id}.up.weight"
|
||||
gate_key = f"{base_key}.{expert_id}.gate.weight"
|
||||
down_key = f"{base_key}.{expert_id}.down.weight"
|
||||
else:
|
||||
up_key = f"{base_key}.{expert_id}.up_proj.weight"
|
||||
gate_key = f"{base_key}.{expert_id}.gate_proj.weight"
|
||||
down_key = f"{base_key}.{expert_id}.down_proj.weight"
|
||||
|
|
|
@ -16,7 +16,7 @@ dynamic = ["version"]
|
|||
|
||||
dependencies = [
|
||||
"torch >= 2.3.0",
|
||||
"transformers == 4.51.3",
|
||||
"transformers == 4.53.3",
|
||||
"fastapi >= 0.111.0",
|
||||
"uvicorn >= 0.30.1",
|
||||
"langchain >= 0.2.0",
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
fire
|
||||
transformers==4.51.3
|
||||
transformers==4.53.3
|
||||
numpy
|
||||
torch>=2.3.0
|
||||
packaging
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue