mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-19 00:47:44 +00:00
allow loading of the entropy model directly
This commit is contained in:
parent
6ffeb66b53
commit
420326184a
|
@ -6,6 +6,7 @@ from enum import Enum
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from bytelatent.distributed import get_local_rank
|
from bytelatent.distributed import get_local_rank
|
||||||
|
@ -24,8 +25,12 @@ class PatchingModeEnum(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class PatcherArgs(BaseModel):
|
class PatcherArgs(BaseModel):
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
patching_mode: PatchingModeEnum = PatchingModeEnum.entropy
|
patching_mode: PatchingModeEnum = PatchingModeEnum.entropy
|
||||||
patching_device: str = "cuda"
|
patching_device: str = "cuda"
|
||||||
|
entropy_model: nn.Module | None = None
|
||||||
entropy_model_checkpoint_dir: str | None = None
|
entropy_model_checkpoint_dir: str | None = None
|
||||||
realtime_patching: bool = False
|
realtime_patching: bool = False
|
||||||
threshold: float = 1.335442066192627
|
threshold: float = 1.335442066192627
|
||||||
|
@ -451,14 +456,19 @@ class Patcher:
|
||||||
self.patching_mode = patcher_args.patching_mode
|
self.patching_mode = patcher_args.patching_mode
|
||||||
self.realtime_patching = patcher_args.realtime_patching
|
self.realtime_patching = patcher_args.realtime_patching
|
||||||
if self.realtime_patching:
|
if self.realtime_patching:
|
||||||
assert (
|
if patcher_args.entropy_model is not None:
|
||||||
patcher_args.entropy_model_checkpoint_dir is not None
|
self.entropy_model = patcher_args.entropy_model
|
||||||
), "Cannot require realtime patching without an entropy model checkpoint"
|
else:
|
||||||
entropy_model = load_entropy_model(
|
assert (
|
||||||
patcher_args.entropy_model_checkpoint_dir
|
patcher_args.entropy_model_checkpoint_dir is not None
|
||||||
)
|
), "Cannot require realtime patching without an entropy model checkpoint"
|
||||||
entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
|
entropy_model = load_entropy_model(
|
||||||
self.entropy_model = entropy_model
|
patcher_args.entropy_model_checkpoint_dir
|
||||||
|
)
|
||||||
|
entropy_model, _ = to_device(
|
||||||
|
entropy_model, patcher_args.patching_device
|
||||||
|
)
|
||||||
|
self.entropy_model = entropy_model
|
||||||
else:
|
else:
|
||||||
self.entropy_model = None
|
self.entropy_model = None
|
||||||
self.threshold = patcher_args.threshold
|
self.threshold = patcher_args.threshold
|
||||||
|
|
Loading…
Reference in a new issue