mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 14:51:06 +00:00
fix rope; update moegate
This commit is contained in:
parent
f873558a89
commit
f748cd29f0
5 changed files with 54 additions and 21 deletions
|
@ -12,6 +12,9 @@ from ktransformers.models.modeling_llama import (
|
|||
LlamaLinearScalingRotaryEmbedding,
|
||||
LlamaDynamicNTKScalingRotaryEmbedding,
|
||||
)
|
||||
from ktransformers.models.modeling_deepseek_v3 import (
|
||||
DeepseekV3RotaryEmbedding
|
||||
)
|
||||
from ktransformers.models.modeling_deepseek import (
|
||||
DeepseekV2YarnRotaryEmbedding,
|
||||
DeepseekV2RotaryEmbedding,
|
||||
|
@ -134,6 +137,31 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
|
|||
self.orig_module.mscale_all_dim,
|
||||
)
|
||||
|
||||
class DeepSeekV3YarnRotaryEmbedding(BaseInjectedModule, DeepseekV3RotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
||||
)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
|
||||
def load(self):
|
||||
# TODO support perlayer prefill
|
||||
self.orig_module.__init__(
|
||||
self.config,
|
||||
device=self.generate_device
|
||||
)
|
||||
return
|
||||
|
||||
class DynamicNTKScalingRotaryEmbedding(
|
||||
BaseInjectedModule, LlamaDynamicNTKScalingRotaryEmbedding
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue