mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-05 03:59:04 +00:00
parent
d286bbe415
commit
3dd5e9fb62
3 changed files with 6 additions and 6 deletions
|
@ -72,6 +72,7 @@ The main benefit of this method is that the build is reproducible since there is
|
||||||
uv pip install --group pre_build --no-build-isolation
|
uv pip install --group pre_build --no-build-isolation
|
||||||
uv pip install --group compile_xformers --no-build-isolation
|
uv pip install --group compile_xformers --no-build-isolation
|
||||||
uv sync
|
uv sync
|
||||||
|
uv run python download_blt_weights.py
|
||||||
uv run python demo.py "A BLT has"
|
uv run python demo.py "A BLT has"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
6
demo.py
6
demo.py
|
@ -11,6 +11,8 @@ from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
||||||
|
|
||||||
|
|
||||||
def main(prompt: str, model_name: str = "blt-1b"):
|
def main(prompt: str, model_name: str = "blt-1b"):
|
||||||
|
assert model_name in ['blt-1b', 'blt-7b']
|
||||||
|
model_name = model_name.replace('-', '_')
|
||||||
distributed_args = DistributedArgs()
|
distributed_args = DistributedArgs()
|
||||||
distributed_args.configure_world()
|
distributed_args.configure_world()
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
|
@ -25,9 +27,7 @@ def main(prompt: str, model_name: str = "blt-1b"):
|
||||||
patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
|
patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
|
||||||
patcher_args.realtime_patching = True
|
patcher_args.realtime_patching = True
|
||||||
print("Loading entropy model and patcher")
|
print("Loading entropy model and patcher")
|
||||||
patcher_args.entropy_model_checkpoint_dir = os.path.join(
|
patcher_args.entropy_model_checkpoint_dir = os.path.join("hf-weights", "entropy_model")
|
||||||
checkpoint_path, "entropy_model"
|
|
||||||
)
|
|
||||||
patcher = patcher_args.build()
|
patcher = patcher_args.build()
|
||||||
prompts = [prompt]
|
prompts = [prompt]
|
||||||
outputs = generate_nocache(
|
outputs = generate_nocache(
|
||||||
|
|
|
@ -4,11 +4,10 @@ import typer
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
def main(models: list[str] = ["blt-1b", "blt-7b"]):
|
def main():
|
||||||
if not os.path.exists("hf-weights"):
|
if not os.path.exists("hf-weights"):
|
||||||
os.makedirs("hf-weights")
|
os.makedirs("hf-weights")
|
||||||
for model in models:
|
snapshot_download(f"facebook/blt", local_dir=f"hf-weights")
|
||||||
snapshot_download(f"facebook/{model}", local_dir=f"hf-weights/{model}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Add table
Reference in a new issue