llama: use f16 mask for FA to save VRAM (#23764)

* llama: use f16 mask for FA

* review: add llama_cast + formatting

* simplify
This commit is contained in:
Aman Gupta 2026-05-29 15:44:43 +08:00 committed by GitHub
parent fe12e422ad
commit 031ddb2e08
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 123 additions and 86 deletions

View file

@ -29,7 +29,10 @@ static ggml_tensor * build_attn_inp_kq_mask(
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
// flash attention requires an f16 mask
const auto type = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
ggml_tensor * res = ggml_new_tensor_4d(ctx, type, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(res);
ggml_set_name(res, "attn_inp_kq_mask");
@ -381,7 +384,8 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
}
}
static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
template <typename T>
static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
const char * swa_type_str = "unknown";
@ -405,7 +409,7 @@ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64
for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
LLAMA_LOG_DEBUG(" %2d ", i);
for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
float val = data[i * n_kv + j];
float val = llama_cast<float>(data[i * n_kv + j]);
if (val == -INFINITY) {
LLAMA_LOG_DEBUG("");
} else {
@ -420,7 +424,10 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
const int64_t n_kv = ubatch->n_tokens;
const int64_t n_tokens = ubatch->n_tokens;
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
const auto fill_mask = [&](auto * data, int64_t ne, int n_swa, llama_swa_type swa_type) {
using T = std::remove_reference_t<decltype(*data)>;
std::fill(data, data + ne, llama_cast<T>(-INFINITY));
for (int i1 = 0; i1 < n_tokens; ++i1) {
const llama_seq_id s1 = ubatch->seq_id[i1][0];
const llama_pos p1 = ubatch->pos[i1];
@ -446,38 +453,30 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
continue;
}
data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
data[idst + i0] = llama_cast<T>(hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f);
}
}
if (debug) {
print_mask(data, n_tokens, n_kv, n_swa, swa_type);
}
};
{
GGML_ASSERT(self_kq_mask);
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
float * data = (float *) self_kq_mask->data;
std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
if (debug) {
print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
}
GGML_ASSERT(self_kq_mask);
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
if (self_kq_mask->type == GGML_TYPE_F16) {
fill_mask((ggml_fp16_t *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
} else {
fill_mask((float *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
}
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(self_kq_mask_swa);
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
float * data = (float *) self_kq_mask_swa->data;
std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
fill_mask(data, hparams.n_swa, hparams.swa_type);
if (debug) {
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
if (self_kq_mask_swa->type == GGML_TYPE_F16) {
fill_mask((ggml_fp16_t *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type);
} else {
fill_mask((float *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type);
}
}
}
@ -601,23 +600,30 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
float * data = (float *) cross_kq_mask->data;
const auto fill_mask = [&](auto * data) {
using T = std::remove_reference_t<decltype(*data)>;
for (int i = 0; i < n_tokens; ++i) {
GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first");
for (int j = 0; j < n_enc; ++j) {
float f = -INFINITY;
for (int i = 0; i < n_tokens; ++i) {
GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first");
for (int j = 0; j < n_enc; ++j) {
float f = -INFINITY;
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[i][s];
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[i][s];
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
f = 0.0f;
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
f = 0.0f;
}
}
}
data[i*n_enc + j] = f;
data[i*n_enc + j] = llama_cast<T>(f);
}
}
};
if (cross_kq_mask->type == GGML_TYPE_F16) {
fill_mask((ggml_fp16_t *) cross_kq_mask->data);
} else {
fill_mask((float *) cross_kq_mask->data);
}
}
@ -2121,17 +2127,20 @@ ggml_tensor * llm_graph_context::build_attn_mha(
llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
// flash attention requires an f16 mask
const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
inp->self_kq_mask_cnv = inp->self_kq_mask;
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa;
} else {
inp->self_kq_mask_swa = nullptr;
inp->self_kq_mask_swa_cnv = nullptr;
@ -2208,7 +2217,7 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
inp->self_kq_mask_cnv = inp->self_kq_mask;
}
inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0);
@ -2315,7 +2324,7 @@ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
inp->self_kq_mask_cnv = inp->self_kq_mask;
}
return inp;
@ -2479,10 +2488,13 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1);
// flash attention requires an f16 mask
const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_enc, n_tokens, 1, 1);
ggml_set_input(inp->cross_kq_mask);
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
inp->cross_kq_mask_cnv = inp->cross_kq_mask;
return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
}
@ -2543,7 +2555,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
inp->self_kq_mask_cnv = inp->self_kq_mask;
}
{
@ -2553,7 +2565,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa;
}
inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
@ -2722,7 +2734,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
inp_attn->self_kq_mask_cnv = inp_attn->self_kq_mask;
}
{
@ -2730,7 +2742,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
inp_attn->self_kq_mask_swa_cnv = inp_attn->self_kq_mask_swa;
}
auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);

View file

@ -291,10 +291,10 @@ public:
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
// n_tokens == n_batch
ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_tokens, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa = nullptr; // F32/F16 [n_tokens, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
const llama_hparams hparams;
const llama_cparams cparams;
@ -324,8 +324,8 @@ public:
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
// note: assumes v_rot^2 == I
ggml_tensor * self_k_rot = nullptr;
@ -364,8 +364,8 @@ public:
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
const llama_hparams hparams;
const llama_cparams cparams;
@ -402,10 +402,10 @@ public:
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_k_rot = nullptr;
ggml_tensor * self_v_rot = nullptr;
@ -428,8 +428,8 @@ public:
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
ggml_tensor * cross_kq_mask = nullptr; // F32/F16 [n_outputs_enc, n_batch, 1, 1]
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32/F16 [n_outputs_enc, n_batch, 1, 1]
const llama_cross * cross = nullptr;
};

View file

@ -3,6 +3,7 @@
#include "ggml.h" // for ggml_log_level
#include <string>
#include <type_traits>
#include <vector>
#ifdef __GNUC__
@ -40,6 +41,19 @@ struct no_init {
no_init() = default;
};
template <typename dst_t, typename src_t>
static inline dst_t llama_cast(src_t v) {
if constexpr (std::is_same_v<src_t, dst_t>) {
return v;
} else if constexpr (std::is_same_v<src_t, ggml_fp16_t> && std::is_same_v<dst_t, float>) {
return ggml_fp16_to_fp32(v);
} else if constexpr (std::is_same_v<src_t, float> && std::is_same_v<dst_t, ggml_fp16_t>) {
return ggml_fp32_to_fp16(v);
} else {
static_assert(std::is_same_v<dst_t, void>, "unsupported type combination");
}
}
struct time_meas {
time_meas(int64_t & t_acc, bool disable = false);
~time_meas();

View file

@ -1430,8 +1430,8 @@ struct args_set_input_kq_mask {
int64_t n_tps;
};
template<bool causal, bool swa, bool is_2d, bool alibi>
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
template<typename T, bool causal, bool swa, bool is_2d, bool alibi>
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) {
//const auto & hparams = args.hparams;
const auto & ubatch = args.ubatch;
@ -1445,6 +1445,9 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float *
const int64_t n_stream = args.n_stream;
const int64_t n_tps = args.n_tps;
const T mask_keep = llama_cast<T>(0.0f);
const T mask_drop = llama_cast<T>(-INFINITY);
// the min position in the batch for each sequence
llama_pos seq_pos_min[LLAMA_MAX_SEQ];
std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX);
@ -1563,46 +1566,55 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float *
}
if (alibi) {
data[idst + j] = -std::abs(p0 - p1);
data[idst + j] = llama_cast<T>(static_cast<float>(-std::abs(p0 - p1)));
} else {
data[idst + j] = 0.0f;
data[idst + j] = mask_keep;
}
continue;
skip:
data[idst + j] = -INFINITY;
data[idst + j] = mask_drop;
}
}
}
}
template<bool causal, bool swa, bool is_2d>
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
template<typename T, bool causal, bool swa, bool is_2d>
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) {
const bool alibi = args.hparams.use_alibi;
if (alibi) {
set_input_kq_mask_impl<causal, swa, is_2d, true> (args, data);
set_input_kq_mask_impl<T, causal, swa, is_2d, true> (args, data);
} else {
set_input_kq_mask_impl<causal, swa, is_2d, false>(args, data);
set_input_kq_mask_impl<T, causal, swa, is_2d, false>(args, data);
}
}
template<bool causal, bool swa>
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
template<typename T, bool causal, bool swa>
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) {
const bool is_2d = args.ubatch->is_pos_2d();
if (is_2d) {
set_input_kq_mask_impl<causal, swa, true> (args, data);
set_input_kq_mask_impl<T, causal, swa, true> (args, data);
} else {
set_input_kq_mask_impl<causal, swa, false>(args, data);
set_input_kq_mask_impl<T, causal, swa, false>(args, data);
}
}
template<bool causal>
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
template<typename T, bool causal>
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) {
const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE;
if (swa) {
set_input_kq_mask_impl<causal, true> (args, data);
set_input_kq_mask_impl<T, causal, true> (args, data);
} else {
set_input_kq_mask_impl<causal, false>(args, data);
set_input_kq_mask_impl<T, causal, false>(args, data);
}
}
template<typename T>
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data, bool causal_attn) {
if (causal_attn) {
set_input_kq_mask_impl<T, true> (args, data);
} else {
set_input_kq_mask_impl<T, false>(args, data);
}
}
@ -1610,7 +1622,6 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
const uint32_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
float * data = (float *) dst->data;
const int64_t n_kv = dst->ne[0];
const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
@ -1634,10 +1645,10 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
/*.n_tps =*/ n_tps,
};
if (causal_attn) {
set_input_kq_mask_impl<true> (args, data);
if (dst->type == GGML_TYPE_F16) {
set_input_kq_mask_impl<ggml_fp16_t>(args, (ggml_fp16_t *) dst->data, causal_attn);
} else {
set_input_kq_mask_impl<false>(args, data);
set_input_kq_mask_impl<float>(args, (float *) dst->data, causal_attn);
}
//const int64_t t_end = ggml_time_us();