mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-07 04:59:07 +00:00
Remove debugging lines
This commit is contained in:
parent
f501584a13
commit
321b4b8df8
2 changed files with 0 additions and 27 deletions
|
@ -16,7 +16,6 @@ def load_entropy_model(
|
||||||
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
|
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
|
||||||
reloaded = json.loads(fr.read())
|
reloaded = json.loads(fr.read())
|
||||||
|
|
||||||
# torch.set_default_dtype(dtype)
|
|
||||||
model_params = reloaded["entropy_model"]
|
model_params = reloaded["entropy_model"]
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Update checkpoint to load attn and sliding window args from checkpoint"
|
"Update checkpoint to load attn and sliding window args from checkpoint"
|
||||||
|
|
|
@ -280,30 +280,6 @@ def cross_attn_mask(
|
||||||
def patch_mask(b, h, q_idx, kv_idx):
|
def patch_mask(b, h, q_idx, kv_idx):
|
||||||
return cross_mask_copy[b, q_idx, kv_idx]
|
return cross_mask_copy[b, q_idx, kv_idx]
|
||||||
|
|
||||||
# print(f"cross_mask: {cross_mask.shape}")
|
|
||||||
# print(f"bs: {bs}, q_len: {q_len}, kv_len: {kv_len}")
|
|
||||||
# print(cross_mask[0, 0, 0])
|
|
||||||
# for i in range(bs):
|
|
||||||
# for j in range(q_len):
|
|
||||||
# for k in range(kv_len):
|
|
||||||
# y = cross_mask[i, j, k]
|
|
||||||
|
|
||||||
# import pickle
|
|
||||||
|
|
||||||
# with open("cross_mask.pkl", "wb") as f:
|
|
||||||
# pickle.dump(cross_mask, f)
|
|
||||||
|
|
||||||
# global GLOBAL
|
|
||||||
# s = f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len}"
|
|
||||||
# if s not in GLOBAL:
|
|
||||||
# GLOBAL.add(s)
|
|
||||||
# print(f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len}")
|
|
||||||
# else:
|
|
||||||
# print(f"bs_{bs}_q_len_{q_len}_kv_len_{kv_len} (skipped)")
|
|
||||||
|
|
||||||
# if q_len >= 51 and kv_len >= 96:
|
|
||||||
# breakpoint()
|
|
||||||
|
|
||||||
block_mask = create_block_mask(
|
block_mask = create_block_mask(
|
||||||
patch_mask,
|
patch_mask,
|
||||||
B=bs,
|
B=bs,
|
||||||
|
@ -313,8 +289,6 @@ def cross_attn_mask(
|
||||||
_compile=True,
|
_compile=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# print(f"block_mask_shape: {block_mask.shape}")
|
|
||||||
|
|
||||||
return block_mask
|
return block_mask
|
||||||
else:
|
else:
|
||||||
return torch.where(
|
return torch.where(
|
||||||
|
|
Loading…
Add table
Reference in a new issue