Summary:
With >1 GPU, but only 1 node, all reduces fail when inputs are not bf16. This uses a modified copy of torch's grad norm to avoid failures
Test Plan:
- Run unit tests:
- Run single gpu training: `python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100`
- Run 1 node, multi-gpu training `torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100`
Summary:
- Refactor local model configs to be separate and clearer
- Add attention arguments and correct which attention is used in local models
- Preparation for being able to have an entropy train script
- Fix failing unit tests
Test Plan: