mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
try new batch api (not actually batching)
This commit is contained in:
parent
8a7d53d838
commit
4b96c3bba8
1 changed files with 67 additions and 4 deletions
|
@ -1514,6 +1514,66 @@ static void load_grammar(const std::string & gammarstr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct kcpp_embd_batch { //duplcated from llava_embd_batch
|
||||||
|
std::vector<int32_t> pos;
|
||||||
|
std::vector<int32_t> n_seq_id;
|
||||||
|
std::vector<int32_t> seq_id_0;
|
||||||
|
std::vector<int32_t *> seq_ids;
|
||||||
|
std::vector<int8_t> logits;
|
||||||
|
llama_batch batch;
|
||||||
|
kcpp_embd_batch(float * embd, int32_t n_tokens, int32_t npast) {
|
||||||
|
int32_t seq_id = 0;
|
||||||
|
pos.resize(n_tokens);
|
||||||
|
n_seq_id.resize(n_tokens);
|
||||||
|
seq_ids.resize(n_tokens + 1);
|
||||||
|
logits.resize(n_tokens);
|
||||||
|
seq_id_0.resize(1);
|
||||||
|
seq_id_0[0] = seq_id;
|
||||||
|
seq_ids [n_tokens] = nullptr;
|
||||||
|
batch = {
|
||||||
|
/*n_tokens =*/ n_tokens,
|
||||||
|
/*tokens =*/ nullptr,
|
||||||
|
/*embd =*/ embd,
|
||||||
|
/*pos =*/ pos.data(),
|
||||||
|
/*n_seq_id =*/ n_seq_id.data(),
|
||||||
|
/*seq_id =*/ seq_ids.data(),
|
||||||
|
/*logits =*/ logits.data(),
|
||||||
|
};
|
||||||
|
for (int i = 0; i < n_tokens; i++) {
|
||||||
|
batch.pos [i] = npast + i;
|
||||||
|
batch.n_seq_id[i] = 1;
|
||||||
|
batch.seq_id [i] = seq_id_0.data();
|
||||||
|
batch.logits [i] = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
kcpp_embd_batch(std::vector<llama_token> & tokens, int32_t npast) {
|
||||||
|
int32_t seq_id = 0;
|
||||||
|
int32_t n_tokens = tokens.size();
|
||||||
|
pos.resize(n_tokens);
|
||||||
|
n_seq_id.resize(n_tokens);
|
||||||
|
seq_ids.resize(n_tokens + 1);
|
||||||
|
logits.resize(n_tokens);
|
||||||
|
seq_id_0.resize(1);
|
||||||
|
seq_id_0[0] = seq_id;
|
||||||
|
seq_ids [n_tokens] = nullptr;
|
||||||
|
batch = {
|
||||||
|
/*n_tokens =*/ n_tokens,
|
||||||
|
/*tokens =*/ tokens.data(),
|
||||||
|
/*embd =*/ nullptr,
|
||||||
|
/*pos =*/ pos.data(),
|
||||||
|
/*n_seq_id =*/ n_seq_id.data(),
|
||||||
|
/*seq_id =*/ seq_ids.data(),
|
||||||
|
/*logits =*/ logits.data(),
|
||||||
|
};
|
||||||
|
for (int i = 0; i < n_tokens; i++) {
|
||||||
|
batch.pos [i] = npast + i;
|
||||||
|
batch.n_seq_id[i] = 1;
|
||||||
|
batch.seq_id [i] = seq_id_0.data();
|
||||||
|
batch.logits [i] = false;
|
||||||
|
}
|
||||||
|
batch.logits[n_tokens - 1] = true;
|
||||||
|
}
|
||||||
|
};
|
||||||
static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num_img_tokens, int n_batch, int * n_past) {
|
static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num_img_tokens, int n_batch, int * n_past) {
|
||||||
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
|
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
|
||||||
|
|
||||||
|
@ -1522,8 +1582,9 @@ static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num
|
||||||
if (n_eval > n_batch) {
|
if (n_eval > n_batch) {
|
||||||
n_eval = n_batch;
|
n_eval = n_batch;
|
||||||
}
|
}
|
||||||
llama_batch batch = {int32_t(n_eval), nullptr, (img_embd+i*n_embd), nullptr, nullptr, nullptr, nullptr,};
|
float * embd = img_embd+i*n_embd;
|
||||||
if (llama_decode(ctx_llama, batch)) {
|
kcpp_embd_batch llava_batch = kcpp_embd_batch(embd, n_eval, *n_past);
|
||||||
|
if (llama_decode(ctx_llama, llava_batch.batch)) {
|
||||||
fprintf(stderr, "\n%s : failed to eval image\n", __func__);
|
fprintf(stderr, "\n%s : failed to eval image\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -3108,7 +3169,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||||
}
|
}
|
||||||
else if(file_format == FileFormat::GGUF_GENERIC)
|
else if(file_format == FileFormat::GGUF_GENERIC)
|
||||||
{
|
{
|
||||||
evalres = (llama_decode(llama_ctx_v4, llama_batch_get_one(embd.data(), embdsize))==0);
|
kcpp_embd_batch batch = kcpp_embd_batch(embd, n_past);
|
||||||
|
evalres = (llama_decode(llama_ctx_v4, batch.batch)==0);
|
||||||
}
|
}
|
||||||
else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||||
{
|
{
|
||||||
|
@ -3485,7 +3547,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||||
if(i>0 && sepsize>0)
|
if(i>0 && sepsize>0)
|
||||||
{
|
{
|
||||||
//add a separator between each image
|
//add a separator between each image
|
||||||
auto evr = llama_decode(llama_ctx_v4, llama_batch_get_one(llava_sep.data(), sepsize));
|
kcpp_embd_batch batch = kcpp_embd_batch(embd, n_past);
|
||||||
|
auto evr = llama_decode(llama_ctx_v4, batch.batch);
|
||||||
if(evr!=0)
|
if(evr!=0)
|
||||||
{
|
{
|
||||||
printf("\nError when appending llava separator: %d\n",evr);
|
printf("\nError when appending llava separator: %d\n",evr);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue