Reworking rope WIP

This commit is contained in:
Concedo 2023-07-19 00:54:41 +08:00
commit 374fffb9c6
24 changed files with 600 additions and 256 deletions

View file

@ -348,12 +348,32 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
= gpt2_ctx_v1.hparams.n_ctx = gpt2_ctx_v2.hparams.n_ctx = gpt2_ctx_v3.hparams.n_ctx
= mpt_ctx_v3.hparams.n_ctx = params.n_ctx;
//handle linear rope
if(inputs.linear_rope)
//determine rope scaling params
float rope_freq_scale = 1.0f;
float rope_freq_base = 10000.0f;
if(inputs.rope_freq_scale>0.0f)
{
printf("Using Linear RoPE scaling instead of NTK-Aware scaling.\n");
rope_freq_scale = inputs.rope_freq_scale;
rope_freq_base = inputs.rope_freq_base;
printf("Using Custom RoPE scaling (scale:%.3f, base:%.1f).\n",rope_freq_scale,rope_freq_base);
}
set_ntk_rope_scale_mode(!inputs.linear_rope);
else
{
rope_freq_scale = 1.0f;
if (params.n_ctx <= 2048) //normie mode
{
rope_freq_base = 10000.0f;
}
else
{
//approximate NTK aware ctx
rope_freq_base = (params.n_ctx <= 4096 ? 40880.0f : 82684.0f);
}
printf("Using automatic RoPE scaling (scale:%.3f, base:%.1f)\n",rope_freq_scale,rope_freq_base);
}
gptj_ctx_v3.hparams.rope_freq_scale = neox_ctx_v3.hparams.rope_freq_scale = rope_freq_scale;
gptj_ctx_v3.hparams.rope_freq_base = neox_ctx_v3.hparams.rope_freq_base = rope_freq_base;
//handle custom token bans
banned_tokens.clear();
@ -444,6 +464,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
llama_ctx_params.use_mlock = inputs.use_mlock;
llama_ctx_params.n_gpu_layers = inputs.gpulayers;
llama_ctx_params.main_gpu = cu_parseinfo_maindevice;
llama_ctx_params.rope_freq_base = rope_freq_base;
llama_ctx_params.rope_freq_scale = rope_freq_scale;
llama_ctx_v3 = llama_init_from_file(modelname.c_str(), llama_ctx_params);