support smt and glm4

This commit is contained in:
djw 2025-07-24 08:40:58 +00:00
parent 1677e90092
commit b66d96db97
18 changed files with 3519 additions and 16 deletions

View file

@ -26,6 +26,8 @@ from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_loader import GGUFLoader
from ktransformers.util.utils import InferenceState
from transformers.configuration_utils import PretrainedConfig
from ktransformers.models.modeling_smallthinker import SmallthinkerRotaryEmbedding
from ktransformers.models.modeling_glm4_moe import Glm4MoeRotaryEmbedding
import torch
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
@ -437,4 +439,93 @@ class KQwen3MoeRotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
def load(self):
self.orig_module.__init__(
self.orig_module.config
)
)
class KSmallthinkerRotaryEmbedding(BaseInjectedModule, SmallthinkerRotaryEmbedding):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
)
self.orig_module.__init__(
config
)
self.generate_device = generate_device
self.prefill_device = prefill_device
def load(self):
self.orig_module.__init__(
self.orig_module.config,
device = self.generate_device,
)
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
# print(inv_freq_expanded.device)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
freqs_cis = freqs_cis * self.attention_scaling
return freqs_cis
class KGlm4MoeRotaryEmbedding(BaseInjectedModule, Glm4MoeRotaryEmbedding):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
)
self.orig_module.__init__(
config
)
self.generate_device = generate_device
self.prefill_device = prefill_device
def load(self):
self.orig_module.__init__(
self.orig_module.config,
device = self.generate_device,
)
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
# print(inv_freq_expanded.device)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
freqs_cis = freqs_cis * self.attention_scaling
return freqs_cis