mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-16 01:59:42 +00:00
fix precision bug imported by position_ids in 0.2.0
This commit is contained in:
parent
b84524622e
commit
038bc30888
10 changed files with 471 additions and 45 deletions
|
@ -42,7 +42,7 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
|
|||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.orig_module.__init__(
|
||||
orig_module.dim, orig_module.max_position_embeddings, orig_module.base
|
||||
|
@ -72,7 +72,7 @@ class RotaryEmbeddingV3(BaseInjectedModule):
|
|||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
|
@ -122,7 +122,7 @@ class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding):
|
|||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.orig_module.__init__(
|
||||
orig_module.dim,
|
||||
|
@ -160,7 +160,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
|
|||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.orig_module.__init__(
|
||||
orig_module.dim,
|
||||
|
@ -204,7 +204,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
|
|||
# **kwargs,
|
||||
# ):
|
||||
# BaseInjectedModule.__init__(
|
||||
# self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
||||
# self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
# )
|
||||
# self.generate_device = generate_device
|
||||
# self.prefill_device = prefill_device
|
||||
|
@ -230,7 +230,7 @@ class YarnRotaryEmbeddingV3(BaseInjectedModule):
|
|||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
|
@ -332,11 +332,12 @@ class DynamicNTKScalingRotaryEmbedding(
|
|||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, device, **kwargs
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.orig_module.__init__(
|
||||
orig_module.dim,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue