From 420326184a41876b03bf796abc6cd691aefd73de Mon Sep 17 00:00:00 2001 From: Luciferian Ink Date: Fri, 17 Jan 2025 18:03:18 -0600 Subject: [PATCH] allow loading of the entropy model directly --- bytelatent/data/patcher.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index f8477a3..d32f168 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -6,6 +6,7 @@ from enum import Enum import torch from pydantic import BaseModel +from torch import nn from torch.nn import functional as F from bytelatent.distributed import get_local_rank @@ -24,8 +25,12 @@ class PatchingModeEnum(str, Enum): class PatcherArgs(BaseModel): + class Config: + arbitrary_types_allowed = True + patching_mode: PatchingModeEnum = PatchingModeEnum.entropy patching_device: str = "cuda" + entropy_model: nn.Module | None = None entropy_model_checkpoint_dir: str | None = None realtime_patching: bool = False threshold: float = 1.335442066192627 @@ -451,14 +456,19 @@ class Patcher: self.patching_mode = patcher_args.patching_mode self.realtime_patching = patcher_args.realtime_patching if self.realtime_patching: - assert ( - patcher_args.entropy_model_checkpoint_dir is not None - ), "Cannot require realtime patching without an entropy model checkpoint" - entropy_model = load_entropy_model( - patcher_args.entropy_model_checkpoint_dir - ) - entropy_model, _ = to_device(entropy_model, patcher_args.patching_device) - self.entropy_model = entropy_model + if patcher_args.entropy_model is not None: + self.entropy_model = patcher_args.entropy_model + else: + assert ( + patcher_args.entropy_model_checkpoint_dir is not None + ), "Cannot require realtime patching without an entropy model checkpoint" + entropy_model = load_entropy_model( + patcher_args.entropy_model_checkpoint_dir + ) + entropy_model, _ = to_device( + entropy_model, patcher_args.patching_device + ) + self.entropy_model = entropy_model else: self.entropy_model = None self.threshold = patcher_args.threshold