allow loading of the entropy model directly

This commit is contained in:
Luciferian Ink 2025-01-17 18:03:18 -06:00
parent 6ffeb66b53
commit 420326184a

View file

@ -6,6 +6,7 @@ from enum import Enum
import torch import torch
from pydantic import BaseModel from pydantic import BaseModel
from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from bytelatent.distributed import get_local_rank from bytelatent.distributed import get_local_rank
@ -24,8 +25,12 @@ class PatchingModeEnum(str, Enum):
class PatcherArgs(BaseModel): class PatcherArgs(BaseModel):
class Config:
arbitrary_types_allowed = True
patching_mode: PatchingModeEnum = PatchingModeEnum.entropy patching_mode: PatchingModeEnum = PatchingModeEnum.entropy
patching_device: str = "cuda" patching_device: str = "cuda"
entropy_model: nn.Module | None = None
entropy_model_checkpoint_dir: str | None = None entropy_model_checkpoint_dir: str | None = None
realtime_patching: bool = False realtime_patching: bool = False
threshold: float = 1.335442066192627 threshold: float = 1.335442066192627
@ -451,14 +456,19 @@ class Patcher:
self.patching_mode = patcher_args.patching_mode self.patching_mode = patcher_args.patching_mode
self.realtime_patching = patcher_args.realtime_patching self.realtime_patching = patcher_args.realtime_patching
if self.realtime_patching: if self.realtime_patching:
assert ( if patcher_args.entropy_model is not None:
patcher_args.entropy_model_checkpoint_dir is not None self.entropy_model = patcher_args.entropy_model
), "Cannot require realtime patching without an entropy model checkpoint" else:
entropy_model = load_entropy_model( assert (
patcher_args.entropy_model_checkpoint_dir patcher_args.entropy_model_checkpoint_dir is not None
) ), "Cannot require realtime patching without an entropy model checkpoint"
entropy_model, _ = to_device(entropy_model, patcher_args.patching_device) entropy_model = load_entropy_model(
self.entropy_model = entropy_model patcher_args.entropy_model_checkpoint_dir
)
entropy_model, _ = to_device(
entropy_model, patcher_args.patching_device
)
self.entropy_model = entropy_model
else: else:
self.entropy_model = None self.entropy_model = None
self.threshold = patcher_args.threshold self.threshold = patcher_args.threshold