diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index 44ff5e9..328b716 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -91,7 +91,7 @@ def calculate_entropies( split = split.reshape(-1, max_length) if device is not None: split = split.to(device) - assert torch.all(split >= 0) and torch.all(split < 260) + # assert torch.all(split >= 0) and torch.all(split < 260) pred = entropy_model(split) pred = pred.reshape(-1, pred.shape[-1])[ : split.numel() - pad_size, : @@ -103,7 +103,7 @@ def calculate_entropies( concat_entropies = torch.cat(entropies, dim=0) concat_entropies = concat_entropies.reshape(tokens.shape) concat_preds = torch.cat(preds, dim=0) - concat_preds = concat_preds.reshape(tokens.shape[0], tokens.shape[1], -1) + concat_preds = concat_preds.reshape(tokens.shape[0], -1) return concat_entropies, concat_preds diff --git a/bytelatent/entropy_model.py b/bytelatent/entropy_model.py index 30754ee..0e11a60 100644 --- a/bytelatent/entropy_model.py +++ b/bytelatent/entropy_model.py @@ -15,7 +15,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp reloaded = json.loads(fr.read()) torch.set_default_dtype(torch.bfloat16) - model_params = reloaded["model"] + model_params = reloaded["entropy_model"] logger.warning( "Update checkpoint to load attn and sliding window args from checkpoint" ) @@ -24,7 +24,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp dim=model_params["dim"], n_layers=model_params["n_layers"], n_heads=model_params["n_heads"], - max_seqlen=model_params["max_length"], + max_seqlen=model_params["max_seqlen"], ffn_dim_multiplier=model_params["ffn_dim_multiplier"], vocab_size=model_params["vocab_size"], attn_bias_type="local_block_causal", @@ -34,7 +34,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp ) entropy_model.load_state_dict( - torch.load(state_dict_path, map_location=device), strict=False + torch.load(state_dict_path, map_location=device)["model"], strict=False ) entropy_model.to(device) entropy_model = entropy_model.eval()