mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-19 00:47:44 +00:00
allow loading of the entropy model directly
This commit is contained in:
parent
6ffeb66b53
commit
420326184a
|
@ -6,6 +6,7 @@ from enum import Enum
|
|||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from bytelatent.distributed import get_local_rank
|
||||
|
@ -24,8 +25,12 @@ class PatchingModeEnum(str, Enum):
|
|||
|
||||
|
||||
class PatcherArgs(BaseModel):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
patching_mode: PatchingModeEnum = PatchingModeEnum.entropy
|
||||
patching_device: str = "cuda"
|
||||
entropy_model: nn.Module | None = None
|
||||
entropy_model_checkpoint_dir: str | None = None
|
||||
realtime_patching: bool = False
|
||||
threshold: float = 1.335442066192627
|
||||
|
@ -451,14 +456,19 @@ class Patcher:
|
|||
self.patching_mode = patcher_args.patching_mode
|
||||
self.realtime_patching = patcher_args.realtime_patching
|
||||
if self.realtime_patching:
|
||||
assert (
|
||||
patcher_args.entropy_model_checkpoint_dir is not None
|
||||
), "Cannot require realtime patching without an entropy model checkpoint"
|
||||
entropy_model = load_entropy_model(
|
||||
patcher_args.entropy_model_checkpoint_dir
|
||||
)
|
||||
entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
|
||||
self.entropy_model = entropy_model
|
||||
if patcher_args.entropy_model is not None:
|
||||
self.entropy_model = patcher_args.entropy_model
|
||||
else:
|
||||
assert (
|
||||
patcher_args.entropy_model_checkpoint_dir is not None
|
||||
), "Cannot require realtime patching without an entropy model checkpoint"
|
||||
entropy_model = load_entropy_model(
|
||||
patcher_args.entropy_model_checkpoint_dir
|
||||
)
|
||||
entropy_model, _ = to_device(
|
||||
entropy_model, patcher_args.patching_device
|
||||
)
|
||||
self.entropy_model = entropy_model
|
||||
else:
|
||||
self.entropy_model = None
|
||||
self.threshold = patcher_args.threshold
|
||||
|
|
Loading…
Reference in a new issue