diff --git a/README.md b/README.md index d1fdd26..2e621a1 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,7 @@ From here there are two options: (1) load weights in our train script and (2) lo In your terminal: ```bash -python -m bytelatent.hf load-transformers --entropy-repo facebook/blt-entropy --blt-repo facebook/blt-1b hub --prompt "My test prompt" +python -m bytelatent.hf load-transformers --entropy-repo facebook/blt-entropy --blt-repo facebook/blt-1b --prompt "My test prompt" hub ``` In your own code: diff --git a/bytelatent/config_parser.py b/bytelatent/config_parser.py index b4bda35..5630c44 100644 --- a/bytelatent/config_parser.py +++ b/bytelatent/config_parser.py @@ -1,5 +1,5 @@ import copy -from typing import Type, TypeVar, Any +from typing import Any, Type, TypeVar import omegaconf from omegaconf import DictConfig, OmegaConf @@ -69,6 +69,7 @@ def parse_args_with_default( T = TypeVar("T", bound=BaseModel) + def get_pydantic_default_args(args_cls: Type[T]) -> dict[str, Any]: defaults = {} for field, info in args_cls.model_fields.items(): @@ -76,8 +77,11 @@ def get_pydantic_default_args(args_cls: Type[T]) -> dict[str, Any]: defaults[field] = info.default return defaults + def parse_args_to_pydantic_model( - args_cls: Type[T], cli_args: DictConfig | None = None, instantiate_default_cls: bool = True + args_cls: Type[T], + cli_args: DictConfig | None = None, + instantiate_default_cls: bool = True, ) -> T: if instantiate_default_cls: default_cfg = OmegaConf.create(args_cls().model_dump()) diff --git a/bytelatent/stool.py b/bytelatent/stool.py index 45fb701..618d264 100644 --- a/bytelatent/stool.py +++ b/bytelatent/stool.py @@ -26,8 +26,10 @@ class StoolArgs(BaseModel): dirs_exists_ok: bool = ( False # Wether to copy new code and config and run regardless that dir exists ) - override: bool = False # Whether to delete dump dir and restart, requires confirmation - force_override: bool = False # Does not require interaction + override: bool = ( + False # Whether to delete dump dir and restart, requires confirmation + ) + force_override: bool = False # Does not require interaction nodes: int = -1 # The number of nodes to run the job on. ngpu: int = 8 # The number of GPUs required per node. ncpu: int = 16 # The number of CPUs allocated per GPU. @@ -43,7 +45,6 @@ class StoolArgs(BaseModel): dry_run: bool = False - def copy_dir(input_dir: str, output_dir: str) -> None: print(f"Copying : {input_dir}\n" f"to : {output_dir} ...") assert os.path.isdir(input_dir), f"{input_dir} is not a directory" @@ -130,7 +131,9 @@ def launch_job(args: StoolArgs): job_name = args.name or args.model_conf["name"] dump_dir = os.path.join(args.dump_dir, job_name) or args.model_conf["dump_dir"] print("Creating directories...") - os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override or args.force_override) + os.makedirs( + dump_dir, exist_ok=args.dirs_exists_ok or args.override or args.force_override + ) if args.override or args.force_override: if args.force_override: shutil.rmtree(dump_dir) @@ -161,10 +164,10 @@ def launch_job(args: StoolArgs): else "" ) env = jinja2.Environment( - loader=jinja2.PackageLoader('bytelatent'), + loader=jinja2.PackageLoader("bytelatent"), autoescape=jinja2.select_autoescape(), ) - template = env.get_template('stool_template.sh.jinja') + template = env.get_template("stool_template.sh.jinja") sbatch_jinja = template.render( name=job_name, script=args.script, diff --git a/demo.py b/demo.py index e32dd74..2b0e31c 100644 --- a/demo.py +++ b/demo.py @@ -11,8 +11,8 @@ from bytelatent.tokenizers.blt_tokenizer import BltTokenizer def main(prompt: str, model_name: str = "blt-1b"): - assert model_name in ['blt-1b', 'blt-7b'] - model_name = model_name.replace('-', '_') + assert model_name in ["blt-1b", "blt-7b"] + model_name = model_name.replace("-", "_") distributed_args = DistributedArgs() distributed_args.configure_world() if not torch.distributed.is_initialized(): @@ -27,7 +27,9 @@ def main(prompt: str, model_name: str = "blt-1b"): patcher_args = train_cfg.data.patcher_args.model_copy(deep=True) patcher_args.realtime_patching = True print("Loading entropy model and patcher") - patcher_args.entropy_model_checkpoint_dir = os.path.join("hf-weights", "entropy_model") + patcher_args.entropy_model_checkpoint_dir = os.path.join( + "hf-weights", "entropy_model" + ) patcher = patcher_args.build() prompts = [prompt] outputs = generate_nocache( diff --git a/pyproject.toml b/pyproject.toml index 5fadd7f..1a14fdc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,9 @@ pre_build = [ ] compile_xformers = ['xformers'] dev = [ + "black==24.8.0", "ipython>=9.2.0", + "isort>=6.0.1", "pudb>=2025.1", ] diff --git a/uv.lock b/uv.lock index e4ab7a7..e546815 100644 --- a/uv.lock +++ b/uv.lock @@ -166,6 +166,26 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/77/06/bb80f5f86020c4551da315d78b3ab75e8228f89f0162f2c3a819e407941a/attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3", size = 63815, upload-time = "2025-03-13T11:10:21.14Z" }, ] +[[package]] +name = "black" +version = "24.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "mypy-extensions" }, + { name = "packaging" }, + { name = "pathspec" }, + { name = "platformdirs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/04/b0/46fb0d4e00372f4a86a6f8efa3cb193c9f64863615e39010b1477e010578/black-24.8.0.tar.gz", hash = "sha256:2500945420b6784c38b9ee885af039f5e7471ef284ab03fa35ecdde4688cd83f", size = 644810, upload-time = "2024-08-02T17:43:18.405Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/a8/05fb14195cfef32b7c8d4585a44b7499c2a4b205e1662c427b941ed87054/black-24.8.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:7c046c1d1eeb7aea9335da62472481d3bbf3fd986e093cffd35f4385c94ae368", size = 1646132, upload-time = "2024-08-02T17:49:52.843Z" }, + { url = "https://files.pythonhosted.org/packages/41/77/8d9ce42673e5cb9988f6df73c1c5c1d4e9e788053cccd7f5fb14ef100982/black-24.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:649f6d84ccbae73ab767e206772cc2d7a393a001070a4c814a546afd0d423aed", size = 1448665, upload-time = "2024-08-02T17:47:54.479Z" }, + { url = "https://files.pythonhosted.org/packages/cc/94/eff1ddad2ce1d3cc26c162b3693043c6b6b575f538f602f26fe846dfdc75/black-24.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2b59b250fdba5f9a9cd9d0ece6e6d993d91ce877d121d161e4698af3eb9c1018", size = 1762458, upload-time = "2024-08-02T17:46:19.384Z" }, + { url = "https://files.pythonhosted.org/packages/28/ea/18b8d86a9ca19a6942e4e16759b2fa5fc02bbc0eb33c1b866fcd387640ab/black-24.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:6e55d30d44bed36593c3163b9bc63bf58b3b30e4611e4d88a0c3c239930ed5b2", size = 1436109, upload-time = "2024-08-02T17:46:52.97Z" }, + { url = "https://files.pythonhosted.org/packages/27/1e/83fa8a787180e1632c3d831f7e58994d7aaf23a0961320d21e84f922f919/black-24.8.0-py3-none-any.whl", hash = "sha256:972085c618ee94f402da1af548a4f218c754ea7e5dc70acb168bfaca4c2542ed", size = 206504, upload-time = "2024-08-02T17:43:15.747Z" }, +] + [[package]] name = "blt" version = "0.1.0" @@ -200,7 +220,9 @@ compile-xformers = [ { name = "xformers" }, ] dev = [ + { name = "black" }, { name = "ipython" }, + { name = "isort" }, { name = "pudb" }, ] pre-build = [ @@ -238,7 +260,9 @@ requires-dist = [ [package.metadata.requires-dev] compile-xformers = [{ name = "xformers", git = "https://github.com/facebookresearch/xformers.git?rev=de742ec3d64bd83b1184cc043e541f15d270c148" }] dev = [ + { name = "black", specifier = "==24.8.0" }, { name = "ipython", specifier = ">=9.2.0" }, + { name = "isort", specifier = ">=6.0.1" }, { name = "pudb", specifier = ">=2025.1" }, ] pre-build = [ @@ -610,6 +634,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c", size = 8074, upload-time = "2025-01-17T11:24:33.271Z" }, ] +[[package]] +name = "isort" +version = "6.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b8/21/1e2a441f74a653a144224d7d21afe8f4169e6c7c20bb13aec3a2dc3815e0/isort-6.0.1.tar.gz", hash = "sha256:1cb5df28dfbc742e490c5e41bad6da41b805b0a8be7bc93cd0fb2a8a890ac450", size = 821955, upload-time = "2025-02-26T21:13:16.955Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/11/114d0a5f4dabbdcedc1125dee0888514c3c3b16d3e9facad87ed96fad97c/isort-6.0.1-py3-none-any.whl", hash = "sha256:2dc5d7f65c9678d94c88dfc29161a320eec67328bc97aad576874cb4be1e9615", size = 94186, upload-time = "2025-02-26T21:13:14.911Z" }, +] + [[package]] name = "jedi" version = "0.19.2" @@ -901,6 +934,15 @@ wheels = [ { url = "https://download.pytorch.org/whl/nightly/multiprocess-0.70.16-py39-none-any.whl" }, ] +[[package]] +name = "mypy-extensions" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343, upload-time = "2025-04-22T14:54:24.164Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, +] + [[package]] name = "narwhals" version = "1.37.1" @@ -1197,6 +1239,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650, upload-time = "2024-04-05T09:43:53.299Z" }, ] +[[package]] +name = "pathspec" +version = "0.12.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043, upload-time = "2023-12-10T22:30:45Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" }, +] + [[package]] name = "pathvalidate" version = "3.2.3"