Remove non-serializable type from model config

This commit is contained in:
Gustaf Ahdritz 2025-06-06 11:14:55 -07:00
parent 4c5e51e4de
commit 48256248b5
8 changed files with 57 additions and 35 deletions

View file

@ -403,27 +403,23 @@ def load_consolidated_model_and_tokenizer(consolidated_path, init_distributed=Fa
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
train_args.distributed.model_dtype
]
if train_args.train_entropy_model:
model_args = train_args.entropy_model
model_args.init_device = "cuda"
model_args.init_dtype = param_dtype
model_args.init_dtype = train_args.distributed.model_dtype
model = LMTransformer(model_args)
else:
model_args = train_args.model
model_args.init_device = "cuda"
model_args.init_dtype = param_dtype
model_args.init_dtype = train_args.distributed.model_dtype
model = ByteLatentTransformer(args=model_args)
model = model.eval()
tokenizer = train_args.data.tokenizer_args.build()
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
st_dict = torch.load(f, weights_only=True)
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as fp:
st_dict = torch.load(fp, weights_only=True)
model.load_state_dict(st_dict["model"])