mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
support smt and glm4
This commit is contained in:
parent
1677e90092
commit
b66d96db97
18 changed files with 3519 additions and 16 deletions
|
@ -26,6 +26,8 @@ from ktransformers.operators.base_operator import BaseInjectedModule
|
|||
from ktransformers.util.custom_loader import GGUFLoader
|
||||
from ktransformers.util.utils import InferenceState
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from ktransformers.models.modeling_smallthinker import SmallthinkerRotaryEmbedding
|
||||
from ktransformers.models.modeling_glm4_moe import Glm4MoeRotaryEmbedding
|
||||
import torch
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
|
||||
|
@ -437,4 +439,93 @@ class KQwen3MoeRotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
|
|||
def load(self):
|
||||
self.orig_module.__init__(
|
||||
self.orig_module.config
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class KSmallthinkerRotaryEmbedding(BaseInjectedModule, SmallthinkerRotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.orig_module.__init__(
|
||||
config
|
||||
)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
|
||||
def load(self):
|
||||
self.orig_module.__init__(
|
||||
self.orig_module.config,
|
||||
device = self.generate_device,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
# print(inv_freq_expanded.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
freqs_cis = freqs_cis * self.attention_scaling
|
||||
return freqs_cis
|
||||
|
||||
class KGlm4MoeRotaryEmbedding(BaseInjectedModule, Glm4MoeRotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.orig_module.__init__(
|
||||
config
|
||||
)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
|
||||
def load(self):
|
||||
self.orig_module.__init__(
|
||||
self.orig_module.config,
|
||||
device = self.generate_device,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
# print(inv_freq_expanded.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
freqs_cis = freqs_cis * self.attention_scaling
|
||||
return freqs_cis
|
Loading…
Add table
Add a link
Reference in a new issue