llama : add high-throughput mode (#14363)

* kv-cache : prepare K/V buffers for separation

ggml-ci

* batched-bench : fix oob write

ggml-ci

* llama : add "virtual sequences"

ggml-ci

* llama : use "stream" vs "virtual sequence"

ggml-ci

* graph : fix stream splitting when KV cache is not used

ggml-ci

* kv-cache : add multi-stream save/load support

ggml-ci

* llama : add "--attn-streams" flag

ggml-ci

* kv-cache : fix handling when find_slot fails

ggml-ci

* kv-cache : restore find_slot impl

ggml-ci

* kv-cache : add comments

* kv-cache : add bounds checks for sequence id

ggml-ci

* cont : add n_seq_max to batch allocr

ggml-ci

* kv-cache : perform stream copies lazily after llama_synchronize

ggml-ci

* kv-cache : avoid throwing exceptions across the C boundary

ggml-ci

* CUDA: 4D FlashAttention support (#14628)

* CUDA: 4D FlashAttention support

* CUDA: fix WMMA FA kernel

* llama : rename attn_streams -> kv_unified

ggml-ci

* common : rename kv_split -> kv_unified

ggml-ci

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
Georgi Gerganov 2025-07-16 16:35:42 +03:00 committed by GitHub
parent ab14019821
commit 225e7a1438
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 1080 additions and 460 deletions

View file

@ -1464,6 +1464,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.swa_full = true; params.swa_full = true;
} }
).set_env("LLAMA_ARG_SWA_FULL")); ).set_env("LLAMA_ARG_SWA_FULL"));
add_opt(common_arg(
{"--kv-unified", "-kvu"},
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)", params.kv_unified ? "true" : "false"),
[](common_params & params) {
params.kv_unified = true;
}
).set_env("LLAMA_ARG_KV_SPLIT"));
add_opt(common_arg( add_opt(common_arg(
{"--no-context-shift"}, {"--no-context-shift"},
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"), string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),

View file

@ -1163,6 +1163,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.no_perf = params.no_perf; cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload; cparams.op_offload = !params.no_op_offload;
cparams.swa_full = params.swa_full; cparams.swa_full = params.swa_full;
cparams.kv_unified = params.kv_unified;
cparams.type_k = params.cache_type_k; cparams.type_k = params.cache_type_k;
cparams.type_v = params.cache_type_v; cparams.type_v = params.cache_type_v;

View file

@ -341,6 +341,7 @@ struct common_params {
bool no_perf = false; // disable performance metrics bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation bool ctx_shift = true; // context shift on inifinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
bool kv_unified = false; // enable unified KV cache
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool use_mmap = true; // use mmap for faster loads bool use_mmap = true; // use mmap for faster loads

View file

@ -224,6 +224,7 @@ int main(int argc, char ** argv) {
auto & client = clients[i]; auto & client = clients[i];
client.id = i; client.id = i;
client.smpl = common_sampler_init(model, params.sampling); client.smpl = common_sampler_init(model, params.sampling);
//params.sampling.seed++;
} }
std::vector<llama_token> tokens_system; std::vector<llama_token> tokens_system;
@ -345,7 +346,7 @@ int main(int argc, char ** argv) {
client.n_decoded = 0; client.n_decoded = 0;
client.i_batch = batch.n_tokens - 1; client.i_batch = batch.n_tokens - 1;
LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur); LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, prompt = %d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur, client.n_prompt);
g_seq_id += 1; g_seq_id += 1;

View file

@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)(
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32, const int ne32,
const int ne33,
const int nb31, const int nb31,
const int nb32, const int nb32,
const int nb33,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
template<int D, int ncols1, int ncols2> // D == head size template<int D, int ncols1, int ncols2> // D == head size
__launch_bounds__(D, 1) __launch_bounds__(D, 1)
static __global__ void flash_attn_stream_k_fixup( static __global__ void flash_attn_stream_k_fixup(
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
constexpr int ncols = ncols1*ncols2; constexpr int ncols = ncols1*ncols2;
const int bidx0 = blockIdx.x; const int bidx0 = blockIdx.x;
@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup(
const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_k = ne11 / FATTN_KQ_STRIDE;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const bool did_not_have_any_data = kbc0 == kbc0_stop; const bool did_not_have_any_data = kbc0 == kbc0_stop;
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup(
return; return;
} }
const int channel = kbc0 / (iter_k*iter_j); const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
if (jt*ncols1 + j >= ne01) { if (jt*ncols1 + j >= ne01) {
return; return;
} }
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid; dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
// Load the partial result that needs a fixup: // Load the partial result that needs a fixup:
float dst_val = 0.0f; float dst_val = 0.0f;
@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup(
int bidx = bidx0 - 1; int bidx = bidx0 - 1;
int kbc_stop = kbc0; int kbc_stop = kbc0;
while(true) { while(true) {
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x; const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
if (kbc == kbc_stop) { // Did not have any data. if (kbc == kbc_stop) { // Did not have any data.
bidx--; bidx--;
kbc_stop = kbc; kbc_stop = kbc;
@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results(
const float2 * __restrict__ VKQ_meta, const float2 * __restrict__ VKQ_meta,
float * __restrict__ dst, float * __restrict__ dst,
const int parallel_blocks) { const int parallel_blocks) {
VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x; // Dimension 0: threadIdx.x
VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x; // Dimension 1: blockIdx.x
dst += D * gridDim.z*blockIdx.x; // Dimension 2: blockIdx.y
// Dimension 3: blockIdx.z
// Memory layout is permuted with [0, 2, 1, 3]
const int ne01 = gridDim.x;
const int ne02 = gridDim.y;
const int col = blockIdx.x;
const int head = blockIdx.y;
const int sequence = blockIdx.z;
const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
VKQ_parts += j_dst_unrolled * parallel_blocks*D;
VKQ_meta += j_dst_unrolled * parallel_blocks;
dst += j_dst_unrolled * D;
const int tid = threadIdx.x; const int tid = threadIdx.x;
__builtin_assume(tid < D); __builtin_assume(tid < D);
extern __shared__ float2 meta[]; extern __shared__ float2 meta[];
for (int i = tid; i < 2*parallel_blocks; i += D) { for (int i = tid; i < 2*parallel_blocks; i += D) {
((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i]; ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
} }
__syncthreads(); __syncthreads();
@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results(
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
*((uint32_t *) &KQ_max_scale) &= ftz_mask; *((uint32_t *) &KQ_max_scale) &= ftz_mask;
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid]; VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
VKQ_denominator += KQ_max_scale * meta[l].y; VKQ_denominator += KQ_max_scale * meta[l].y;
} }
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; dst[tid] = VKQ_numerator / VKQ_denominator;
} }
[[noreturn]] [[noreturn]]
@ -705,8 +723,6 @@ void launch_fattn(
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
GGML_ASSERT(Q->ne[3] == 1);
ggml_cuda_pool & pool = ctx.pool(); ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream(); cudaStream_t main_stream = ctx.stream();
const int id = ggml_cuda_get_device(); const int id = ggml_cuda_get_device();
@ -853,8 +869,8 @@ void launch_fattn(
scale, max_bias, m0, m1, n_head_log2, logit_softcap, scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
Q->nb[1], Q->nb[2], Q->nb[3], Q->nb[1], Q->nb[2], Q->nb[3],
nb11, nb12, nb13, nb11, nb12, nb13,
nb21, nb22, nb23, nb21, nb22, nb23,
@ -869,11 +885,11 @@ void launch_fattn(
flash_attn_stream_k_fixup<DV, ncols1, ncols2> flash_attn_stream_k_fixup<DV, ncols1, ncols2>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>> <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
} }
} else if (parallel_blocks > 1) { } else if (parallel_blocks > 1) {
const dim3 block_dim_combine(DV, 1, 1); const dim3 block_dim_combine(DV, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z); const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
flash_attn_combine_results<DV> flash_attn_combine_results<DV>

View file

@ -1224,8 +1224,10 @@ static __global__ void flash_attn_ext_f16(
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32, const int ne32,
const int ne33,
const int nb31, const int nb31,
const int nb32, const int nb32,
const int nb33,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -1274,8 +1276,8 @@ static __global__ void flash_attn_ext_f16(
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
// kbc == k block continuous, current index in continuous ijk space. // kbc == k block continuous, current index in continuous ijk space.
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@ -1285,18 +1287,19 @@ static __global__ void flash_attn_ext_f16(
int kb0_start = kbc % iter_k; int kb0_start = kbc % iter_k;
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
while (kbc < kbc_stop && kb0_stop == iter_k) { while (kbc < kbc_stop && kb0_stop == iter_k) {
const int channel = kbc / (iter_k*iter_j); const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter; const int kb0_start_kernel = kb0_start * kb_niter;
const int kb0_stop_kernel = kb0_stop * kb_niter; const int kb0_stop_kernel = kb0_stop * kb_niter;
@ -1325,18 +1328,19 @@ static __global__ void flash_attn_ext_f16(
return; return;
} }
const int channel = kbc / (iter_k*iter_j); const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter; const int kb0_start_kernel = kb0_start * kb_niter;
const int kb0_stop_kernel = kb0_stop * kb_niter; const int kb0_stop_kernel = kb0_stop * kb_niter;

View file

@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f16(
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32, const int ne32,
const int ne33,
const int nb31, const int nb31,
const int nb32, const int nb32,
const int nb33,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -62,15 +64,17 @@ static __global__ void flash_attn_tile_ext_f16(
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const int stride_KV2 = nb11 / sizeof(half2); const int stride_KV2 = nb11 / sizeof(half2);
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef); const half slopeh = __float2half(slopef);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@ -255,6 +259,8 @@ static __global__ void flash_attn_tile_ext_f16(
__syncthreads(); __syncthreads();
} }
float2 * dst2 = (float2 *) dst;
#pragma unroll #pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y; const int j_VKQ = j_VKQ_0 + threadIdx.y;
@ -266,21 +272,21 @@ static __global__ void flash_attn_tile_ext_f16(
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]); half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
kqsum_j = warp_reduce_sum((float)kqsum_j); kqsum_j = warp_reduce_sum((float)kqsum_j);
#pragma unroll const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
const int i0 = i00 + 2*threadIdx.x;
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)]; #pragma unroll
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
const int i0 = i00 + threadIdx.x;
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
if (gridDim.y == 1) { if (gridDim.y == 1) {
dst_val /= __half2half2(kqsum_j); dst_val /= __half2half2(kqsum_j);
} }
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
} }
if (gridDim.y != 1 && threadIdx.x == 0) { if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
} }
} }
#else #else
@ -290,8 +296,8 @@ static __global__ void flash_attn_tile_ext_f16(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);

View file

@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f32(
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32, const int ne32,
const int ne33,
const int nb31, const int nb31,
const int nb32, const int nb32,
const int nb33,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -74,15 +76,17 @@ static __global__ void flash_attn_tile_ext_f32(
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const int stride_KV2 = nb11 / sizeof(half2); const int stride_KV2 = nb11 / sizeof(half2);
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@ -265,6 +269,8 @@ static __global__ void flash_attn_tile_ext_f32(
__syncthreads(); __syncthreads();
} }
float2 * dst2 = (float2 *) dst;
#pragma unroll #pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y; const int j_VKQ = j_VKQ_0 + threadIdx.y;
@ -276,22 +282,22 @@ static __global__ void flash_attn_tile_ext_f32(
float kqsum_j = kqsum[j_VKQ_0/nwarps]; float kqsum_j = kqsum[j_VKQ_0/nwarps];
kqsum_j = warp_reduce_sum(kqsum_j); kqsum_j = warp_reduce_sum(kqsum_j);
#pragma unroll const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
const int i0 = i00 + 2*threadIdx.x;
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)]; #pragma unroll
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
const int i0 = i00 + threadIdx.x;
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
if (gridDim.y == 1) { if (gridDim.y == 1) {
dst_val.x /= kqsum_j; dst_val.x /= kqsum_j;
dst_val.y /= kqsum_j; dst_val.y /= kqsum_j;
} }
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
} }
if (gridDim.y != 1 && threadIdx.x == 0) { if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
} }
} }
#else #else

View file

@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f16(
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32, const int ne32,
const int ne33,
const int nb31, const int nb31,
const int nb32, const int nb32,
const int nb33,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -65,14 +67,16 @@ static __global__ void flash_attn_vec_ext_f16(
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
Q += nb02* blockIdx.z + nb01*ic0; Q += nb03*sequence + nb02* head + nb01*ic0;
K += nb12*(blockIdx.z / gqa_ratio); K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb22*(blockIdx.z / gqa_ratio); V += nb23*sequence + nb22*(head / gqa_ratio);
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef); const half slopeh = __float2half(slopef);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@ -330,12 +334,11 @@ static __global__ void flash_attn_vec_ext_f16(
if (gridDim.y == 1) { if (gridDim.y == 1) {
dst_val /= kqsum[j_VKQ]; dst_val /= kqsum[j_VKQ];
} }
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
} }
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
} }
#else #else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@ -344,8 +347,8 @@ static __global__ void flash_attn_vec_ext_f16(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);

View file

@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f32(
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32, const int ne32,
const int ne33,
const int nb31, const int nb31,
const int nb32, const int nb32,
const int nb33,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -53,8 +55,8 @@ static __global__ void flash_attn_vec_ext_f32(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@ -77,14 +79,16 @@ static __global__ void flash_attn_vec_ext_f32(
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
Q += nb02* blockIdx.z + nb01*ic0; Q += nb03*sequence + nb02* head + nb01*ic0;
K += nb12*(blockIdx.z / gqa_ratio); K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape V += nb23*sequence + nb22*(head / gqa_ratio);
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = D / WARP_SIZE; constexpr int nwarps = D / WARP_SIZE;
@ -326,12 +330,11 @@ static __global__ void flash_attn_vec_ext_f32(
if (gridDim.y == 1) { if (gridDim.y == 1) {
dst_val /= kqsum[j_VKQ]; dst_val /= kqsum[j_VKQ];
} }
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
} }
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
} }
#else #else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@ -340,8 +343,8 @@ static __global__ void flash_attn_vec_ext_f32(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);

View file

@ -47,8 +47,10 @@ static __global__ void flash_attn_ext_f16(
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32, const int ne32,
const int ne33,
const int nb31, const int nb31,
const int nb32, const int nb32,
const int nb33,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -95,17 +97,19 @@ static __global__ void flash_attn_ext_f16(
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0); const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio)); const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const half2 * mask2 = (const half2 *) maskh; const half2 * mask2 = (const half2 *) maskh;
const int stride_Q = nb01 / sizeof(float); const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half); const int stride_KV = nb11 / sizeof(half);
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef); const half slopeh = __float2half(slopef);
const half2 slope2 = make_half2(slopef, slopef); const half2 slope2 = make_half2(slopef, slopef);
@ -400,7 +404,6 @@ static __global__ void flash_attn_ext_f16(
if (ic0 + j_VKQ >= ne01) { if (ic0 + j_VKQ >= ne01) {
return; return;
} }
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
float KQ_rowsum_j; float KQ_rowsum_j;
if (std::is_same<KQ_acc_t, float>::value) { if (std::is_same<KQ_acc_t, float>::value) {
@ -409,6 +412,8 @@ static __global__ void flash_attn_ext_f16(
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
} }
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < D; i0 += warp_size) { for (int i0 = 0; i0 < D; i0 += warp_size) {
const int i = i0 + threadIdx.x; const int i = i0 + threadIdx.x;
@ -419,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
if (gridDim.y == 1) { if (gridDim.y == 1) {
dst_val /= KQ_rowsum_j; dst_val /= KQ_rowsum_j;
} }
dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val; dst[j_dst_unrolled*D + i] = dst_val;
} }
if (gridDim.y == 1 || threadIdx.x != 0) { if (gridDim.y == 1 || threadIdx.x != 0) {
@ -433,7 +438,7 @@ static __global__ void flash_attn_ext_f16(
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
} }
dst_meta_val.y = KQ_rowsum_j; dst_meta_val.y = KQ_rowsum_j;
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val; dst_meta[j_dst_unrolled] = dst_meta_val;
} }
#else #else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@ -442,7 +447,8 @@ static __global__ void flash_attn_ext_f16(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31);
GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);

View file

@ -3413,12 +3413,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (op->src[0]->ne[0] == 192) { if (op->src[0]->ne[0] == 192) {
return false; return false;
} }
// TODO: support broadcast
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
if (op->src[0]->ne[3] != 1) {
return false;
}
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
return false; return false;
} }
@ -3431,6 +3425,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) { if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
return true; return true;
} }
if (op->src[3] && op->src[3]->ne[2] != 1) {
return false;
}
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) && return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
} }

View file

@ -335,6 +335,9 @@ extern "C" {
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573 // ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
}; };
// model quantization parameters // model quantization parameters

View file

@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
const llama_vocab & vocab, const llama_vocab & vocab,
const llama_memory_i * memory, const llama_memory_i * memory,
uint32_t n_embd, uint32_t n_embd,
uint32_t n_seq_max,
bool output_all) { bool output_all) {
clear(); clear();
@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
// validate input batch // validate input batch
// //
if (n_seq_max > LLAMA_MAX_SEQ) {
LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
return false;
}
if (batch.token) { if (batch.token) {
for (int32_t i = 0; i < batch.n_tokens; ++i) { for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) { if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
if (batch.seq_id) { if (batch.seq_id) {
for (int32_t i = 0; i < batch.n_tokens; ++i) { for (int32_t i = 0; i < batch.n_tokens; ++i) {
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) { if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ); LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
return false; return false;
} }
} }
@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
// initialize the starting position for each sequence based on the positions in the memory // initialize the starting position for each sequence based on the positions in the memory
llama_pos p0[LLAMA_MAX_SEQ]; llama_pos p0[LLAMA_MAX_SEQ];
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { for (uint32_t s = 0; s < n_seq_max; ++s) {
if (!memory) { if (!memory) {
// if no memory -> start from 0 // if no memory -> start from 0
p0[s] = 0; p0[s] = 0;
@ -144,6 +150,7 @@ bool llama_batch_allocr::init(
// //
this->n_embd = n_embd; this->n_embd = n_embd;
this->n_seq_max = n_seq_max;
// count the outputs in this batch // count the outputs in this batch
for (int32_t i = 0; i < batch.n_tokens; ++i) { for (int32_t i = 0; i < batch.n_tokens; ++i) {
@ -189,7 +196,7 @@ bool llama_batch_allocr::init(
seq_set_map[cur].push_back(i); seq_set_map[cur].push_back(i);
} }
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_set_unq.test(s)) { if (seq_set_unq.test(s)) {
seq_idx[s] = seq_id_unq.size(); seq_idx[s] = seq_id_unq.size();
seq_id_unq.push_back(s); seq_id_unq.push_back(s);
@ -241,7 +248,7 @@ bool llama_batch_allocr::init(
// consistency checks // consistency checks
// //
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_pos[s].empty()) { if (seq_pos[s].empty()) {
continue; continue;
} }
@ -284,8 +291,8 @@ bool llama_batch_allocr::init(
} }
if (memory) { if (memory) {
for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) { for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) { for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
if (seq_cpl[s0][s1]) { if (seq_cpl[s0][s1]) {
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) || if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) { memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
@ -316,12 +323,12 @@ bool llama_batch_allocr::init(
// //
{ {
seq_set_t cur_seq_set[LLAMA_MAX_SEQ]; seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { for (uint32_t s = 0; s < n_seq_max; ++s) {
cur_seq_set[s].set(); cur_seq_set[s].set();
} }
llama_pos cur_seq_pos[LLAMA_MAX_SEQ]; llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { for (uint32_t s = 0; s < n_seq_max; ++s) {
cur_seq_pos[s] = -1; cur_seq_pos[s] = -1;
} }
@ -692,7 +699,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
} }
} }
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_set_unq.test(s)) { if (seq_set_unq.test(s)) {
ubatch.seq_idx[s] = ubatch.seq_id_unq.size(); ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
ubatch.seq_id_unq.push_back(s); ubatch.seq_id_unq.push_back(s);

View file

@ -48,6 +48,7 @@ public:
const llama_vocab & vocab, const llama_vocab & vocab,
const llama_memory_i * memory, const llama_memory_i * memory,
uint32_t n_embd, uint32_t n_embd,
uint32_t n_seq_max,
bool output_all); bool output_all);
const llama_batch & get_batch() const; const llama_batch & get_batch() const;
@ -100,6 +101,7 @@ private:
const uint32_t n_pos_per_embd; const uint32_t n_pos_per_embd;
uint32_t n_embd; uint32_t n_embd;
uint32_t n_seq_max;
uint32_t n_outputs; uint32_t n_outputs;
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id

View file

@ -98,10 +98,20 @@ llama_context::llama_context(
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
cparams.n_batch = GGML_KQ_MASK_PAD; cparams.n_batch = GGML_KQ_MASK_PAD;
} }
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
cparams.op_offload = params.op_offload; cparams.op_offload = params.op_offload;
cparams.kv_unified = params.kv_unified;
{
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
const bool supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
if (!supports_set_rows && !cparams.kv_unified) {
LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
cparams.kv_unified = true;
}
}
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
@ -112,6 +122,7 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
@ -267,7 +278,7 @@ llama_context::llama_context(
// reserve worst-case graph // reserve worst-case graph
if (!hparams.vocab_only && memory) { if (!hparams.vocab_only && memory) {
const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
@ -300,7 +311,7 @@ llama_context::llama_context(
// reserve with tg graph to get the number of splits and nodes // reserve with tg graph to get the number of splits and nodes
{ {
auto * gf = graph_reserve(1, 1, 1, mctx.get()); auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
if (!gf) { if (!gf) {
throw std::runtime_error("failed to allocate compute tg buffers"); throw std::runtime_error("failed to allocate compute tg buffers");
} }
@ -311,6 +322,10 @@ llama_context::llama_context(
// reserve again with pp graph to avoid ggml-alloc reallocations during inference // reserve again with pp graph to avoid ggml-alloc reallocations during inference
{ {
// TODO: not sure if the following graph would be worster case for multi-stream KV caches:
//
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
//
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) { if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers"); throw std::runtime_error("failed to allocate compute pp buffers");
@ -475,7 +490,7 @@ bool llama_context::kv_self_update(bool optimize) {
throw std::runtime_error("failed to initialize memory context"); throw std::runtime_error("failed to initialize memory context");
} }
const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
@ -735,13 +750,15 @@ int llama_context::encode(const llama_batch & batch_inp) {
const int32_t n_vocab = model.vocab.n_tokens(); const int32_t n_vocab = model.vocab.n_tokens();
// note: during encode, we always pass the full sequence starting from pos = 0 // note: during encode, we always pass the full sequence starting from pos = 0
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) { if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1; return -1;
} }
const uint32_t n_tokens = balloc->get_n_tokens(); const uint32_t n_tokens = balloc->get_n_tokens();
// [TAG_NO_CACHE_PAD]
// TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
const llama_ubatch ubatch = balloc->split_simple(n_tokens); const llama_ubatch ubatch = balloc->split_simple(n_tokens);
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
@ -910,7 +927,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
// when computing embeddings, all tokens are output // when computing embeddings, all tokens are output
const bool output_all = cparams.embeddings; const bool output_all = cparams.embeddings;
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) { if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1; return -1;
} }
@ -2039,7 +2056,7 @@ void llama_context::opt_epoch_iter(
batch.logits [pos_batch] = true; batch.logits [pos_batch] = true;
} }
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) { if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return; return;
} }
@ -2198,6 +2215,7 @@ llama_context_params llama_context_default_params() {
/*.no_perf =*/ true, /*.no_perf =*/ true,
/*.op_offload =*/ true, /*.op_offload =*/ true,
/*.swa_full =*/ true, /*.swa_full =*/ true,
/*.kv_unified =*/ false,
}; };
return result; return result;

View file

@ -11,8 +11,8 @@ struct llama_cparams {
uint32_t n_batch; uint32_t n_batch;
uint32_t n_ubatch; uint32_t n_ubatch;
uint32_t n_seq_max; uint32_t n_seq_max;
int n_threads; // number of threads to use for generation int32_t n_threads; // number of threads to use for generation
int n_threads_batch; // number of threads to use for batch processing int32_t n_threads_batch; // number of threads to use for batch processing
float rope_freq_base; float rope_freq_base;
float rope_freq_scale; float rope_freq_scale;
@ -33,6 +33,7 @@ struct llama_cparams {
bool no_perf; bool no_perf;
bool warmup; bool warmup;
bool op_offload; bool op_offload;
bool kv_unified;
enum llama_pooling_type pooling_type; enum llama_pooling_type pooling_type;

View file

@ -982,12 +982,15 @@ ggml_tensor * llm_graph_context::build_attn_mha(
float kq_scale) const { float kq_scale) const {
const bool v_trans = v->nb[1] > v->nb[2]; const bool v_trans = v->nb[1] > v->nb[2];
// split the batch into streams if needed
const auto n_stream = k->ne[3];
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
q = ggml_permute(ctx0, q, 0, 2, 1, 3); q = ggml_permute(ctx0, q, 0, 2, 1, 3);
k = ggml_permute(ctx0, k, 0, 2, 1, 3); k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3); v = ggml_permute(ctx0, v, 0, 2, 1, 3);
const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];
const auto n_kv = k->ne[1]; const auto n_kv = k->ne[1];
ggml_tensor * cur; ggml_tensor * cur;
@ -1030,7 +1033,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
#endif #endif
} }
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
} else { } else {
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
@ -1075,7 +1078,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); // recombine streams
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
if (!cparams.offload_kqv) { if (!cparams.offload_kqv) {
// all nodes between the KV store and the attention output are run on the CPU // all nodes between the KV store and the attention output are run on the CPU
@ -1122,6 +1126,10 @@ ggml_tensor * llm_graph_context::build_attn(
const auto & kq_mask = inp->get_kq_mask(); const auto & kq_mask = inp->get_kq_mask();
// [TAG_NO_CACHE_PAD]
// TODO: if ubatch.equal_seqs == true, we can split the three tensors below into ubatch.n_seqs_unq streams
assert(ubatch.equal_seqs == false);
ggml_tensor * q = q_cur; ggml_tensor * q = q_cur;
ggml_tensor * k = k_cur; ggml_tensor * k = k_cur;
ggml_tensor * v = v_cur; ggml_tensor * v = v_cur;
@ -1158,11 +1166,12 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
const auto n_kv = mctx_cur->get_n_kv(); const auto n_kv = mctx_cur->get_n_kv();
const auto n_tokens = ubatch.n_tokens; const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
ggml_set_input(inp->self_kq_mask); ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@ -1362,13 +1371,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur); auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
{ {
const auto n_kv = mctx_cur->get_base()->get_n_kv(); const auto n_kv = mctx_cur->get_base()->get_n_kv();
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
ggml_set_input(inp->self_kq_mask); ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@ -1382,7 +1393,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
ggml_set_input(inp->self_kq_mask_swa); ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;

View file

@ -255,10 +255,10 @@ public:
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
const llama_hparams & hparams; const llama_hparams & hparams;
const llama_cparams & cparams; const llama_cparams & cparams;
@ -289,14 +289,14 @@ public:
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch] ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1] ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1] ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
const llama_hparams & hparams; const llama_hparams & hparams;
const llama_cparams & cparams; const llama_cparams & cparams;

View file

@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
return n_embd_head_v * n_head_kv; return n_embd_head_v * n_head_kv;
} }
bool llama_hparams::is_n_embd_k_gqa_variable() const {
const uint32_t val = n_embd_k_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
if (val != n_embd_k_gqa(il)) {
return true;
}
}
return false;
}
bool llama_hparams::is_n_embd_v_gqa_variable() const {
const uint32_t val = n_embd_v_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
if (val != n_embd_v_gqa(il)) {
return true;
}
}
return false;
}
uint32_t llama_hparams::n_embd_k_gqa_max() const {
uint32_t val = n_embd_k_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
val = std::max(val, n_embd_k_gqa(il));
}
return val;
}
uint32_t llama_hparams::n_embd_v_gqa_max() const {
uint32_t val = n_embd_v_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
val = std::max(val, n_embd_v_gqa(il));
}
return val;
}
uint32_t llama_hparams::n_embd_r() const { uint32_t llama_hparams::n_embd_r() const {
if (wkv_head_size != 0) { if (wkv_head_size != 0) {
// for RWKV models // for RWKV models

View file

@ -191,6 +191,14 @@ struct llama_hparams {
// dimension of value embeddings across all k-v heads // dimension of value embeddings across all k-v heads
uint32_t n_embd_v_gqa(uint32_t il = 0) const; uint32_t n_embd_v_gqa(uint32_t il = 0) const;
// true if any layer has a different n_embd_k_gqa/n_embd_v_gqa
bool is_n_embd_k_gqa_variable() const;
bool is_n_embd_v_gqa_variable() const;
// return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers
uint32_t n_embd_k_gqa_max() const;
uint32_t n_embd_v_gqa_max() const;
// dimension of the rolling state embeddings // dimension of the rolling state embeddings
// corresponds to Mamba's conv_states size or RWKV's token_shift states size // corresponds to Mamba's conv_states size or RWKV's token_shift states size
uint32_t n_embd_r() const; uint32_t n_embd_r() const;

View file

@ -18,16 +18,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
bool v_trans, bool v_trans,
bool offload, bool offload,
bool swa_full, bool swa_full,
bool unified,
uint32_t kv_size, uint32_t kv_size,
uint32_t n_seq_max, uint32_t n_seq_max,
uint32_t n_ubatch, uint32_t n_ubatch,
uint32_t n_pad) : hparams(model.hparams) { uint32_t n_pad) : hparams(model.hparams), unified(unified) {
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
const uint32_t size_base = kv_size; const uint32_t size_base = kv_size;
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad)); uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
if (swa_full) { if (swa_full) {
@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
kv_base = std::make_unique<llama_kv_cache_unified>( kv_base = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_base), type_k, type_v, model, std::move(filter_base), type_k, type_v,
v_trans, offload, size_base, n_seq_max, n_pad, v_trans, offload, unified, size_base, n_seq_max, n_pad,
0, LLAMA_SWA_TYPE_NONE); 0, LLAMA_SWA_TYPE_NONE);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique<llama_kv_cache_unified>( kv_swa = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_swa), type_k, type_v, model, std::move(filter_swa), type_k, type_v,
v_trans, offload, size_swa, n_seq_max, n_pad, v_trans, offload, unified, size_swa, n_seq_max, n_pad,
hparams.n_swa, hparams.swa_type); hparams.n_swa, hparams.swa_type);
} }
@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
// first try simple split // first try simple split
do { do {
if (!unified) {
// requires equal splits, so we skip the simple split
break;
}
balloc.split_reset(); balloc.split_reset();
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
@ -140,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
while (true) { while (true) {
auto ubatch = balloc.split_equal(n_ubatch, false); auto ubatch = balloc.split_equal(n_ubatch, !unified);
if (ubatch.n_tokens == 0) { if (ubatch.n_tokens == 0) {
break; break;

View file

@ -20,6 +20,7 @@ public:
bool v_trans, bool v_trans,
bool offload, bool offload,
bool swa_full, bool swa_full,
bool unified,
uint32_t kv_size, uint32_t kv_size,
uint32_t n_seq_max, uint32_t n_seq_max,
uint32_t n_ubatch, uint32_t n_ubatch,
@ -68,6 +69,8 @@ public:
private: private:
const llama_hparams & hparams; const llama_hparams & hparams;
const bool unified;
std::unique_ptr<llama_kv_cache_unified> kv_base; std::unique_ptr<llama_kv_cache_unified> kv_base;
std::unique_ptr<llama_kv_cache_unified> kv_swa; std::unique_ptr<llama_kv_cache_unified> kv_swa;
}; };

File diff suppressed because it is too large Load diff

View file

@ -35,16 +35,50 @@ public:
std::vector<uint32_t> ids; std::vector<uint32_t> ids;
}; };
struct stream_copy_info {
bool empty() const {
assert(ssrc.size() == sdst.size());
return ssrc.empty();
}
std::vector<uint32_t> ssrc;
std::vector<uint32_t> sdst;
};
// for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
// KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]] // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
struct slot_info { struct slot_info {
// data for ggml_set_rows // data for ggml_set_rows
using idx_vec_t = std::vector<uint32_t>; using idx_vec_t = std::vector<uint32_t>;
idx_vec_t idxs; // number of streams: ns = s1 - s0 + 1
llama_seq_id s0;
llama_seq_id s1;
std::vector<llama_seq_id> strm; // [ns]
std::vector<idx_vec_t> idxs; // [ns]
uint32_t head() const { uint32_t head() const {
return idxs.at(0); GGML_ASSERT(idxs.size() == 1);
GGML_ASSERT(!idxs[0].empty());
return idxs[0][0];
}
void resize(size_t n) {
strm.resize(n);
idxs.resize(n);
}
size_t size() const {
GGML_ASSERT(idxs.size() == strm.size());
GGML_ASSERT(!idxs.empty());
return idxs[0].size();
}
size_t n_stream() const {
return strm.size();
} }
bool empty() const { bool empty() const {
@ -54,9 +88,6 @@ public:
void clear() { void clear() {
idxs.clear(); idxs.clear();
} }
// TODO: implement
//std::vector<idx_vec_t> seq_idxs;
}; };
using slot_info_vec_t = std::vector<slot_info>; using slot_info_vec_t = std::vector<slot_info>;
@ -68,6 +99,7 @@ public:
ggml_type type_v, ggml_type type_v,
bool v_trans, bool v_trans,
bool offload, bool offload,
bool unified,
uint32_t kv_size, uint32_t kv_size,
uint32_t n_seq_max, uint32_t n_seq_max,
uint32_t n_pad, uint32_t n_pad,
@ -112,6 +144,7 @@ public:
// //
uint32_t get_size() const; uint32_t get_size() const;
uint32_t get_n_stream() const;
bool get_has_shift() const; bool get_has_shift() const;
@ -122,8 +155,8 @@ public:
uint32_t get_n_kv() const; uint32_t get_n_kv() const;
// get views of the current state of the cache // get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const; ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
// store k_cur and v_cur in the cache based on the provided head location // store k_cur and v_cur in the cache based on the provided head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const; ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
@ -137,7 +170,7 @@ public:
// return empty vector on failure // return empty vector on failure
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches); slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo); bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
// find a slot of kv cells that can hold the ubatch // find a slot of kv cells that can hold the ubatch
// if cont == true, then the slot must be continuous // if cont == true, then the slot must be continuous
@ -157,8 +190,9 @@ public:
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_k_shift(ggml_tensor * dst) const; void set_input_k_shift(ggml_tensor * dst) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
private: private:
@ -172,15 +206,15 @@ private:
ggml_tensor * k; ggml_tensor * k;
ggml_tensor * v; ggml_tensor * v;
std::vector<ggml_tensor *> k_stream;
std::vector<ggml_tensor *> v_stream;
}; };
bool v_trans = true; // the value tensor is transposed bool v_trans = true; // the value tensor is transposed
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
uint32_t head = 0;
const uint32_t n_seq_max = 1; const uint32_t n_seq_max = 1;
const uint32_t n_stream = 1;
// required padding // required padding
const uint32_t n_pad = 1; const uint32_t n_pad = 1;
@ -200,7 +234,17 @@ private:
std::vector<ggml_context_ptr> ctxs; std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs; std::vector<ggml_backend_buffer_ptr> bufs;
llama_kv_cells_unified cells; // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
std::vector<uint32_t> v_heads;
std::vector<llama_kv_cells_unified> v_cells;
// maps from a sequence id to a stream id
std::vector<uint32_t> seq_to_stream;
// pending stream copies that will be applied during the next update
stream_copy_info sc_info;
std::vector<kv_layer> layers; std::vector<kv_layer> layers;
@ -237,11 +281,17 @@ private:
ggml_cgraph * gf, ggml_cgraph * gf,
const defrag_info & dinfo) const; const defrag_info & dinfo) const;
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const; struct cell_ranges_t {
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const; uint32_t strm;
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
bool state_read_data(llama_io_read_i & io, uint32_t cell_count); };
void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
}; };
class llama_kv_cache_unified_context : public llama_memory_context_i { class llama_kv_cache_unified_context : public llama_memory_context_i {
@ -249,6 +299,7 @@ public:
// some shorthands // some shorthands
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
using defrag_info = llama_kv_cache_unified::defrag_info; using defrag_info = llama_kv_cache_unified::defrag_info;
using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
// used for errors // used for errors
llama_kv_cache_unified_context(llama_memory_status status); llama_kv_cache_unified_context(llama_memory_status status);
@ -262,7 +313,8 @@ public:
llama_kv_cache_unified * kv, llama_kv_cache_unified * kv,
llama_context * lctx, llama_context * lctx,
bool do_shift, bool do_shift,
defrag_info dinfo); defrag_info dinfo,
stream_copy_info sc_info);
// used to create a batch procesing context from a batch // used to create a batch procesing context from a batch
llama_kv_cache_unified_context( llama_kv_cache_unified_context(
@ -320,6 +372,8 @@ private:
defrag_info dinfo; defrag_info dinfo;
stream_copy_info sc_info;
// //
// batch processing context // batch processing context
// //

View file

@ -40,6 +40,7 @@ llama_memory_hybrid::llama_memory_hybrid(
offload, offload,
kv_size, kv_size,
n_seq_max, n_seq_max,
1,
n_pad, n_pad,
n_swa, n_swa,
swa_type swa_type

View file

@ -16647,7 +16647,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
} else { } else {
const auto padding = llama_kv_cache_unified::get_padding(cparams); const auto padding = llama_kv_cache_unified::get_padding(cparams);
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); uint32_t n_ctx_per_stream = cparams.n_ctx;
if (!cparams.kv_unified) {
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max;
} else {
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
cparams.n_ctx = n_ctx_per_stream;
}
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
@ -16661,7 +16672,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
!cparams.flash_attn, !cparams.flash_attn,
cparams.offload_kqv, cparams.offload_kqv,
params.swa_full, params.swa_full,
cparams.n_ctx, cparams.kv_unified,
n_ctx_per_stream,
cparams.n_seq_max, cparams.n_seq_max,
cparams.n_ubatch, cparams.n_ubatch,
padding); padding);
@ -16675,7 +16687,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
params.type_v, params.type_v,
!cparams.flash_attn, !cparams.flash_attn,
cparams.offload_kqv, cparams.offload_kqv,
cparams.n_ctx, cparams.kv_unified,
n_ctx_per_stream,
cparams.n_seq_max, cparams.n_seq_max,
padding, padding,
hparams.n_swa, hparams.n_swa,

View file

@ -4282,7 +4282,7 @@ struct test_flash_attn_ext : public test_case {
ggml_tensor * m = nullptr; ggml_tensor * m = nullptr;
if (mask) { if (mask) {
m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[0], nr23[1]); m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, nr23[1]);
ggml_set_name(m, "m"); ggml_set_name(m, "m");
} }

View file

@ -127,10 +127,9 @@ int main(int argc, char ** argv) {
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) { for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
for (int i = 0; i < pp; ++i) { for (int i = 0; i < pp; ++i) {
common_batch_add(batch, 0, i, { j }, false); common_batch_add(batch, 0, i, { j }, i == pp - 1);
} }
} }
batch.logits[batch.n_tokens - 1] = true;
const auto t_pp_start = ggml_time_us(); const auto t_pp_start = ggml_time_us();