auto rope scaling changes

This commit is contained in:
Concedo 2024-04-19 23:08:55 +08:00
parent 5b6ac9cc6e
commit b01820dec7
2 changed files with 10 additions and 5 deletions

View file

@ -803,8 +803,10 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
float factor = file_format_meta.n_ctx_train/2048;
effectivenctx = effectivenctx/factor;
}
rope_freq_base = (effectivenctx <= 2048 ? 10000.0f : (effectivenctx <= 3072 ? 26000.0f : (effectivenctx <= 4096 ? 32000.0f : (effectivenctx <= 6144 ? 54000.0f :
(effectivenctx <= 8192 ? 82684.0f : (effectivenctx <= 12288 ? 140000.0f : (effectivenctx <= 16384 ? 200000.0f : (effectivenctx <= 24576 ? 320000.0f : 440000.0f))))))));
float magic_multiplier = 8.0f;
float base_multiplier = effectivenctx*magic_multiplier;
float base_raw = 10000.0f;
rope_freq_base = (effectivenctx <= 2048 ? base_raw : base_multiplier);
}
@ -1049,7 +1051,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
else
{
//if the model modifes rope in any way, use the model values. Otherwise, use our automatic ones
if(llamamodel->hparams.rope_freq_base_train!=10000.0f ||
//special exception for llama, which uses auto scale
if((llamamodel->hparams.rope_freq_base_train!=10000.0f && llamamodel->hparams.rope_freq_base_train!=500000.0f) ||
llamamodel->hparams.rope_freq_scale_train!=1.0f ||
llamamodel->hparams.rope_scaling_type_train==2)
{
@ -1057,6 +1060,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
else
{
float multiplier_rope_base = llamamodel->hparams.rope_freq_base_train/10000.0f;
rope_freq_base *= multiplier_rope_base;
llama_ctx_params.rope_freq_base = rope_freq_base;
llama_ctx_params.rope_freq_scale = rope_freq_scale;
printf("Automatic RoPE Scaling: Using (scale:%.3f, base:%.1f).\n", rope_freq_scale, rope_freq_base);