Remove debugging lines

This commit is contained in:
Gustaf Ahdritz 2025-06-10 20:41:16 -07:00
parent f501584a13
commit 321b4b8df8
2 changed files with 0 additions and 27 deletions

View file

@ -16,7 +16,6 @@ def load_entropy_model(
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
reloaded = json.loads(fr.read())
# torch.set_default_dtype(dtype)
model_params = reloaded["entropy_model"]
logger.warning(
"Update checkpoint to load attn and sliding window args from checkpoint"

View file

@ -280,30 +280,6 @@ def cross_attn_mask(
def patch_mask(b, h, q_idx, kv_idx):
return cross_mask_copy[b, q_idx, kv_idx]
# print(f"cross_mask: {cross_mask.shape}")
# print(f"bs: {bs}, q_len: {q_len}, kv_len: {kv_len}")
# print(cross_mask[0, 0, 0])
# for i in range(bs):
# for j in range(q_len):
# for k in range(kv_len):
# y = cross_mask[i, j, k]
# import pickle
# with open("cross_mask.pkl", "wb") as f:
# pickle.dump(cross_mask, f)
# global GLOBAL
# s = f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len}"
# if s not in GLOBAL:
# GLOBAL.add(s)
# print(f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len}")
# else:
# print(f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len} (skipped)")
# if q_len >= 51 and kv_len >= 96:
# breakpoint()
block_mask = create_block_mask(
patch_mask,
B=bs,
@ -313,8 +289,6 @@ def cross_attn_mask(
_compile=True,
)
# print(f"block_mask_shape: {block_mask.shape}")
return block_mask
else:
return torch.where(