GradientAI Auto ROPE Base calculation (#910)

* GradientAI Auto ROPE Base calculation

https://gradient.ai/blog/scaling-rotational-embeddings-for-long-context-language-models
has a formula that better fits the ideal rope scaling. 

Tested with Lllama3, checked calculation is correct for llama2. Retains logic for not scaling rope if under trained CTX.

* add in solar scaling logic

Solar based models require the context values to be multiplied by 8. This is (i'm guessing) because the positions as based on a 32k context, but sliding window of 4k.

* Update model_adapter.h

adding in tensor count to identify solar models based on tensor count of 435.

* Update model_adapter.cpp

add in n_tensor count for solar identification

* refactor and cleanup GradientAI rope scaling

---------

Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
askmyteapot 2024-06-13 20:12:00 +10:00 committed by GitHub
parent 49e4c3fd7b
commit 1e72b65c38
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 39 additions and 22 deletions

View file

@ -271,6 +271,9 @@ void print_tok_vec(std::vector<float> &embd)
if(modelarch!="" && fileformatmeta!=nullptr)
{
int n_tensors = gguf_get_n_tensors(ctx);
float freq_base_train = 0;
std::string fkey = modelarch+".context_length";
int keyidx = gguf_find_key(ctx, fkey.c_str());
if (keyidx != -1) {
@ -281,8 +284,14 @@ void print_tok_vec(std::vector<float> &embd)
if (keyidx != -1) {
fileformatmeta->n_expert_count = gguf_get_val_u32(ctx, keyidx);
}
fkey = modelarch+".rope.freq_base";
keyidx = gguf_find_key(ctx, fkey.c_str());
if (keyidx != -1) {
freq_base_train = gguf_get_val_f32(ctx, keyidx);
}
int filever = gguf_get_version(ctx);
fileformatmeta->fileversion = filever;
fileformatmeta->model_architecture = GGUFArch::ARCH_DEFAULT;
if(modelarch=="phi2")
@ -297,7 +306,12 @@ void print_tok_vec(std::vector<float> &embd)
{
fileformatmeta->model_architecture = GGUFArch::ARCH_MAMBA;
}
else if(modelarch=="llama" && freq_base_train==10000.0f && n_tensors==435)
{
fileformatmeta->model_architecture = GGUFArch::ARCH_SOLAR;
}
}
gguf_free(ctx);
}
@ -531,4 +545,4 @@ void print_tok_vec(std::vector<float> &embd)
//remove all tokens between start part and start of LCS in new prompt, thus avoiding shift
//if LCS not found or mismatched, regenerate. chop new prompt and repeat from step B
}
}
}