add overridenativecontext flag, stop nagging me

This commit is contained in:
Concedo 2025-08-14 22:54:45 +08:00
parent 7ac0102ed3
commit 5a921a40f9
3 changed files with 69 additions and 17 deletions

View file

@ -2021,7 +2021,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
float rope_freq_scale = 1.0f;
float rope_freq_base = 10000.0f;
bool overwriteRope = false;
if(inputs.rope_freq_scale>0.0f)
if(inputs.rope_freq_scale>0.0f && inputs.overridenativecontext==0)
{
rope_freq_scale = inputs.rope_freq_scale;
rope_freq_base = inputs.rope_freq_base;
@ -2030,8 +2030,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
else
{
const int maxctxtrain = (inputs.overridenativecontext>0?inputs.overridenativecontext:2048);
//Set freq base for all, including non GGUF. If we are using GGUF, this will be overwritten with more accurate values later.
rope_freq_base = CalcGradientAIRopeFreqBase(10000.0f,2048,kcpp_data->n_ctx, GGUFArch::ARCH_DEFAULT);
rope_freq_base = CalcGradientAIRopeFreqBase(10000.0f,maxctxtrain,kcpp_data->n_ctx, GGUFArch::ARCH_DEFAULT);
if(file_format==FileFormat::GGUF_GENERIC)
{
printf("Using automatic RoPE scaling for GGUF. If the model has custom RoPE settings, they'll be used directly instead!\n");
@ -2369,7 +2370,15 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
{
//if the model modifes rope in any way, or uses yarn, use the model values. Otherwise, use our automatic ones
//special exception for llama, which uses auto scale
if((llamamodel->hparams.rope_freq_base_train!=10000.0f && llamamodel->hparams.rope_freq_base_train!=500000.0f) ||
if(inputs.overridenativecontext > 0)
{
printf("Automatic RoPE Scaling: Adjust based on override train context of %d.\n",inputs.overridenativecontext);
rope_freq_base = CalcGradientAIRopeFreqBase(llamamodel->hparams.rope_freq_base_train, inputs.overridenativecontext, kcpp_data->n_ctx, file_format_meta.model_architecture);
llama_ctx_params.rope_freq_base = rope_freq_base;
llama_ctx_params.rope_freq_scale = rope_freq_scale;
printf("Automatic RoPE Scaling: Using (scale:%.3f, base:%.1f).\n", rope_freq_scale, rope_freq_base);
}
else if((llamamodel->hparams.rope_freq_base_train!=10000.0f && llamamodel->hparams.rope_freq_base_train!=500000.0f) ||
llamamodel->hparams.rope_freq_scale_train!=1.0f ||
llamamodel->hparams.rope_scaling_type_train==2)
{