mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-10 14:27:49 +00:00
Fix demo.py
Summary: Test Plan:
This commit is contained in:
parent
d286bbe415
commit
eb1f0fc318
3 changed files with 6 additions and 6 deletions
6
demo.py
6
demo.py
|
@ -11,6 +11,8 @@ from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
|||
|
||||
|
||||
def main(prompt: str, model_name: str = "blt-1b"):
|
||||
assert model_name in ['blt-1b', 'blt-7b']
|
||||
model_name = model_name.replace('-', '_')
|
||||
distributed_args = DistributedArgs()
|
||||
distributed_args.configure_world()
|
||||
if not torch.distributed.is_initialized():
|
||||
|
@ -25,9 +27,7 @@ def main(prompt: str, model_name: str = "blt-1b"):
|
|||
patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
|
||||
patcher_args.realtime_patching = True
|
||||
print("Loading entropy model and patcher")
|
||||
patcher_args.entropy_model_checkpoint_dir = os.path.join(
|
||||
checkpoint_path, "entropy_model"
|
||||
)
|
||||
patcher_args.entropy_model_checkpoint_dir = os.path.join("hf-weights", "entropy_model")
|
||||
patcher = patcher_args.build()
|
||||
prompts = [prompt]
|
||||
outputs = generate_nocache(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue