mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-01 10:09:06 +00:00
parent
3dd5e9fb62
commit
4ae7a62594
6 changed files with 74 additions and 12 deletions
|
@ -93,7 +93,7 @@ From here there are two options: (1) load weights in our train script and (2) lo
|
|||
In your terminal:
|
||||
|
||||
```bash
|
||||
python -m bytelatent.hf load-transformers --entropy-repo facebook/blt-entropy --blt-repo facebook/blt-1b hub --prompt "My test prompt"
|
||||
python -m bytelatent.hf load-transformers --entropy-repo facebook/blt-entropy --blt-repo facebook/blt-1b --prompt "My test prompt" hub
|
||||
```
|
||||
|
||||
In your own code:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import copy
|
||||
from typing import Type, TypeVar, Any
|
||||
from typing import Any, Type, TypeVar
|
||||
|
||||
import omegaconf
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
@ -69,6 +69,7 @@ def parse_args_with_default(
|
|||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
def get_pydantic_default_args(args_cls: Type[T]) -> dict[str, Any]:
|
||||
defaults = {}
|
||||
for field, info in args_cls.model_fields.items():
|
||||
|
@ -76,8 +77,11 @@ def get_pydantic_default_args(args_cls: Type[T]) -> dict[str, Any]:
|
|||
defaults[field] = info.default
|
||||
return defaults
|
||||
|
||||
|
||||
def parse_args_to_pydantic_model(
|
||||
args_cls: Type[T], cli_args: DictConfig | None = None, instantiate_default_cls: bool = True
|
||||
args_cls: Type[T],
|
||||
cli_args: DictConfig | None = None,
|
||||
instantiate_default_cls: bool = True,
|
||||
) -> T:
|
||||
if instantiate_default_cls:
|
||||
default_cfg = OmegaConf.create(args_cls().model_dump())
|
||||
|
|
|
@ -26,8 +26,10 @@ class StoolArgs(BaseModel):
|
|||
dirs_exists_ok: bool = (
|
||||
False # Wether to copy new code and config and run regardless that dir exists
|
||||
)
|
||||
override: bool = False # Whether to delete dump dir and restart, requires confirmation
|
||||
force_override: bool = False # Does not require interaction
|
||||
override: bool = (
|
||||
False # Whether to delete dump dir and restart, requires confirmation
|
||||
)
|
||||
force_override: bool = False # Does not require interaction
|
||||
nodes: int = -1 # The number of nodes to run the job on.
|
||||
ngpu: int = 8 # The number of GPUs required per node.
|
||||
ncpu: int = 16 # The number of CPUs allocated per GPU.
|
||||
|
@ -43,7 +45,6 @@ class StoolArgs(BaseModel):
|
|||
dry_run: bool = False
|
||||
|
||||
|
||||
|
||||
def copy_dir(input_dir: str, output_dir: str) -> None:
|
||||
print(f"Copying : {input_dir}\n" f"to : {output_dir} ...")
|
||||
assert os.path.isdir(input_dir), f"{input_dir} is not a directory"
|
||||
|
@ -130,7 +131,9 @@ def launch_job(args: StoolArgs):
|
|||
job_name = args.name or args.model_conf["name"]
|
||||
dump_dir = os.path.join(args.dump_dir, job_name) or args.model_conf["dump_dir"]
|
||||
print("Creating directories...")
|
||||
os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override or args.force_override)
|
||||
os.makedirs(
|
||||
dump_dir, exist_ok=args.dirs_exists_ok or args.override or args.force_override
|
||||
)
|
||||
if args.override or args.force_override:
|
||||
if args.force_override:
|
||||
shutil.rmtree(dump_dir)
|
||||
|
@ -161,10 +164,10 @@ def launch_job(args: StoolArgs):
|
|||
else ""
|
||||
)
|
||||
env = jinja2.Environment(
|
||||
loader=jinja2.PackageLoader('bytelatent'),
|
||||
loader=jinja2.PackageLoader("bytelatent"),
|
||||
autoescape=jinja2.select_autoescape(),
|
||||
)
|
||||
template = env.get_template('stool_template.sh.jinja')
|
||||
template = env.get_template("stool_template.sh.jinja")
|
||||
sbatch_jinja = template.render(
|
||||
name=job_name,
|
||||
script=args.script,
|
||||
|
|
8
demo.py
8
demo.py
|
@ -11,8 +11,8 @@ from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
|||
|
||||
|
||||
def main(prompt: str, model_name: str = "blt-1b"):
|
||||
assert model_name in ['blt-1b', 'blt-7b']
|
||||
model_name = model_name.replace('-', '_')
|
||||
assert model_name in ["blt-1b", "blt-7b"]
|
||||
model_name = model_name.replace("-", "_")
|
||||
distributed_args = DistributedArgs()
|
||||
distributed_args.configure_world()
|
||||
if not torch.distributed.is_initialized():
|
||||
|
@ -27,7 +27,9 @@ def main(prompt: str, model_name: str = "blt-1b"):
|
|||
patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
|
||||
patcher_args.realtime_patching = True
|
||||
print("Loading entropy model and patcher")
|
||||
patcher_args.entropy_model_checkpoint_dir = os.path.join("hf-weights", "entropy_model")
|
||||
patcher_args.entropy_model_checkpoint_dir = os.path.join(
|
||||
"hf-weights", "entropy_model"
|
||||
)
|
||||
patcher = patcher_args.build()
|
||||
prompts = [prompt]
|
||||
outputs = generate_nocache(
|
||||
|
|
|
@ -46,7 +46,9 @@ pre_build = [
|
|||
]
|
||||
compile_xformers = ['xformers']
|
||||
dev = [
|
||||
"black==24.8.0",
|
||||
"ipython>=9.2.0",
|
||||
"isort>=6.0.1",
|
||||
"pudb>=2025.1",
|
||||
]
|
||||
|
||||
|
|
51
uv.lock
generated
51
uv.lock
generated
|
@ -166,6 +166,26 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/77/06/bb80f5f86020c4551da315d78b3ab75e8228f89f0162f2c3a819e407941a/attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3", size = 63815, upload-time = "2025-03-13T11:10:21.14Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "black"
|
||||
version = "24.8.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
{ name = "mypy-extensions" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pathspec" },
|
||||
{ name = "platformdirs" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/04/b0/46fb0d4e00372f4a86a6f8efa3cb193c9f64863615e39010b1477e010578/black-24.8.0.tar.gz", hash = "sha256:2500945420b6784c38b9ee885af039f5e7471ef284ab03fa35ecdde4688cd83f", size = 644810, upload-time = "2024-08-02T17:43:18.405Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a2/a8/05fb14195cfef32b7c8d4585a44b7499c2a4b205e1662c427b941ed87054/black-24.8.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:7c046c1d1eeb7aea9335da62472481d3bbf3fd986e093cffd35f4385c94ae368", size = 1646132, upload-time = "2024-08-02T17:49:52.843Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/41/77/8d9ce42673e5cb9988f6df73c1c5c1d4e9e788053cccd7f5fb14ef100982/black-24.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:649f6d84ccbae73ab767e206772cc2d7a393a001070a4c814a546afd0d423aed", size = 1448665, upload-time = "2024-08-02T17:47:54.479Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/94/eff1ddad2ce1d3cc26c162b3693043c6b6b575f538f602f26fe846dfdc75/black-24.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2b59b250fdba5f9a9cd9d0ece6e6d993d91ce877d121d161e4698af3eb9c1018", size = 1762458, upload-time = "2024-08-02T17:46:19.384Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/28/ea/18b8d86a9ca19a6942e4e16759b2fa5fc02bbc0eb33c1b866fcd387640ab/black-24.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:6e55d30d44bed36593c3163b9bc63bf58b3b30e4611e4d88a0c3c239930ed5b2", size = 1436109, upload-time = "2024-08-02T17:46:52.97Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/27/1e/83fa8a787180e1632c3d831f7e58994d7aaf23a0961320d21e84f922f919/black-24.8.0-py3-none-any.whl", hash = "sha256:972085c618ee94f402da1af548a4f218c754ea7e5dc70acb168bfaca4c2542ed", size = 206504, upload-time = "2024-08-02T17:43:15.747Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "blt"
|
||||
version = "0.1.0"
|
||||
|
@ -200,7 +220,9 @@ compile-xformers = [
|
|||
{ name = "xformers" },
|
||||
]
|
||||
dev = [
|
||||
{ name = "black" },
|
||||
{ name = "ipython" },
|
||||
{ name = "isort" },
|
||||
{ name = "pudb" },
|
||||
]
|
||||
pre-build = [
|
||||
|
@ -238,7 +260,9 @@ requires-dist = [
|
|||
[package.metadata.requires-dev]
|
||||
compile-xformers = [{ name = "xformers", git = "https://github.com/facebookresearch/xformers.git?rev=de742ec3d64bd83b1184cc043e541f15d270c148" }]
|
||||
dev = [
|
||||
{ name = "black", specifier = "==24.8.0" },
|
||||
{ name = "ipython", specifier = ">=9.2.0" },
|
||||
{ name = "isort", specifier = ">=6.0.1" },
|
||||
{ name = "pudb", specifier = ">=2025.1" },
|
||||
]
|
||||
pre-build = [
|
||||
|
@ -610,6 +634,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c", size = 8074, upload-time = "2025-01-17T11:24:33.271Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "isort"
|
||||
version = "6.0.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b8/21/1e2a441f74a653a144224d7d21afe8f4169e6c7c20bb13aec3a2dc3815e0/isort-6.0.1.tar.gz", hash = "sha256:1cb5df28dfbc742e490c5e41bad6da41b805b0a8be7bc93cd0fb2a8a890ac450", size = 821955, upload-time = "2025-02-26T21:13:16.955Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/11/114d0a5f4dabbdcedc1125dee0888514c3c3b16d3e9facad87ed96fad97c/isort-6.0.1-py3-none-any.whl", hash = "sha256:2dc5d7f65c9678d94c88dfc29161a320eec67328bc97aad576874cb4be1e9615", size = 94186, upload-time = "2025-02-26T21:13:14.911Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jedi"
|
||||
version = "0.19.2"
|
||||
|
@ -901,6 +934,15 @@ wheels = [
|
|||
{ url = "https://download.pytorch.org/whl/nightly/multiprocess-0.70.16-py39-none-any.whl" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mypy-extensions"
|
||||
version = "1.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343, upload-time = "2025-04-22T14:54:24.164Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "narwhals"
|
||||
version = "1.37.1"
|
||||
|
@ -1197,6 +1239,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650, upload-time = "2024-04-05T09:43:53.299Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pathspec"
|
||||
version = "0.12.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043, upload-time = "2023-12-10T22:30:45Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pathvalidate"
|
||||
version = "3.2.3"
|
||||
|
|
Loading…
Add table
Reference in a new issue