mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-06-01 06:00:36 +00:00
llama: use f16 mask for FA to save VRAM (#23764)
* llama: use f16 mask for FA * review: add llama_cast + formatting * simplify
This commit is contained in:
parent
fe12e422ad
commit
031ddb2e08
4 changed files with 123 additions and 86 deletions
|
|
@ -29,7 +29,10 @@ static ggml_tensor * build_attn_inp_kq_mask(
|
|||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
// flash attention requires an f16 mask
|
||||
const auto type = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
ggml_tensor * res = ggml_new_tensor_4d(ctx, type, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_set_input(res);
|
||||
ggml_set_name(res, "attn_inp_kq_mask");
|
||||
|
||||
|
|
@ -381,7 +384,8 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
|||
}
|
||||
}
|
||||
|
||||
static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
|
||||
template <typename T>
|
||||
static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
|
||||
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
|
||||
const char * swa_type_str = "unknown";
|
||||
|
||||
|
|
@ -405,7 +409,7 @@ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64
|
|||
for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
|
||||
LLAMA_LOG_DEBUG(" %2d ", i);
|
||||
for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
|
||||
float val = data[i * n_kv + j];
|
||||
float val = llama_cast<float>(data[i * n_kv + j]);
|
||||
if (val == -INFINITY) {
|
||||
LLAMA_LOG_DEBUG(" ∞");
|
||||
} else {
|
||||
|
|
@ -420,7 +424,10 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|||
const int64_t n_kv = ubatch->n_tokens;
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
|
||||
const auto fill_mask = [&](auto * data, int64_t ne, int n_swa, llama_swa_type swa_type) {
|
||||
using T = std::remove_reference_t<decltype(*data)>;
|
||||
std::fill(data, data + ne, llama_cast<T>(-INFINITY));
|
||||
|
||||
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
||||
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
||||
const llama_pos p1 = ubatch->pos[i1];
|
||||
|
|
@ -446,38 +453,30 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|||
continue;
|
||||
}
|
||||
|
||||
data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
||||
data[idst + i0] = llama_cast<T>(hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
if (debug) {
|
||||
print_mask(data, n_tokens, n_kv, n_swa, swa_type);
|
||||
}
|
||||
};
|
||||
|
||||
{
|
||||
GGML_ASSERT(self_kq_mask);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
||||
|
||||
float * data = (float *) self_kq_mask->data;
|
||||
|
||||
std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
|
||||
|
||||
fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
|
||||
|
||||
if (debug) {
|
||||
print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
|
||||
}
|
||||
GGML_ASSERT(self_kq_mask);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
||||
if (self_kq_mask->type == GGML_TYPE_F16) {
|
||||
fill_mask((ggml_fp16_t *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
|
||||
} else {
|
||||
fill_mask((float *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
|
||||
}
|
||||
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
GGML_ASSERT(self_kq_mask_swa);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
|
||||
|
||||
float * data = (float *) self_kq_mask_swa->data;
|
||||
|
||||
std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
|
||||
|
||||
fill_mask(data, hparams.n_swa, hparams.swa_type);
|
||||
|
||||
if (debug) {
|
||||
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
|
||||
if (self_kq_mask_swa->type == GGML_TYPE_F16) {
|
||||
fill_mask((ggml_fp16_t *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type);
|
||||
} else {
|
||||
fill_mask((float *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -601,23 +600,30 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|||
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
|
||||
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
|
||||
|
||||
float * data = (float *) cross_kq_mask->data;
|
||||
const auto fill_mask = [&](auto * data) {
|
||||
using T = std::remove_reference_t<decltype(*data)>;
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first");
|
||||
for (int j = 0; j < n_enc; ++j) {
|
||||
float f = -INFINITY;
|
||||
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first");
|
||||
for (int j = 0; j < n_enc; ++j) {
|
||||
float f = -INFINITY;
|
||||
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
||||
|
||||
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
||||
|
||||
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
|
||||
f = 0.0f;
|
||||
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
|
||||
f = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data[i*n_enc + j] = f;
|
||||
data[i*n_enc + j] = llama_cast<T>(f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (cross_kq_mask->type == GGML_TYPE_F16) {
|
||||
fill_mask((ggml_fp16_t *) cross_kq_mask->data);
|
||||
} else {
|
||||
fill_mask((float *) cross_kq_mask->data);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2121,17 +2127,20 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
|
||||
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
||||
|
||||
// flash attention requires an f16 mask
|
||||
const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
inp->self_kq_mask_cnv = inp->self_kq_mask;
|
||||
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa;
|
||||
} else {
|
||||
inp->self_kq_mask_swa = nullptr;
|
||||
inp->self_kq_mask_swa_cnv = nullptr;
|
||||
|
|
@ -2208,7 +2217,7 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
|
|||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
inp->self_kq_mask_cnv = inp->self_kq_mask;
|
||||
}
|
||||
|
||||
inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0);
|
||||
|
|
@ -2315,7 +2324,7 @@ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
|
|||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
inp->self_kq_mask_cnv = inp->self_kq_mask;
|
||||
}
|
||||
|
||||
return inp;
|
||||
|
|
@ -2479,10 +2488,13 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
|||
|
||||
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
||||
|
||||
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1);
|
||||
// flash attention requires an f16 mask
|
||||
const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_enc, n_tokens, 1, 1);
|
||||
ggml_set_input(inp->cross_kq_mask);
|
||||
|
||||
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
|
||||
inp->cross_kq_mask_cnv = inp->cross_kq_mask;
|
||||
|
||||
return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
|
@ -2543,7 +2555,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|||
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
inp->self_kq_mask_cnv = inp->self_kq_mask;
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -2553,7 +2565,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|||
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa;
|
||||
}
|
||||
|
||||
inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
|
||||
|
|
@ -2722,7 +2734,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
|
|||
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
|
||||
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
|
||||
inp_attn->self_kq_mask_cnv = inp_attn->self_kq_mask;
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -2730,7 +2742,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
|
|||
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
|
||||
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
|
||||
inp_attn->self_kq_mask_swa_cnv = inp_attn->self_kq_mask_swa;
|
||||
}
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue