fix rope; update moegate

This commit is contained in:
Azure 2025-02-01 18:05:45 +00:00
parent f873558a89
commit f748cd29f0
5 changed files with 54 additions and 21 deletions

View file

@ -12,6 +12,9 @@ from ktransformers.models.modeling_llama import (
LlamaLinearScalingRotaryEmbedding,
LlamaDynamicNTKScalingRotaryEmbedding,
)
from ktransformers.models.modeling_deepseek_v3 import (
DeepseekV3RotaryEmbedding
)
from ktransformers.models.modeling_deepseek import (
DeepseekV2YarnRotaryEmbedding,
DeepseekV2RotaryEmbedding,
@ -134,6 +137,31 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
self.orig_module.mscale_all_dim,
)
class DeepSeekV3YarnRotaryEmbedding(BaseInjectedModule, DeepseekV3RotaryEmbedding):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
)
self.generate_device = generate_device
self.prefill_device = prefill_device
def load(self):
# TODO support perlayer prefill
self.orig_module.__init__(
self.config,
device=self.generate_device
)
return
class DynamicNTKScalingRotaryEmbedding(
BaseInjectedModule, LlamaDynamicNTKScalingRotaryEmbedding