mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-04 19:49:06 +00:00
Remove debugging lines
This commit is contained in:
parent
f501584a13
commit
321b4b8df8
2 changed files with 0 additions and 27 deletions
|
@ -16,7 +16,6 @@ def load_entropy_model(
|
|||
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
|
||||
reloaded = json.loads(fr.read())
|
||||
|
||||
# torch.set_default_dtype(dtype)
|
||||
model_params = reloaded["entropy_model"]
|
||||
logger.warning(
|
||||
"Update checkpoint to load attn and sliding window args from checkpoint"
|
||||
|
|
|
@ -280,30 +280,6 @@ def cross_attn_mask(
|
|||
def patch_mask(b, h, q_idx, kv_idx):
|
||||
return cross_mask_copy[b, q_idx, kv_idx]
|
||||
|
||||
# print(f"cross_mask: {cross_mask.shape}")
|
||||
# print(f"bs: {bs}, q_len: {q_len}, kv_len: {kv_len}")
|
||||
# print(cross_mask[0, 0, 0])
|
||||
# for i in range(bs):
|
||||
# for j in range(q_len):
|
||||
# for k in range(kv_len):
|
||||
# y = cross_mask[i, j, k]
|
||||
|
||||
# import pickle
|
||||
|
||||
# with open("cross_mask.pkl", "wb") as f:
|
||||
# pickle.dump(cross_mask, f)
|
||||
|
||||
# global GLOBAL
|
||||
# s = f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len}"
|
||||
# if s not in GLOBAL:
|
||||
# GLOBAL.add(s)
|
||||
# print(f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len}")
|
||||
# else:
|
||||
# print(f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len} (skipped)")
|
||||
|
||||
# if q_len >= 51 and kv_len >= 96:
|
||||
# breakpoint()
|
||||
|
||||
block_mask = create_block_mask(
|
||||
patch_mask,
|
||||
B=bs,
|
||||
|
@ -313,8 +289,6 @@ def cross_attn_mask(
|
|||
_compile=True,
|
||||
)
|
||||
|
||||
# print(f"block_mask_shape: {block_mask.shape}")
|
||||
|
||||
return block_mask
|
||||
else:
|
||||
return torch.where(
|
||||
|
|
Loading…
Add table
Reference in a new issue