Some fixes for entropy model predictions

This commit is contained in:
Srini Iyer 2025-03-13 05:12:56 +00:00
parent c110f6be2a
commit f50157f3e2
2 changed files with 5 additions and 5 deletions

View file

@ -91,7 +91,7 @@ def calculate_entropies(
split = split.reshape(-1, max_length)
if device is not None:
split = split.to(device)
assert torch.all(split >= 0) and torch.all(split < 260)
# assert torch.all(split >= 0) and torch.all(split < 260)
pred = entropy_model(split)
pred = pred.reshape(-1, pred.shape[-1])[
: split.numel() - pad_size, :
@ -103,7 +103,7 @@ def calculate_entropies(
concat_entropies = torch.cat(entropies, dim=0)
concat_entropies = concat_entropies.reshape(tokens.shape)
concat_preds = torch.cat(preds, dim=0)
concat_preds = concat_preds.reshape(tokens.shape[0], tokens.shape[1], -1)
concat_preds = concat_preds.reshape(tokens.shape[0], -1)
return concat_entropies, concat_preds

View file

@ -15,7 +15,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
reloaded = json.loads(fr.read())
torch.set_default_dtype(torch.bfloat16)
model_params = reloaded["model"]
model_params = reloaded["entropy_model"]
logger.warning(
"Update checkpoint to load attn and sliding window args from checkpoint"
)
@ -24,7 +24,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
dim=model_params["dim"],
n_layers=model_params["n_layers"],
n_heads=model_params["n_heads"],
max_seqlen=model_params["max_length"],
max_seqlen=model_params["max_seqlen"],
ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
vocab_size=model_params["vocab_size"],
attn_bias_type="local_block_causal",
@ -34,7 +34,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
)
entropy_model.load_state_dict(
torch.load(state_dict_path, map_location=device), strict=False
torch.load(state_dict_path, map_location=device)["model"], strict=False
)
entropy_model.to(device)
entropy_model = entropy_model.eval()