diff --git a/bytelatent/entropy_model.py b/bytelatent/entropy_model.py index 36726c9..e8a12b3 100644 --- a/bytelatent/entropy_model.py +++ b/bytelatent/entropy_model.py @@ -16,7 +16,6 @@ def load_entropy_model( with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr: reloaded = json.loads(fr.read()) - # torch.set_default_dtype(dtype) model_params = reloaded["entropy_model"] logger.warning( "Update checkpoint to load attn and sliding window args from checkpoint" diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 199c88b..ede27e3 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -280,30 +280,6 @@ def cross_attn_mask( def patch_mask(b, h, q_idx, kv_idx): return cross_mask_copy[b, q_idx, kv_idx] - # print(f"cross_mask: {cross_mask.shape}") - # print(f"bs: {bs}, q_len: {q_len}, kv_len: {kv_len}") - # print(cross_mask[0, 0, 0]) - # for i in range(bs): - # for j in range(q_len): - # for k in range(kv_len): - # y = cross_mask[i, j, k] - - # import pickle - - # with open("cross_mask.pkl", "wb") as f: - # pickle.dump(cross_mask, f) - - # global GLOBAL - # s = f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len}" - # if s not in GLOBAL: - # GLOBAL.add(s) - # print(f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len}") - # else: - # print(f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len} (skipped)") - - # if q_len >= 51 and kv_len >= 96: - # breakpoint() - block_mask = create_block_mask( patch_mask, B=bs, @@ -313,8 +289,6 @@ def cross_attn_mask( _compile=True, ) - # print(f"block_mask_shape: {block_mask.shape}") - return block_mask else: return torch.where(