From 3dd5e9fb6211ba5cf92d7f0c3b2119d7c836698f Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 22 May 2025 12:15:57 -0700 Subject: [PATCH] Fix demo.py (#116) Summary: Test Plan: --- README.md | 1 + demo.py | 6 +++--- download_blt_weights.py | 5 ++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index ebd3a26..d1fdd26 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,7 @@ The main benefit of this method is that the build is reproducible since there is uv pip install --group pre_build --no-build-isolation uv pip install --group compile_xformers --no-build-isolation uv sync +uv run python download_blt_weights.py uv run python demo.py "A BLT has" ``` diff --git a/demo.py b/demo.py index f79c896..e32dd74 100644 --- a/demo.py +++ b/demo.py @@ -11,6 +11,8 @@ from bytelatent.tokenizers.blt_tokenizer import BltTokenizer def main(prompt: str, model_name: str = "blt-1b"): + assert model_name in ['blt-1b', 'blt-7b'] + model_name = model_name.replace('-', '_') distributed_args = DistributedArgs() distributed_args.configure_world() if not torch.distributed.is_initialized(): @@ -25,9 +27,7 @@ def main(prompt: str, model_name: str = "blt-1b"): patcher_args = train_cfg.data.patcher_args.model_copy(deep=True) patcher_args.realtime_patching = True print("Loading entropy model and patcher") - patcher_args.entropy_model_checkpoint_dir = os.path.join( - checkpoint_path, "entropy_model" - ) + patcher_args.entropy_model_checkpoint_dir = os.path.join("hf-weights", "entropy_model") patcher = patcher_args.build() prompts = [prompt] outputs = generate_nocache( diff --git a/download_blt_weights.py b/download_blt_weights.py index 0035494..d63306e 100644 --- a/download_blt_weights.py +++ b/download_blt_weights.py @@ -4,11 +4,10 @@ import typer from huggingface_hub import snapshot_download -def main(models: list[str] = ["blt-1b", "blt-7b"]): +def main(): if not os.path.exists("hf-weights"): os.makedirs("hf-weights") - for model in models: - snapshot_download(f"facebook/{model}", local_dir=f"hf-weights/{model}") + snapshot_download(f"facebook/blt", local_dir=f"hf-weights") if __name__ == "__main__":