llama: use f16 mask for FA to save VRAM (#23764)

* llama: use f16 mask for FA

* review: add llama_cast + formatting

* simplify
This commit is contained in:
Aman Gupta 2026-05-29 15:44:43 +08:00 committed by GitHub
parent fe12e422ad
commit 031ddb2e08
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 123 additions and 86 deletions

View file

@ -29,7 +29,10 @@ static ggml_tensor * build_attn_inp_kq_mask(
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
// flash attention requires an f16 mask
const auto type = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
ggml_tensor * res = ggml_new_tensor_4d(ctx, type, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(res);
ggml_set_name(res, "attn_inp_kq_mask");
@ -381,7 +384,8 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
}
}
static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
template <typename T>
static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
const char * swa_type_str = "unknown";
@ -405,7 +409,7 @@ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64
for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
LLAMA_LOG_DEBUG(" %2d ", i);
for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
float val = data[i * n_kv + j];
float val = llama_cast<float>(data[i * n_kv + j]);
if (val == -INFINITY) {
LLAMA_LOG_DEBUG("");
} else {
@ -420,7 +424,10 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
const int64_t n_kv = ubatch->n_tokens;
const int64_t n_tokens = ubatch->n_tokens;
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
const auto fill_mask = [&](auto * data, int64_t ne, int n_swa, llama_swa_type swa_type) {
using T = std::remove_reference_t<decltype(*data)>;
std::fill(data, data + ne, llama_cast<T>(-INFINITY));
for (int i1 = 0; i1 < n_tokens; ++i1) {
const llama_seq_id s1 = ubatch->seq_id[i1][0];
const llama_pos p1 = ubatch->pos[i1];
@ -446,38 +453,30 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
continue;
}
data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
data[idst + i0] = llama_cast<T>(hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f);
}
}
if (debug) {
print_mask(data, n_tokens, n_kv, n_swa, swa_type);
}
};
{
GGML_ASSERT(self_kq_mask);
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
float * data = (float *) self_kq_mask->data;
std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
if (debug) {
print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
}
GGML_ASSERT(self_kq_mask);
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
if (self_kq_mask->type == GGML_TYPE_F16) {
fill_mask((ggml_fp16_t *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
} else {
fill_mask((float *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
}
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(self_kq_mask_swa);
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
float * data = (float *) self_kq_mask_swa->data;
std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
fill_mask(data, hparams.n_swa, hparams.swa_type);
if (debug) {
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
if (self_kq_mask_swa->type == GGML_TYPE_F16) {
fill_mask((ggml_fp16_t *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type);
} else {
fill_mask((float *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type);
}
}
}
@ -601,23 +600,30 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
float * data = (float *) cross_kq_mask->data;
const auto fill_mask = [&](auto * data) {
using T = std::remove_reference_t<decltype(*data)>;
for (int i = 0; i < n_tokens; ++i) {
GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first");
for (int j = 0; j < n_enc; ++j) {
float f = -INFINITY;
for (int i = 0; i < n_tokens; ++i) {
GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first");
for (int j = 0; j < n_enc; ++j) {
float f = -INFINITY;
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[i][s];
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[i][s];
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
f = 0.0f;
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
f = 0.0f;
}
}
}
data[i*n_enc + j] = f;
data[i*n_enc + j] = llama_cast<T>(f);
}
}
};
if (cross_kq_mask->type == GGML_TYPE_F16) {
fill_mask((ggml_fp16_t *) cross_kq_mask->data);
} else {
fill_mask((float *) cross_kq_mask->data);
}
}
@ -2121,17 +2127,20 @@ ggml_tensor * llm_graph_context::build_attn_mha(
llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
// flash attention requires an f16 mask
const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
inp->self_kq_mask_cnv = inp->self_kq_mask;
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa;
} else {
inp->self_kq_mask_swa = nullptr;
inp->self_kq_mask_swa_cnv = nullptr;
@ -2208,7 +2217,7 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
inp->self_kq_mask_cnv = inp->self_kq_mask;
}
inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0);
@ -2315,7 +2324,7 @@ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
inp->self_kq_mask_cnv = inp->self_kq_mask;
}
return inp;
@ -2479,10 +2488,13 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1);
// flash attention requires an f16 mask
const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_enc, n_tokens, 1, 1);
ggml_set_input(inp->cross_kq_mask);
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
inp->cross_kq_mask_cnv = inp->cross_kq_mask;
return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
}
@ -2543,7 +2555,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
inp->self_kq_mask_cnv = inp->self_kq_mask;
}
{
@ -2553,7 +2565,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa;
}
inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
@ -2722,7 +2734,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
inp_attn->self_kq_mask_cnv = inp_attn->self_kq_mask;
}
{
@ -2730,7 +2742,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
inp_attn->self_kq_mask_swa_cnv = inp_attn->self_kq_mask_swa;
}
auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);