fix precision bug imported by position_ids in 0.2.0

This commit is contained in:
Atream 2025-02-17 09:23:14 +00:00
parent b84524622e
commit 038bc30888
10 changed files with 471 additions and 45 deletions

View file

@ -16,14 +16,17 @@ class BaseInjectedModule(nn.Module):
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs):
nn.Module.__init__(self)
nn.Module.__setattr__(self, "orig_module", orig_module)
object.__setattr__(self, "key", key)
object.__setattr__(self, "gguf_loader", gguf_loader)
object.__setattr__(self, "config", config)
object.__setattr__(self, "device", device)
object.__setattr__(self, "prefill_device", prefill_device)
object.__setattr__(self, "generate_device", generate_device)
object.__setattr__(self, "device", generate_device)
def __getattr__(self, name: str) -> Any:
# __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__,