Some fixes for entropy model predictions ()

Co-authored-by: Srini Iyer <sviyer@meta.com>
This commit is contained in:
Srinivasan Iyer 2025-03-13 10:28:42 -07:00 committed by GitHub
parent 083656ce55
commit fc946a1918
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 5 additions and 5 deletions

View file

@ -91,7 +91,7 @@ def calculate_entropies(
split = split.reshape(-1, max_length) split = split.reshape(-1, max_length)
if device is not None: if device is not None:
split = split.to(device) split = split.to(device)
assert torch.all(split >= 0) and torch.all(split < 260) # assert torch.all(split >= 0) and torch.all(split < 260)
pred = entropy_model(split) pred = entropy_model(split)
pred = pred.reshape(-1, pred.shape[-1])[ pred = pred.reshape(-1, pred.shape[-1])[
: split.numel() - pad_size, : : split.numel() - pad_size, :
@ -103,7 +103,7 @@ def calculate_entropies(
concat_entropies = torch.cat(entropies, dim=0) concat_entropies = torch.cat(entropies, dim=0)
concat_entropies = concat_entropies.reshape(tokens.shape) concat_entropies = concat_entropies.reshape(tokens.shape)
concat_preds = torch.cat(preds, dim=0) concat_preds = torch.cat(preds, dim=0)
concat_preds = concat_preds.reshape(tokens.shape[0], tokens.shape[1], -1) concat_preds = concat_preds.reshape(tokens.shape[0], -1)
return concat_entropies, concat_preds return concat_entropies, concat_preds

View file

@ -15,7 +15,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
reloaded = json.loads(fr.read()) reloaded = json.loads(fr.read())
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
model_params = reloaded["model"] model_params = reloaded["entropy_model"]
logger.warning( logger.warning(
"Update checkpoint to load attn and sliding window args from checkpoint" "Update checkpoint to load attn and sliding window args from checkpoint"
) )
@ -24,7 +24,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
dim=model_params["dim"], dim=model_params["dim"],
n_layers=model_params["n_layers"], n_layers=model_params["n_layers"],
n_heads=model_params["n_heads"], n_heads=model_params["n_heads"],
max_seqlen=model_params["max_length"], max_seqlen=model_params["max_seqlen"],
ffn_dim_multiplier=model_params["ffn_dim_multiplier"], ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
vocab_size=model_params["vocab_size"], vocab_size=model_params["vocab_size"],
attn_bias_type="local_block_causal", attn_bias_type="local_block_causal",
@ -34,7 +34,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
) )
entropy_model.load_state_dict( entropy_model.load_state_dict(
torch.load(state_dict_path, map_location=device), strict=False torch.load(state_dict_path, map_location=device)["model"], strict=False
) )
entropy_model.to(device) entropy_model.to(device)
entropy_model = entropy_model.eval() entropy_model = entropy_model.eval()