diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1a4fa6921..9ce4c4479 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -29,7 +29,10 @@ static ggml_tensor * build_attn_inp_kq_mask( const auto n_tokens = ubatch.n_tokens; const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + // flash attention requires an f16 mask + const auto type = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; + + ggml_tensor * res = ggml_new_tensor_4d(ctx, type, n_kv, n_tokens/n_stream, 1, n_stream); ggml_set_input(res); ggml_set_name(res, "attn_inp_kq_mask"); @@ -381,7 +384,8 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { } } -static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { +template +static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__); const char * swa_type_str = "unknown"; @@ -405,7 +409,7 @@ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64 for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) { LLAMA_LOG_DEBUG(" %2d ", i); for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) { - float val = data[i * n_kv + j]; + float val = llama_cast(data[i * n_kv + j]); if (val == -INFINITY) { LLAMA_LOG_DEBUG(" ∞"); } else { @@ -420,7 +424,10 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { const int64_t n_kv = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens; - const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) { + const auto fill_mask = [&](auto * data, int64_t ne, int n_swa, llama_swa_type swa_type) { + using T = std::remove_reference_t; + std::fill(data, data + ne, llama_cast(-INFINITY)); + for (int i1 = 0; i1 < n_tokens; ++i1) { const llama_seq_id s1 = ubatch->seq_id[i1][0]; const llama_pos p1 = ubatch->pos[i1]; @@ -446,38 +453,30 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { continue; } - data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; + data[idst + i0] = llama_cast(hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f); } } + + if (debug) { + print_mask(data, n_tokens, n_kv, n_swa, swa_type); + } }; - { - GGML_ASSERT(self_kq_mask); - GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); - - float * data = (float *) self_kq_mask->data; - - std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY); - - fill_mask(data, 0, LLAMA_SWA_TYPE_NONE); - - if (debug) { - print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE); - } + GGML_ASSERT(self_kq_mask); + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); + if (self_kq_mask->type == GGML_TYPE_F16) { + fill_mask((ggml_fp16_t *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE); + } else { + fill_mask((float *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE); } if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(self_kq_mask_swa); GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer)); - - float * data = (float *) self_kq_mask_swa->data; - - std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY); - - fill_mask(data, hparams.n_swa, hparams.swa_type); - - if (debug) { - print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type); + if (self_kq_mask_swa->type == GGML_TYPE_F16) { + fill_mask((ggml_fp16_t *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type); + } else { + fill_mask((float *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type); } } } @@ -601,23 +600,30 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer)); GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing - float * data = (float *) cross_kq_mask->data; + const auto fill_mask = [&](auto * data) { + using T = std::remove_reference_t; + for (int i = 0; i < n_tokens; ++i) { + GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first"); + for (int j = 0; j < n_enc; ++j) { + float f = -INFINITY; - for (int i = 0; i < n_tokens; ++i) { - GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first"); - for (int j = 0; j < n_enc; ++j) { - float f = -INFINITY; + for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[i][s]; - for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { - const llama_seq_id seq_id = ubatch->seq_id[i][s]; - - if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) { - f = 0.0f; + if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) { + f = 0.0f; + } } - } - data[i*n_enc + j] = f; + data[i*n_enc + j] = llama_cast(f); + } } + }; + + if (cross_kq_mask->type == GGML_TYPE_F16) { + fill_mask((ggml_fp16_t *) cross_kq_mask->data); + } else { + fill_mask((float *) cross_kq_mask->data); } } @@ -2121,17 +2127,20 @@ ggml_tensor * llm_graph_context::build_attn_mha( llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const { auto inp = std::make_unique(hparams, cparams); + // flash attention requires an f16 mask + const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; + // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1); ggml_set_input(inp->self_kq_mask); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask_cnv = inp->self_kq_mask; if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1); + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1); ggml_set_input(inp->self_kq_mask_swa); - inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa; } else { inp->self_kq_mask_swa = nullptr; inp->self_kq_mask_swa_cnv = nullptr; @@ -2208,7 +2217,7 @@ static std::unique_ptr build_attn_inp_kv_impl( inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask_cnv = inp->self_kq_mask; } inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0); @@ -2315,7 +2324,7 @@ static std::unique_ptr build_attn_inp_k_impl( inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask_cnv = inp->self_kq_mask; } return inp; @@ -2479,10 +2488,13 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; - inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1); + // flash attention requires an f16 mask + const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; + + inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_enc, n_tokens, 1, 1); ggml_set_input(inp->cross_kq_mask); - inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask; + inp->cross_kq_mask_cnv = inp->cross_kq_mask; return (llm_graph_input_attn_cross *) res->add_input(std::move(inp)); } @@ -2543,7 +2555,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask_cnv = inp->self_kq_mask; } { @@ -2553,7 +2565,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); - inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa; } inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0); @@ -2722,7 +2734,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); - inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask; + inp_attn->self_kq_mask_cnv = inp_attn->self_kq_mask; } { @@ -2730,7 +2742,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); - inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa; + inp_attn->self_kq_mask_swa_cnv = inp_attn->self_kq_mask_swa; } auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); diff --git a/src/llama-graph.h b/src/llama-graph.h index e240ade7b..9f4816959 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -291,10 +291,10 @@ public: ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } // n_tokens == n_batch - ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa = nullptr; // F32/F16 [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] const llama_hparams hparams; const llama_cparams cparams; @@ -324,8 +324,8 @@ public: ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] // note: assumes v_rot^2 == I ggml_tensor * self_k_rot = nullptr; @@ -364,8 +364,8 @@ public: ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] const llama_hparams hparams; const llama_cparams cparams; @@ -402,10 +402,10 @@ public: ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch] ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] ggml_tensor * self_k_rot = nullptr; ggml_tensor * self_v_rot = nullptr; @@ -428,8 +428,8 @@ public: ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; } - ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] - ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] + ggml_tensor * cross_kq_mask = nullptr; // F32/F16 [n_outputs_enc, n_batch, 1, 1] + ggml_tensor * cross_kq_mask_cnv = nullptr; // F32/F16 [n_outputs_enc, n_batch, 1, 1] const llama_cross * cross = nullptr; }; diff --git a/src/llama-impl.h b/src/llama-impl.h index e4f35c8e5..7923c3f7e 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -3,6 +3,7 @@ #include "ggml.h" // for ggml_log_level #include +#include #include #ifdef __GNUC__ @@ -40,6 +41,19 @@ struct no_init { no_init() = default; }; +template +static inline dst_t llama_cast(src_t v) { + if constexpr (std::is_same_v) { + return v; + } else if constexpr (std::is_same_v && std::is_same_v) { + return ggml_fp16_to_fp32(v); + } else if constexpr (std::is_same_v && std::is_same_v) { + return ggml_fp32_to_fp16(v); + } else { + static_assert(std::is_same_v, "unsupported type combination"); + } +} + struct time_meas { time_meas(int64_t & t_acc, bool disable = false); ~time_meas(); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index a49a055a6..2356d612b 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1430,8 +1430,8 @@ struct args_set_input_kq_mask { int64_t n_tps; }; -template -static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { //const auto & hparams = args.hparams; const auto & ubatch = args.ubatch; @@ -1445,6 +1445,9 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * const int64_t n_stream = args.n_stream; const int64_t n_tps = args.n_tps; + const T mask_keep = llama_cast(0.0f); + const T mask_drop = llama_cast(-INFINITY); + // the min position in the batch for each sequence llama_pos seq_pos_min[LLAMA_MAX_SEQ]; std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX); @@ -1563,46 +1566,55 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * } if (alibi) { - data[idst + j] = -std::abs(p0 - p1); + data[idst + j] = llama_cast(static_cast(-std::abs(p0 - p1))); } else { - data[idst + j] = 0.0f; + data[idst + j] = mask_keep; } continue; skip: - data[idst + j] = -INFINITY; + data[idst + j] = mask_drop; } } } } -template -static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { const bool alibi = args.hparams.use_alibi; if (alibi) { - set_input_kq_mask_impl (args, data); + set_input_kq_mask_impl (args, data); } else { - set_input_kq_mask_impl(args, data); + set_input_kq_mask_impl(args, data); } } -template -static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { const bool is_2d = args.ubatch->is_pos_2d(); if (is_2d) { - set_input_kq_mask_impl (args, data); + set_input_kq_mask_impl (args, data); } else { - set_input_kq_mask_impl(args, data); + set_input_kq_mask_impl(args, data); } } -template -static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE; if (swa) { - set_input_kq_mask_impl (args, data); + set_input_kq_mask_impl (args, data); } else { - set_input_kq_mask_impl(args, data); + set_input_kq_mask_impl(args, data); + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data, bool causal_attn) { + if (causal_attn) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); } } @@ -1610,7 +1622,6 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - float * data = (float *) dst->data; const int64_t n_kv = dst->ne[0]; const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch @@ -1634,10 +1645,10 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u /*.n_tps =*/ n_tps, }; - if (causal_attn) { - set_input_kq_mask_impl (args, data); + if (dst->type == GGML_TYPE_F16) { + set_input_kq_mask_impl(args, (ggml_fp16_t *) dst->data, causal_attn); } else { - set_input_kq_mask_impl(args, data); + set_input_kq_mask_impl(args, (float *) dst->data, causal_attn); } //const int64_t t_end = ggml_time_us();