llava support is now fully functioning

This commit is contained in:
Concedo 2024-03-11 15:55:32 +08:00
parent d943c739a8
commit 484d90c330
2 changed files with 193 additions and 47 deletions

View file

@ -33,6 +33,11 @@
#include "examples/llava/clip.h"
#include "examples/llava/llava.h"
//const
const int extra_context_handle_fragmentation = 80;
const int LLAVA_TOKEN_IDENTIFIER_A = -998; //alternate between both, changing when image changes
const int LLAVA_TOKEN_IDENTIFIER_B = -999;
//shared
std::string executable_path = "";
std::string lora_filename = "";
@ -80,6 +85,8 @@ static llama_context * llama_ctx_v4;
static clip_ctx * clp_ctx = nullptr; //for llava
static clip_image_u8 * clp_img_data = nullptr; //most recent image
static std::vector<llava_image> llava_images;
static std::string llava_composite_image_signature = ""; //for identifying when the llava images change, we need to invalidate the cache
static int current_llava_identifier = LLAVA_TOKEN_IDENTIFIER_A;
static gpt_params * kcpp_params = nullptr;
static int max_context_limit_at_load = 0;
@ -105,8 +112,6 @@ static std::string concat_output_reader_copy_poll = ""; //for streaming
static std::string concat_output_reader_copy_res = ""; //for gen response
static std::vector<logit_bias> logit_biases;
const int extra_context_handle_fragmentation = 80;
inline bool IsNanCheck(float f)
{
const unsigned int u = *(unsigned int*)&f;
@ -1080,7 +1085,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
}
if(mmproj_filename != "")
if(mmproj_filename != "" && file_format==FileFormat::GGUF_GENERIC)
{
printf("\nAttempting to apply Multimodal Projector: %s\n", mmproj_filename.c_str());
clp_ctx = clip_model_load(mmproj_filename.c_str(), /*verbosity=*/ 1);
@ -1593,6 +1598,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
}
llava_images.clear();
std::string new_llava_composite = "";
for(int x=0;x<images_max;++x)
{
std::string item = inputs.images[x];
@ -1601,6 +1607,17 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
llava_image lv;
lv.b64data = item;
llava_images.push_back(lv);
new_llava_composite += item;
}
}
if(llava_composite_image_signature!=new_llava_composite)
{
//images have changed. swap identifiers to force reprocessing
current_llava_identifier = (current_llava_identifier==LLAVA_TOKEN_IDENTIFIER_A?LLAVA_TOKEN_IDENTIFIER_B:LLAVA_TOKEN_IDENTIFIER_A);
llava_composite_image_signature = new_llava_composite;
if(debugmode==1)
{
printf("\nLLAVA images changed, existing cache invalidated");
}
}
@ -1667,6 +1684,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
// tokenize the prompt
std::vector<int> embd_inp;
std::vector<int> embd_inp_mem; //for storing added memory
std::vector<int> llava_mem; //for storing dummy tokens that will be consumed by llava
int32_t nctx = kcpp_params->n_ctx;
TokenizeString(kcpp_params->prompt, embd_inp, file_format);
if(clp_ctx!=nullptr && clp_img_data!=nullptr)
@ -1686,7 +1707,20 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
if (!llava_image_embed_make_with_clip_img(clp_ctx, kcpp_params->n_threads, clp_img_data, &llava_images[i].clp_img_embd, &llava_images[i].clp_image_tokens)) {
printf("\nError: Clip image %d failed to create embd!",i);
}
printf("\nLLAVA Clip Embed %i used Tokens: %d",i,llava_images[i].clp_image_tokens);
if(debugmode==1)
{
printf("\nLLAVA Clip Embed %i used Tokens: %d",i,llava_images[i].clp_image_tokens);
}
if(llava_images[i].clp_image_tokens>0 && llava_images[i].clp_image_tokens < nctx)
{
for(int n=0;n<llava_images[i].clp_image_tokens;++n)
{
llava_mem.push_back(current_llava_identifier);
}
}else
{
printf("\nWarning: LLAVA Image excluded - Context size too low or not enough clip tokens!\n");
}
}
}
}
@ -1697,8 +1731,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
//truncate to front of the prompt if its too long
int32_t nctx = kcpp_params->n_ctx;
if (embd_inp.size() + kcpp_params->n_predict > nctx)
{
//get bos token
@ -1713,8 +1745,43 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
}
if(llava_mem.size()>0) //stick the llava mem before the added mem
{
if(llava_mem.size() + kcpp_params->n_predict + 4 > nctx)
{
printf("\nWarning: Too many LLaVA tokens, max context exceeded! They will be ignored!\n");
}
else
{
std::vector<int> bos;
TokenizeString("", bos, file_format);
if(embd_inp_mem.size()>0) //remove existing bos if exists
{
if (bos.size()>0 && !embd_inp_mem.empty() && bos[0]==embd_inp_mem[0]) {
embd_inp_mem.erase(embd_inp_mem.begin());
}
}
//append llava dummy tokens
embd_inp_mem.insert(embd_inp_mem.begin(), llava_mem.begin(), llava_mem.end());
if (bos.size() > 0 && embd_inp_mem.size() > 0)
{
embd_inp_mem.insert(embd_inp_mem.begin(), bos[0]); //insert bos at front
}
//shorten memory if needed
if (embd_inp_mem.size() + kcpp_params->n_predict + 4 > nctx)
{
int limit = nctx - (kcpp_params->n_predict + 4);
if (embd_inp_mem.size() > limit) {
embd_inp_mem.resize(limit);
}
}
}
}
//added special memory, overwrite if needed
if(addedmemory!="")
if(embd_inp_mem.size()>0)
{
//remove bos token from prompt, it'll be taken from memory
std::vector<int> bos;
@ -1750,7 +1817,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
//stick memory to front of prompt
embd_inp.insert(embd_inp.begin(), embd_inp_mem.begin(), embd_inp_mem.end());
}
//determine how much npast we have to rewind from the current state
@ -2148,15 +2214,69 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
// some user input remains from prompt or interaction, forward it to processing
while ((int)embd_inp.size() > input_consumed)
{
embd.push_back(embd_inp[input_consumed]);
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[input_consumed]);
current_context_tokens.push_back(embd_inp[input_consumed]);
++input_consumed;
if ((int)embd.size() >= kcpp_params->n_batch)
int currtoken = embd_inp[input_consumed];
if(currtoken==LLAVA_TOKEN_IDENTIFIER_A || currtoken==LLAVA_TOKEN_IDENTIFIER_B) //special llava token hit
{
break;
//if partial batch, dispatch existing first
if(embd.size()>0)
{
break;
}
else
{
//batch is empty, do image processing
int llavatokenscounted = 0;
int llavatokensevaled = 0;
while(input_consumed < embd_inp.size() && (embd_inp[input_consumed]==LLAVA_TOKEN_IDENTIFIER_A || embd_inp[input_consumed]==LLAVA_TOKEN_IDENTIFIER_B))
{
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(currtoken);
current_context_tokens.push_back(currtoken);
++input_consumed;
++llavatokenscounted;
}
for(int i=0;i<llava_images.size();++i)
{
if(allow_regular_prints)
{
printf("\rProcessing LLaVa Embedding %d (%d tokens)",(i+1), llava_images[i].clp_image_tokens);
}
bool err = kcpp_eval_image(llama_ctx_v4,llava_images[i].clp_img_embd,llava_images[i].clp_image_tokens,kcpp_params->n_batch,&n_past);
llavatokensevaled += llava_images[i].clp_image_tokens;
if(!err)
{
llava_composite_image_signature = ""; //force invalidate
fprintf(stderr, "\nFailed to eval llava image at %d!\n",n_past);
output.text = nullptr;
output.status = 0;
generation_finished = true;
return output;
}
}
if(llavatokenscounted!=llavatokensevaled)
{
llava_composite_image_signature = ""; //force invalidate
fprintf(stderr, "\nLLAVA image tokens mismatch at %d! (%d vs %d tokens)\n",n_past,llavatokenscounted,llavatokensevaled);
output.text = nullptr;
output.status = 0;
generation_finished = true;
return output;
}
}
}
else
{
embd.push_back(currtoken);
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(currtoken);
current_context_tokens.push_back(currtoken);
++input_consumed;
if ((int)embd.size() >= kcpp_params->n_batch)
{
break;
}
}
}
}
}