fix precision bug imported by position_ids in 0.2.0

This commit is contained in:
Atream 2025-02-17 09:23:14 +00:00
parent b84524622e
commit 038bc30888
10 changed files with 471 additions and 45 deletions

View file

@ -93,11 +93,11 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
generate_device: str = "cuda",
prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.generate_device = generate_device
self.prefill_device = prefill_device