mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 09:04:36 +00:00
llama : add high-throughput mode (#14363)
* kv-cache : prepare K/V buffers for separation ggml-ci * batched-bench : fix oob write ggml-ci * llama : add "virtual sequences" ggml-ci * llama : use "stream" vs "virtual sequence" ggml-ci * graph : fix stream splitting when KV cache is not used ggml-ci * kv-cache : add multi-stream save/load support ggml-ci * llama : add "--attn-streams" flag ggml-ci * kv-cache : fix handling when find_slot fails ggml-ci * kv-cache : restore find_slot impl ggml-ci * kv-cache : add comments * kv-cache : add bounds checks for sequence id ggml-ci * cont : add n_seq_max to batch allocr ggml-ci * kv-cache : perform stream copies lazily after llama_synchronize ggml-ci * kv-cache : avoid throwing exceptions across the C boundary ggml-ci * CUDA: 4D FlashAttention support (#14628) * CUDA: 4D FlashAttention support * CUDA: fix WMMA FA kernel * llama : rename attn_streams -> kv_unified ggml-ci * common : rename kv_split -> kv_unified ggml-ci --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
parent
ab14019821
commit
225e7a1438
30 changed files with 1080 additions and 460 deletions
|
@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
|
|||
const llama_vocab & vocab,
|
||||
const llama_memory_i * memory,
|
||||
uint32_t n_embd,
|
||||
uint32_t n_seq_max,
|
||||
bool output_all) {
|
||||
clear();
|
||||
|
||||
|
@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
|
|||
// validate input batch
|
||||
//
|
||||
|
||||
if (n_seq_max > LLAMA_MAX_SEQ) {
|
||||
LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (batch.token) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
|
||||
|
@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
|
|||
if (batch.seq_id) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
||||
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
|
||||
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
|
|||
|
||||
// initialize the starting position for each sequence based on the positions in the memory
|
||||
llama_pos p0[LLAMA_MAX_SEQ];
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (!memory) {
|
||||
// if no memory -> start from 0
|
||||
p0[s] = 0;
|
||||
|
@ -143,7 +149,8 @@ bool llama_batch_allocr::init(
|
|||
// compute stats
|
||||
//
|
||||
|
||||
this->n_embd = n_embd;
|
||||
this->n_embd = n_embd;
|
||||
this->n_seq_max = n_seq_max;
|
||||
|
||||
// count the outputs in this batch
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
|
@ -189,7 +196,7 @@ bool llama_batch_allocr::init(
|
|||
seq_set_map[cur].push_back(i);
|
||||
}
|
||||
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_set_unq.test(s)) {
|
||||
seq_idx[s] = seq_id_unq.size();
|
||||
seq_id_unq.push_back(s);
|
||||
|
@ -241,7 +248,7 @@ bool llama_batch_allocr::init(
|
|||
// consistency checks
|
||||
//
|
||||
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_pos[s].empty()) {
|
||||
continue;
|
||||
}
|
||||
|
@ -284,8 +291,8 @@ bool llama_batch_allocr::init(
|
|||
}
|
||||
|
||||
if (memory) {
|
||||
for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
|
||||
for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
|
||||
for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
|
||||
for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
|
||||
if (seq_cpl[s0][s1]) {
|
||||
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
|
||||
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
|
||||
|
@ -316,12 +323,12 @@ bool llama_batch_allocr::init(
|
|||
//
|
||||
{
|
||||
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
cur_seq_set[s].set();
|
||||
}
|
||||
|
||||
llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
cur_seq_pos[s] = -1;
|
||||
}
|
||||
|
||||
|
@ -692,7 +699,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|||
}
|
||||
}
|
||||
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_set_unq.test(s)) {
|
||||
ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
|
||||
ubatch.seq_id_unq.push_back(s);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue