mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-30 20:33:39 +00:00
llama: use f16 mask for FA to save VRAM (#23764)
* llama: use f16 mask for FA * review: add llama_cast + formatting * simplify
This commit is contained in:
parent
fe12e422ad
commit
031ddb2e08
4 changed files with 123 additions and 86 deletions
|
|
@ -29,7 +29,10 @@ static ggml_tensor * build_attn_inp_kq_mask(
|
|||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
// flash attention requires an f16 mask
|
||||
const auto type = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
ggml_tensor * res = ggml_new_tensor_4d(ctx, type, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_set_input(res);
|
||||
ggml_set_name(res, "attn_inp_kq_mask");
|
||||
|
||||
|
|
@ -381,7 +384,8 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
|||
}
|
||||
}
|
||||
|
||||
static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
|
||||
template <typename T>
|
||||
static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
|
||||
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
|
||||
const char * swa_type_str = "unknown";
|
||||
|
||||
|
|
@ -405,7 +409,7 @@ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64
|
|||
for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
|
||||
LLAMA_LOG_DEBUG(" %2d ", i);
|
||||
for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
|
||||
float val = data[i * n_kv + j];
|
||||
float val = llama_cast<float>(data[i * n_kv + j]);
|
||||
if (val == -INFINITY) {
|
||||
LLAMA_LOG_DEBUG(" ∞");
|
||||
} else {
|
||||
|
|
@ -420,7 +424,10 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|||
const int64_t n_kv = ubatch->n_tokens;
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
|
||||
const auto fill_mask = [&](auto * data, int64_t ne, int n_swa, llama_swa_type swa_type) {
|
||||
using T = std::remove_reference_t<decltype(*data)>;
|
||||
std::fill(data, data + ne, llama_cast<T>(-INFINITY));
|
||||
|
||||
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
||||
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
||||
const llama_pos p1 = ubatch->pos[i1];
|
||||
|
|
@ -446,38 +453,30 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|||
continue;
|
||||
}
|
||||
|
||||
data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
||||
data[idst + i0] = llama_cast<T>(hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
if (debug) {
|
||||
print_mask(data, n_tokens, n_kv, n_swa, swa_type);
|
||||
}
|
||||
};
|
||||
|
||||
{
|
||||
GGML_ASSERT(self_kq_mask);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
||||
|
||||
float * data = (float *) self_kq_mask->data;
|
||||
|
||||
std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
|
||||
|
||||
fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
|
||||
|
||||
if (debug) {
|
||||
print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
|
||||
}
|
||||
GGML_ASSERT(self_kq_mask);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
||||
if (self_kq_mask->type == GGML_TYPE_F16) {
|
||||
fill_mask((ggml_fp16_t *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
|
||||
} else {
|
||||
fill_mask((float *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
|
||||
}
|
||||
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
GGML_ASSERT(self_kq_mask_swa);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
|
||||
|
||||
float * data = (float *) self_kq_mask_swa->data;
|
||||
|
||||
std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
|
||||
|
||||
fill_mask(data, hparams.n_swa, hparams.swa_type);
|
||||
|
||||
if (debug) {
|
||||
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
|
||||
if (self_kq_mask_swa->type == GGML_TYPE_F16) {
|
||||
fill_mask((ggml_fp16_t *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type);
|
||||
} else {
|
||||
fill_mask((float *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -601,23 +600,30 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|||
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
|
||||
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
|
||||
|
||||
float * data = (float *) cross_kq_mask->data;
|
||||
const auto fill_mask = [&](auto * data) {
|
||||
using T = std::remove_reference_t<decltype(*data)>;
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first");
|
||||
for (int j = 0; j < n_enc; ++j) {
|
||||
float f = -INFINITY;
|
||||
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first");
|
||||
for (int j = 0; j < n_enc; ++j) {
|
||||
float f = -INFINITY;
|
||||
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
||||
|
||||
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
||||
|
||||
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
|
||||
f = 0.0f;
|
||||
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
|
||||
f = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data[i*n_enc + j] = f;
|
||||
data[i*n_enc + j] = llama_cast<T>(f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (cross_kq_mask->type == GGML_TYPE_F16) {
|
||||
fill_mask((ggml_fp16_t *) cross_kq_mask->data);
|
||||
} else {
|
||||
fill_mask((float *) cross_kq_mask->data);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2121,17 +2127,20 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
|
||||
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
||||
|
||||
// flash attention requires an f16 mask
|
||||
const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
inp->self_kq_mask_cnv = inp->self_kq_mask;
|
||||
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa;
|
||||
} else {
|
||||
inp->self_kq_mask_swa = nullptr;
|
||||
inp->self_kq_mask_swa_cnv = nullptr;
|
||||
|
|
@ -2208,7 +2217,7 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
|
|||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
inp->self_kq_mask_cnv = inp->self_kq_mask;
|
||||
}
|
||||
|
||||
inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0);
|
||||
|
|
@ -2315,7 +2324,7 @@ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
|
|||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
inp->self_kq_mask_cnv = inp->self_kq_mask;
|
||||
}
|
||||
|
||||
return inp;
|
||||
|
|
@ -2479,10 +2488,13 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
|||
|
||||
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
||||
|
||||
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1);
|
||||
// flash attention requires an f16 mask
|
||||
const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_enc, n_tokens, 1, 1);
|
||||
ggml_set_input(inp->cross_kq_mask);
|
||||
|
||||
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
|
||||
inp->cross_kq_mask_cnv = inp->cross_kq_mask;
|
||||
|
||||
return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
|
@ -2543,7 +2555,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|||
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
inp->self_kq_mask_cnv = inp->self_kq_mask;
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -2553,7 +2565,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|||
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa;
|
||||
}
|
||||
|
||||
inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
|
||||
|
|
@ -2722,7 +2734,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
|
|||
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
|
||||
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
|
||||
inp_attn->self_kq_mask_cnv = inp_attn->self_kq_mask;
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -2730,7 +2742,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
|
|||
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
|
||||
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
|
||||
inp_attn->self_kq_mask_swa_cnv = inp_attn->self_kq_mask_swa;
|
||||
}
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
||||
|
|
|
|||
|
|
@ -291,10 +291,10 @@ public:
|
|||
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||
|
||||
// n_tokens == n_batch
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32/F16 [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
const llama_hparams hparams;
|
||||
const llama_cparams cparams;
|
||||
|
|
@ -324,8 +324,8 @@ public:
|
|||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
// note: assumes v_rot^2 == I
|
||||
ggml_tensor * self_k_rot = nullptr;
|
||||
|
|
@ -364,8 +364,8 @@ public:
|
|||
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
const llama_hparams hparams;
|
||||
const llama_cparams cparams;
|
||||
|
|
@ -402,10 +402,10 @@ public:
|
|||
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
ggml_tensor * self_k_rot = nullptr;
|
||||
ggml_tensor * self_v_rot = nullptr;
|
||||
|
|
@ -428,8 +428,8 @@ public:
|
|||
|
||||
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
|
||||
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
|
||||
ggml_tensor * cross_kq_mask = nullptr; // F32/F16 [n_outputs_enc, n_batch, 1, 1]
|
||||
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32/F16 [n_outputs_enc, n_batch, 1, 1]
|
||||
|
||||
const llama_cross * cross = nullptr;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include "ggml.h" // for ggml_log_level
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#ifdef __GNUC__
|
||||
|
|
@ -40,6 +41,19 @@ struct no_init {
|
|||
no_init() = default;
|
||||
};
|
||||
|
||||
template <typename dst_t, typename src_t>
|
||||
static inline dst_t llama_cast(src_t v) {
|
||||
if constexpr (std::is_same_v<src_t, dst_t>) {
|
||||
return v;
|
||||
} else if constexpr (std::is_same_v<src_t, ggml_fp16_t> && std::is_same_v<dst_t, float>) {
|
||||
return ggml_fp16_to_fp32(v);
|
||||
} else if constexpr (std::is_same_v<src_t, float> && std::is_same_v<dst_t, ggml_fp16_t>) {
|
||||
return ggml_fp32_to_fp16(v);
|
||||
} else {
|
||||
static_assert(std::is_same_v<dst_t, void>, "unsupported type combination");
|
||||
}
|
||||
}
|
||||
|
||||
struct time_meas {
|
||||
time_meas(int64_t & t_acc, bool disable = false);
|
||||
~time_meas();
|
||||
|
|
|
|||
|
|
@ -1430,8 +1430,8 @@ struct args_set_input_kq_mask {
|
|||
int64_t n_tps;
|
||||
};
|
||||
|
||||
template<bool causal, bool swa, bool is_2d, bool alibi>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
||||
template<typename T, bool causal, bool swa, bool is_2d, bool alibi>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) {
|
||||
//const auto & hparams = args.hparams;
|
||||
const auto & ubatch = args.ubatch;
|
||||
|
||||
|
|
@ -1445,6 +1445,9 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float *
|
|||
const int64_t n_stream = args.n_stream;
|
||||
const int64_t n_tps = args.n_tps;
|
||||
|
||||
const T mask_keep = llama_cast<T>(0.0f);
|
||||
const T mask_drop = llama_cast<T>(-INFINITY);
|
||||
|
||||
// the min position in the batch for each sequence
|
||||
llama_pos seq_pos_min[LLAMA_MAX_SEQ];
|
||||
std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX);
|
||||
|
|
@ -1563,46 +1566,55 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float *
|
|||
}
|
||||
|
||||
if (alibi) {
|
||||
data[idst + j] = -std::abs(p0 - p1);
|
||||
data[idst + j] = llama_cast<T>(static_cast<float>(-std::abs(p0 - p1)));
|
||||
} else {
|
||||
data[idst + j] = 0.0f;
|
||||
data[idst + j] = mask_keep;
|
||||
}
|
||||
|
||||
continue;
|
||||
skip:
|
||||
data[idst + j] = -INFINITY;
|
||||
data[idst + j] = mask_drop;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool causal, bool swa, bool is_2d>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
||||
template<typename T, bool causal, bool swa, bool is_2d>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) {
|
||||
const bool alibi = args.hparams.use_alibi;
|
||||
if (alibi) {
|
||||
set_input_kq_mask_impl<causal, swa, is_2d, true> (args, data);
|
||||
set_input_kq_mask_impl<T, causal, swa, is_2d, true> (args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl<causal, swa, is_2d, false>(args, data);
|
||||
set_input_kq_mask_impl<T, causal, swa, is_2d, false>(args, data);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool causal, bool swa>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
||||
template<typename T, bool causal, bool swa>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) {
|
||||
const bool is_2d = args.ubatch->is_pos_2d();
|
||||
if (is_2d) {
|
||||
set_input_kq_mask_impl<causal, swa, true> (args, data);
|
||||
set_input_kq_mask_impl<T, causal, swa, true> (args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl<causal, swa, false>(args, data);
|
||||
set_input_kq_mask_impl<T, causal, swa, false>(args, data);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool causal>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
|
||||
template<typename T, bool causal>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) {
|
||||
const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE;
|
||||
if (swa) {
|
||||
set_input_kq_mask_impl<causal, true> (args, data);
|
||||
set_input_kq_mask_impl<T, causal, true> (args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl<causal, false>(args, data);
|
||||
set_input_kq_mask_impl<T, causal, false>(args, data);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data, bool causal_attn) {
|
||||
if (causal_attn) {
|
||||
set_input_kq_mask_impl<T, true> (args, data);
|
||||
} else {
|
||||
set_input_kq_mask_impl<T, false>(args, data);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1610,7 +1622,6 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
|
|||
const uint32_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
float * data = (float *) dst->data;
|
||||
|
||||
const int64_t n_kv = dst->ne[0];
|
||||
const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
|
||||
|
|
@ -1634,10 +1645,10 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
|
|||
/*.n_tps =*/ n_tps,
|
||||
};
|
||||
|
||||
if (causal_attn) {
|
||||
set_input_kq_mask_impl<true> (args, data);
|
||||
if (dst->type == GGML_TYPE_F16) {
|
||||
set_input_kq_mask_impl<ggml_fp16_t>(args, (ggml_fp16_t *) dst->data, causal_attn);
|
||||
} else {
|
||||
set_input_kq_mask_impl<false>(args, data);
|
||||
set_input_kq_mask_impl<float>(args, (float *) dst->data, causal_attn);
|
||||
}
|
||||
|
||||
//const int64_t t_end = ggml_time_us();
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue