mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
Reworking rope WIP
This commit is contained in:
commit
374fffb9c6
24 changed files with 600 additions and 256 deletions
|
@ -348,12 +348,32 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
= gpt2_ctx_v1.hparams.n_ctx = gpt2_ctx_v2.hparams.n_ctx = gpt2_ctx_v3.hparams.n_ctx
|
||||
= mpt_ctx_v3.hparams.n_ctx = params.n_ctx;
|
||||
|
||||
//handle linear rope
|
||||
if(inputs.linear_rope)
|
||||
//determine rope scaling params
|
||||
float rope_freq_scale = 1.0f;
|
||||
float rope_freq_base = 10000.0f;
|
||||
if(inputs.rope_freq_scale>0.0f)
|
||||
{
|
||||
printf("Using Linear RoPE scaling instead of NTK-Aware scaling.\n");
|
||||
rope_freq_scale = inputs.rope_freq_scale;
|
||||
rope_freq_base = inputs.rope_freq_base;
|
||||
printf("Using Custom RoPE scaling (scale:%.3f, base:%.1f).\n",rope_freq_scale,rope_freq_base);
|
||||
}
|
||||
set_ntk_rope_scale_mode(!inputs.linear_rope);
|
||||
else
|
||||
{
|
||||
rope_freq_scale = 1.0f;
|
||||
if (params.n_ctx <= 2048) //normie mode
|
||||
{
|
||||
rope_freq_base = 10000.0f;
|
||||
}
|
||||
else
|
||||
{
|
||||
//approximate NTK aware ctx
|
||||
rope_freq_base = (params.n_ctx <= 4096 ? 40880.0f : 82684.0f);
|
||||
}
|
||||
|
||||
printf("Using automatic RoPE scaling (scale:%.3f, base:%.1f)\n",rope_freq_scale,rope_freq_base);
|
||||
}
|
||||
gptj_ctx_v3.hparams.rope_freq_scale = neox_ctx_v3.hparams.rope_freq_scale = rope_freq_scale;
|
||||
gptj_ctx_v3.hparams.rope_freq_base = neox_ctx_v3.hparams.rope_freq_base = rope_freq_base;
|
||||
|
||||
//handle custom token bans
|
||||
banned_tokens.clear();
|
||||
|
@ -444,6 +464,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
llama_ctx_params.use_mlock = inputs.use_mlock;
|
||||
llama_ctx_params.n_gpu_layers = inputs.gpulayers;
|
||||
llama_ctx_params.main_gpu = cu_parseinfo_maindevice;
|
||||
llama_ctx_params.rope_freq_base = rope_freq_base;
|
||||
llama_ctx_params.rope_freq_scale = rope_freq_scale;
|
||||
|
||||
llama_ctx_v3 = llama_init_from_file(modelname.c_str(), llama_ctx_params);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue