diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index e7bb8351b..1d92d016c 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -1514,6 +1514,66 @@ static void load_grammar(const std::string & gammarstr) } } +struct kcpp_embd_batch { //duplcated from llava_embd_batch + std::vector pos; + std::vector n_seq_id; + std::vector seq_id_0; + std::vector seq_ids; + std::vector logits; + llama_batch batch; + kcpp_embd_batch(float * embd, int32_t n_tokens, int32_t npast) { + int32_t seq_id = 0; + pos.resize(n_tokens); + n_seq_id.resize(n_tokens); + seq_ids.resize(n_tokens + 1); + logits.resize(n_tokens); + seq_id_0.resize(1); + seq_id_0[0] = seq_id; + seq_ids [n_tokens] = nullptr; + batch = { + /*n_tokens =*/ n_tokens, + /*tokens =*/ nullptr, + /*embd =*/ embd, + /*pos =*/ pos.data(), + /*n_seq_id =*/ n_seq_id.data(), + /*seq_id =*/ seq_ids.data(), + /*logits =*/ logits.data(), + }; + for (int i = 0; i < n_tokens; i++) { + batch.pos [i] = npast + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i] = seq_id_0.data(); + batch.logits [i] = false; + } + } + kcpp_embd_batch(std::vector & tokens, int32_t npast) { + int32_t seq_id = 0; + int32_t n_tokens = tokens.size(); + pos.resize(n_tokens); + n_seq_id.resize(n_tokens); + seq_ids.resize(n_tokens + 1); + logits.resize(n_tokens); + seq_id_0.resize(1); + seq_id_0[0] = seq_id; + seq_ids [n_tokens] = nullptr; + batch = { + /*n_tokens =*/ n_tokens, + /*tokens =*/ tokens.data(), + /*embd =*/ nullptr, + /*pos =*/ pos.data(), + /*n_seq_id =*/ n_seq_id.data(), + /*seq_id =*/ seq_ids.data(), + /*logits =*/ logits.data(), + }; + for (int i = 0; i < n_tokens; i++) { + batch.pos [i] = npast + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i] = seq_id_0.data(); + batch.logits [i] = false; + } + batch.logits[n_tokens - 1] = true; + } +}; static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num_img_tokens, int n_batch, int * n_past) { int n_embd = llama_n_embd(llama_get_model(ctx_llama)); @@ -1522,8 +1582,9 @@ static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num if (n_eval > n_batch) { n_eval = n_batch; } - llama_batch batch = {int32_t(n_eval), nullptr, (img_embd+i*n_embd), nullptr, nullptr, nullptr, nullptr,}; - if (llama_decode(ctx_llama, batch)) { + float * embd = img_embd+i*n_embd; + kcpp_embd_batch llava_batch = kcpp_embd_batch(embd, n_eval, *n_past); + if (llama_decode(ctx_llama, llava_batch.batch)) { fprintf(stderr, "\n%s : failed to eval image\n", __func__); return false; } @@ -3108,7 +3169,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs) } else if(file_format == FileFormat::GGUF_GENERIC) { - evalres = (llama_decode(llama_ctx_v4, llama_batch_get_one(embd.data(), embdsize))==0); + kcpp_embd_batch batch = kcpp_embd_batch(embd, n_past); + evalres = (llama_decode(llama_ctx_v4, batch.batch)==0); } else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2) { @@ -3485,7 +3547,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs) if(i>0 && sepsize>0) { //add a separator between each image - auto evr = llama_decode(llama_ctx_v4, llama_batch_get_one(llava_sep.data(), sepsize)); + kcpp_embd_batch batch = kcpp_embd_batch(embd, n_past); + auto evr = llama_decode(llama_ctx_v4, batch.batch); if(evr!=0) { printf("\nError when appending llava separator: %d\n",evr);