mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
fix precision bug imported by position_ids in 0.2.0
This commit is contained in:
parent
b84524622e
commit
038bc30888
10 changed files with 471 additions and 45 deletions
|
@ -16,14 +16,17 @@ class BaseInjectedModule(nn.Module):
|
|||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
nn.Module.__init__(self)
|
||||
nn.Module.__setattr__(self, "orig_module", orig_module)
|
||||
object.__setattr__(self, "key", key)
|
||||
object.__setattr__(self, "gguf_loader", gguf_loader)
|
||||
object.__setattr__(self, "config", config)
|
||||
object.__setattr__(self, "device", device)
|
||||
object.__setattr__(self, "prefill_device", prefill_device)
|
||||
object.__setattr__(self, "generate_device", generate_device)
|
||||
object.__setattr__(self, "device", generate_device)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
# __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue