mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-11 06:44:37 +00:00
Remove non-serializable type from model config
This commit is contained in:
parent
4c5e51e4de
commit
48256248b5
8 changed files with 57 additions and 35 deletions
|
@ -403,27 +403,23 @@ def load_consolidated_model_and_tokenizer(consolidated_path, init_distributed=Fa
|
|||
|
||||
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
|
||||
|
||||
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
|
||||
train_args.distributed.model_dtype
|
||||
]
|
||||
|
||||
if train_args.train_entropy_model:
|
||||
model_args = train_args.entropy_model
|
||||
model_args.init_device = "cuda"
|
||||
model_args.init_dtype = param_dtype
|
||||
model_args.init_dtype = train_args.distributed.model_dtype
|
||||
model = LMTransformer(model_args)
|
||||
else:
|
||||
model_args = train_args.model
|
||||
model_args.init_device = "cuda"
|
||||
model_args.init_dtype = param_dtype
|
||||
model_args.init_dtype = train_args.distributed.model_dtype
|
||||
model = ByteLatentTransformer(args=model_args)
|
||||
|
||||
model = model.eval()
|
||||
|
||||
tokenizer = train_args.data.tokenizer_args.build()
|
||||
|
||||
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
|
||||
st_dict = torch.load(f, weights_only=True)
|
||||
with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as fp:
|
||||
st_dict = torch.load(fp, weights_only=True)
|
||||
|
||||
model.load_state_dict(st_dict["model"])
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue