diff --git a/common/arg.cpp b/common/arg.cpp index 49c1f5f6c..bdcc875bb 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -624,10 +624,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context for (auto & seq_breaker : params.sampling.dry_sequence_breakers) { string_process_escapes(seq_breaker); } - for (auto & pair : params.speculative.draft.replacements) { - string_process_escapes(pair.first); - string_process_escapes(pair.second); - } } if (!params.kv_overrides.empty()) { @@ -3520,13 +3516,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.draft.p_min = std::stof(value); } ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_P_MIN")); - add_opt(common_arg( - {"--spec-draft-ctx-size", "-cd", "--ctx-size-draft"}, "N", - string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.draft.n_ctx), - [](common_params & params, int value) { - params.speculative.draft.n_ctx = value; - } - ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_CTX_SIZE")); add_opt(common_arg( {"--spec-draft-device", "-devd", "--device-draft"}, "", "comma-separated list of devices to use for offloading the draft model (none = don't offload)\n" @@ -3563,32 +3552,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL")); add_opt(common_arg( - {"--spec-draft-replace", "--spec-replace"}, "TARGET", "DRAFT", - "translate the string in TARGET into DRAFT if the draft model and main model are not compatible", - [](common_params & params, const std::string & tgt, const std::string & dft) { - params.speculative.draft.replacements.push_back({ tgt, dft }); - } - ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); - add_opt(common_arg( - {"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]", + {"--spec-type"}, common_speculative_all_types_str(), string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n", - common_speculative_type_to_str(params.speculative.type).c_str()), + common_speculative_type_name_str(params.speculative.types).c_str()), [](common_params & params, const std::string & value) { - if (value == "none") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE; - } else if (value == "ngram-cache") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE; - } else if (value == "ngram-simple") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE; - } else if (value == "ngram-map-k") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K; - } else if (value == "ngram-map-k4v") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V; - } else if (value == "ngram-mod") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD; - } else { - throw std::invalid_argument("unknown speculative decoding type without draft model"); - } + const auto enabled_types = string_split(value, ','); + params.speculative.types = common_speculative_types_from_names(enabled_types); } ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_TYPE")); add_opt(common_arg( @@ -4077,7 +4046,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--spec-default"}, string_format("enable default speculative decoding config"), [](common_params & params) { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD; + params.speculative.types = { COMMON_SPECULATIVE_TYPE_NGRAM_MOD }; params.speculative.ngram_mod.n_match = 24; params.speculative.ngram_mod.n_min = 48; params.speculative.ngram_mod.n_max = 64; diff --git a/common/common.cpp b/common/common.cpp index 12236bb2c..5ce1c4f52 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1428,7 +1428,7 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) { // try to remove the last tokens if (!llama_memory_seq_rm(mem, 0, 1, -1)) { - LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); + LOG_WRN("%s: the context does not support partial sequence removal\n", __func__); res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL; goto done; } @@ -1966,3 +1966,102 @@ bool common_prompt_batch_decode( return true; } + +size_t common_prompt_checkpoint::size() const { + return data_tgt.size() + data_dft.size(); +} + +bool common_prompt_checkpoint::empty() const { + return data_tgt.empty(); +} + +void common_prompt_checkpoint::clear() { + n_tokens = 0; + + pos_min = 0; + pos_max = 0; + + data_tgt.clear(); + data_dft.clear(); +} + +void common_prompt_checkpoint::update_pos( + int64_t n_tokens, + llama_pos pos_min, + llama_pos pos_max) { + this->n_tokens = n_tokens; + this->pos_min = pos_min; + this->pos_max = pos_max; +} + +void common_prompt_checkpoint::update_tgt( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) { + if (ctx == nullptr) { + return; + } + + const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags); + + data_tgt.resize(ckpt_size); + + const size_t n = llama_state_seq_get_data_ext(ctx, data_tgt.data(), ckpt_size, seq_id, flags); + if (n != ckpt_size) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n); + } +} + +void common_prompt_checkpoint::update_dft( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) { + if (ctx == nullptr) { + return; + } + + const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags); + + data_dft.resize(ckpt_size); + + const size_t n = llama_state_seq_get_data_ext(ctx, data_dft.data(), ckpt_size, seq_id, flags); + if (n != ckpt_size) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n); + } +} + +void common_prompt_checkpoint::load_tgt( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) const { + if (ctx == nullptr) { + return; + } + + if (data_tgt.empty()) { + return; + } + + const size_t n = llama_state_seq_set_data_ext(ctx, data_tgt.data(), data_tgt.size(), seq_id, flags); + if (n != data_tgt.size()) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_tgt.size(), n); + } +} + +void common_prompt_checkpoint::load_dft( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) const { + if (ctx == nullptr) { + return; + } + + if (data_dft.empty()) { + return; + } + + const size_t n = llama_state_seq_set_data_ext(ctx, data_dft.data(), data_dft.size(), seq_id, flags); + if (n != data_dft.size()) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_dft.size(), n); + } +} diff --git a/common/common.h b/common/common.h index 3cc27a40a..a635d17c6 100644 --- a/common/common.h +++ b/common/common.h @@ -296,8 +296,6 @@ struct common_params_model { std::string name = ""; // in format /[:] (tag is optional) // NOLINT }; -struct common_ngram_mod; - // draft-model-based speculative decoding parameters struct common_params_speculative_draft { int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding @@ -308,11 +306,9 @@ struct common_params_speculative_draft { common_params_model mparams; - llama_model * model = nullptr; // a llama_model that can be shared by multiple speculative contexts + llama_context * ctx_tgt = nullptr; + llama_context * ctx_dft = nullptr; - llama_context_params cparams; // these are the parameters for the draft llama_context - - int32_t n_ctx = 0; // draft context size int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K @@ -323,7 +319,6 @@ struct common_params_speculative_draft { std::vector devices; // devices to use for offloading - std::vector> replacements; // main to speculative model replacements std::vector tensor_buft_overrides; }; @@ -332,9 +327,6 @@ struct common_params_speculative_ngram_mod { int32_t n_max = 64; int32_t n_min = 48; - - // shared instance of the ngram container for all speculative decoding contexts - std::shared_ptr obj; }; struct common_params_speculative_ngram_map { @@ -349,8 +341,7 @@ struct common_params_speculative_ngram_cache { }; struct common_params_speculative { - // TODO: become a vector in order to support "chains of speculators" - common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; + std::vector types = { COMMON_SPECULATIVE_TYPE_NONE }; common_params_speculative_draft draft; @@ -1027,3 +1018,47 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std // "adamw" or "sgd" (case insensitive) enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *); + +// +// prompt utils +// + +struct common_prompt_checkpoint { + int64_t n_tokens; + + llama_pos pos_min; + llama_pos pos_max; + + std::vector data_tgt; + std::vector data_dft; + + size_t size() const; + + bool empty() const; + void clear(); + + void update_pos( + int64_t n_tokens, + llama_pos pos_min, + llama_pos pos_max); + + void update_tgt( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags); + + void update_dft( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags); + + void load_tgt( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) const; + + void load_dft( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) const; +}; diff --git a/common/speculative.cpp b/common/speculative.cpp index e9fa751e2..e487e003d 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -10,6 +10,7 @@ #include "sampling.h" #include +#include #include #include #include @@ -18,26 +19,15 @@ #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 -const std::vector common_speculative_types = { - COMMON_SPECULATIVE_TYPE_NONE, - COMMON_SPECULATIVE_TYPE_DRAFT, - COMMON_SPECULATIVE_TYPE_EAGLE3, - COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, - COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, - COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, - COMMON_SPECULATIVE_TYPE_NGRAM_MOD, - COMMON_SPECULATIVE_TYPE_NGRAM_CACHE -}; - -const std::map common_speculative_type_from_name_map = { +const std::map common_speculative_type_from_name_map = { {"none", COMMON_SPECULATIVE_TYPE_NONE}, {"draft", COMMON_SPECULATIVE_TYPE_DRAFT}, {"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3}, - {"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, - {"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, - {"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, - {"ngram_mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD}, - {"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE} + {"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, + {"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, + {"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, + {"ngram-mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD}, + {"ngram-cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE} }; struct common_speculative_config { @@ -115,12 +105,16 @@ static bool common_speculative_are_compatible( return true; } +using common_speculative_draft_params_vec = std::vector; + // state of an implementation of speculative decoding // // each implementation has a unique type and a state that is implementation-specific -// in a subclass of common_speculative_state -struct common_speculative_state { - const enum common_speculative_type type; +// in a subclass of common_speculative_impl +struct common_speculative_impl { + const common_speculative_type type; + + uint32_t n_seq; size_t n_call_begin = 0; // number of times this implementation was called for refresh. size_t n_call_draft = 0; // number of times this implementation was called for generation. @@ -138,65 +132,34 @@ struct common_speculative_state { int64_t t_draft_us = 0; // total time spent in generating drafts in this implementation in microseconds. int64_t t_accept_us = 0; // total time spent in accumulation of this implementation in microseconds. - common_speculative_state(enum common_speculative_type type) : type(type) {} + common_speculative_impl(common_speculative_type type, uint32_t n_seq) : type(type), n_seq(n_seq) {} - virtual ~common_speculative_state() = default; + virtual ~common_speculative_impl() = default; - virtual void begin(const llama_tokens & prompt) = 0; + virtual void begin(llama_seq_id seq_id, const llama_tokens & prompt) = 0; - virtual void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) = 0; + virtual bool process(const llama_batch & batch) = 0; - virtual void accept(uint16_t n_accepted) = 0; + virtual void draft(common_speculative_draft_params_vec & dparams) = 0; - virtual int32_t n_max(const common_params_speculative & params) const = 0; - virtual int32_t n_min(const common_params_speculative & params) const = 0; + virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0; }; -struct common_speculative_checkpoint { - llama_pos pos_min = 0; - llama_pos pos_max = 0; +struct common_speculative_state_draft : public common_speculative_impl { + common_params_speculative_draft params; - int64_t n_tokens = 0; + llama_batch batch; - std::vector data; + std::vector smpls; - size_t size() const { - return data.size(); - } -}; - -struct common_speculative_state_draft : public common_speculative_state { - llama_context * ctx_tgt; // only used for retokenizing from ctx_dft - llama_context * ctx_dft; - - bool use_ckpt = false; - common_speculative_checkpoint ckpt; - - common_sampler * smpl; - - llama_batch batch; - llama_tokens prompt_dft; - - bool vocab_cmpt = true; // whether retokenization is needed - std::unordered_map vocab_map; - - common_speculative_state_draft( - enum common_speculative_type type, - llama_context * ctx_tgt, - llama_context * ctx_dft, - const std::vector> & replacements, - bool use_ckpt) - : common_speculative_state(type) - , ctx_tgt(ctx_tgt) - , ctx_dft(ctx_dft) - , use_ckpt(use_ckpt) + common_speculative_state_draft(const common_params_speculative & params, uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT, n_seq) + , params(params.draft) { + auto * ctx_dft = this->params.ctx_dft; + auto * ctx_tgt = this->params.ctx_tgt; + batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); - smpl = nullptr; // TODO: optimize or pass from outside? // { @@ -214,7 +177,9 @@ struct common_speculative_state_draft : public common_speculative_state { // // result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); // } - { + + smpls.resize(n_seq); + for (auto & smpl : smpls) { common_params_sampling params; params.no_perf = false; params.top_k = 10; @@ -222,482 +187,321 @@ struct common_speculative_state_draft : public common_speculative_state { COMMON_SAMPLER_TYPE_TOP_K, }; - smpl = common_sampler_init(llama_get_model(ctx_dft), params); + smpl.reset(common_sampler_init(llama_get_model(ctx_dft), params)); } - vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); - LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt); + const bool vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); + LOG_DBG("%s: vocab_cmpt = %d\n", __func__, vocab_cmpt); if (!vocab_cmpt) { - LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n"); + LOG_ERR("%s: the target and draft vocabs are not compatible\n", __func__); - for (const auto & pair : replacements) { - vocab_map[pair.first] = pair.second; - } + throw std::runtime_error("draft model vocab type must match target model to use speculation"); + } + + if (n_seq != llama_n_seq_max(ctx_dft)) { + LOG_ERR("%s: n_seq mismatch: %d != %d\n", __func__, n_seq, llama_n_seq_max(ctx_dft)); + + throw std::runtime_error("the draft model number of sequences is incompatible with the speculative n_seq"); } } ~common_speculative_state_draft() override { - llama_perf_context_print(ctx_dft); - - llama_free(ctx_dft); - - common_sampler_free(smpl); - llama_batch_free(batch); } - void begin(const llama_tokens & /*prompt*/) override { + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { + // noop } - size_t create_checkpoint(int n_tokens_prompt) { - int slot_id = 0; - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + bool process(const llama_batch & batch) override { + auto * ctx_dft = params.ctx_dft; - ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id); - ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id); - ckpt.n_tokens = n_tokens_prompt; - ckpt.data.resize(checkpoint_size); + const int ret = llama_decode(ctx_dft, batch); - const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - if (n != checkpoint_size) { - GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); + if (ret != 0) { + LOG_ERR("%s: failed to decode draft batch, ret = %d\n", __func__, ret); + + return false; } - LOG_DBG("%s: pos_min = %d, pos_max = %d, size = %.3f MiB\n", __func__, - ckpt.pos_min, ckpt.pos_max, (float) ckpt.data.size() / 1024 / 1024); - return n; + return true; } - size_t restore_checkpoint() { - int slot_id = 0; - LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max); - const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - if (n != ckpt.size()) { - GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu", - __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size()); - } - llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1); + void draft(common_speculative_draft_params_vec & dparams) override { + auto & ctx_dft = params.ctx_dft; - return n; - } + common_batch_clear(batch); - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - const auto & sparams = params.draft; + // keep track of which sequences are still drafting + int n_drafting = 0; + std::vector drafting(n_seq); - auto * spec = this; + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; - auto & batch = spec->batch; - auto & ctx_tgt = spec->ctx_tgt; - auto & ctx_dft = spec->ctx_dft; - auto & smpl = spec->smpl; - auto & prompt_dft = spec->prompt_dft; + if (!dp.drafting) { + continue; + } - auto * mem_dft = llama_get_memory(ctx_dft); + n_drafting++; + drafting[seq_id] = true; + common_sampler_reset(smpls[seq_id].get()); - int reuse_i = 0; // index of part to be reused in prompt_dft - int reuse_n = 0; // length of part to be reused in prompt_dft - - const int n_ctx = llama_n_ctx(ctx_dft) - sparams.n_max; - - llama_tokens prompt_cnv; - if (!spec->vocab_cmpt) { - std::string text; - - text = common_detokenize(ctx_tgt, prompt_tgt, true); - text = replace_to_dft(text); - - LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str()); - - prompt_cnv = common_tokenize(ctx_dft, text, false, true); - - // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation - const auto * model_tgt = llama_get_model(ctx_tgt); - const auto * vocab_tgt = llama_model_get_vocab(model_tgt); - - int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false); - GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last"); - - text.resize(-n_chars); - llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false); - text = replace_to_dft(text); - - LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str()); - id_last = common_tokenize(ctx_dft, text, false, true)[0]; + common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true); } - const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv; - - const int i_start = std::max(0, (int) prompt_cur.size() - n_ctx); - - if (use_ckpt && i_start > 0) { - LOG_WRN("%s: context shift is not supported with checkpoint-based contexts - skipping\n", __func__); + int ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode returned %d\n", __func__, ret); return; } - // reuse as much as possible from the old draft context - // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt - for (int i = 0; i < (int) prompt_dft.size(); ++i) { - int cur = 0; - while (i_start + cur < (int) prompt_cur.size() && - i + cur < (int) prompt_dft.size() && - prompt_cur[i_start + cur] == prompt_dft[i + cur]) { - cur++; - } + int i = 0; - if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) { - reuse_i = i; - reuse_n = cur; - } + while (n_drafting > 0) { + int i_batch = 0; - if (use_ckpt) { - break; - } - } - - LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n", - __func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size()); - if (use_ckpt && ckpt.n_tokens > reuse_n) { - LOG_DBG("%s: checkpoint (n_tokens = %d) is outdated -> delete it\n", __func__, (int) ckpt.n_tokens); - - reuse_i = 0; - reuse_n = 0; - - ckpt = {}; - } - - result.clear(); - result.reserve(sparams.n_max); - - if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) { - llama_memory_clear(mem_dft, false); - prompt_dft.clear(); - } else { - // this happens when a previous draft has been discarded (for example, due to being too small), but the - // target model agreed with it. in this case, we simply pass back the previous results to save compute - if (reuse_i + reuse_n < (int64_t) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { - for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { - result.push_back(prompt_dft[i]); - - if (sparams.n_max <= (int) result.size()) { - break; - } - } - - return; - } - - if (reuse_i > 0) { - GGML_ASSERT(!use_ckpt); - - bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); - if (!is_removed) { - LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i); - return; - } - llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i); - - prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); - } - - if (reuse_n < (int) prompt_dft.size()) { - if (use_ckpt) { - if (ckpt.n_tokens > 0) { - LOG_DBG("%s: restoring checkpoint, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size()); - restore_checkpoint(); - reuse_n = ckpt.n_tokens; - prompt_dft.resize(reuse_n); - } - } else { - const bool is_removed = llama_memory_seq_rm(mem_dft, 0, reuse_n, -1); - if (!is_removed) { - LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size()); - return; - } - prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); - } - } - } - - // prepare a batch to evaluate any new tokens in the prompt - common_batch_clear(batch); - - for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) { - //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]); - common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false); - - prompt_dft.push_back(prompt_cur[i]); - } - - // we should rarely end-up here during normal decoding - if (batch.n_tokens > 0) { - //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); - LOG_DBG("%s: draft prompt batch: %d tokens\n", __func__, batch.n_tokens); - - int ret = llama_decode(ctx_dft, batch); - if (ret != 0 && ret != 1) { - LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n", - __func__, ret, prompt_cur.size()); - } - - if (use_ckpt) { - create_checkpoint(prompt_dft.size()); - } - } - - const llama_pos n_past = prompt_dft.size(); - - LOG_DBG("%s: n_past = %d\n", __func__, n_past); - - common_batch_clear(batch); - common_batch_add (batch, id_last, n_past, { 0 }, true); - - prompt_dft.push_back(id_last); - - //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); - - int ret = llama_decode(ctx_dft, batch); - if (ret != 0 && ret != 1) { - LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n", - __func__, ret, prompt_cur.size(), prompt_dft.size()); - } - - common_sampler_reset(smpl); - - // sample n_draft tokens from the draft model - for (int i = 0; i < sparams.n_max; ++i) { common_batch_clear(batch); - common_sampler_sample(smpl, ctx_dft, 0, true); + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (!drafting[seq_id]) { + continue; + } - const auto * cur_p = common_sampler_get_candidates(smpl, true); + auto * smpl = smpls[seq_id].get(); - for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + common_sampler_sample(smpl, ctx_dft, i_batch, true); + ++i_batch; + + const auto * cur_p = common_sampler_get_candidates(smpl, true); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p, + common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + drafting[seq_id] = false; + n_drafting--; + + continue; + } + + common_sampler_accept(smpl, id, true); + + auto & dp = dparams.at(seq_id); + auto & result = *dp.result; + + result.push_back(id); + + if ((params.n_max <= (int) result.size()) || + (dp.n_max > 0 && dp.n_max <= (int) result.size())) { + drafting[seq_id] = false; + n_drafting--; + continue; + } + + common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); } - // add drafted token for each sequence - const llama_token id = cur_p->data[0].id; - - common_sampler_accept(smpl, id, true); - - // only collect very high-confidence draft tokens - if (cur_p->data[0].p < sparams.p_min) { + if (batch.n_tokens == 0) { break; } - result.push_back(id); - - if (sparams.n_max <= (int) result.size()) { - break; - } - - common_batch_add(batch, id, n_past + i + 1, { 0 }, true); - // evaluate the drafted tokens on the draft model ret = llama_decode(ctx_dft, batch); if (ret != 0) { - LOG_WRN("%s: llama_decode[%d] returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n", - __func__, i, ret, prompt_cur.size(), prompt_dft.size()); + LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); + break; } - prompt_dft.push_back(id); + ++i; } - if (!spec->vocab_cmpt) { - std::string detokenized = common_detokenize(ctx_dft, result, true); - detokenized = replace_to_tgt(detokenized); - LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str()); - result = common_tokenize(ctx_tgt, detokenized, false, true); - if (result.size() > (size_t) sparams.n_max) { - result.resize(sparams.n_max); + for (auto & dp : dparams) { + if (!dp.drafting) { + continue; } - } - if (result.size() < (size_t) sparams.n_min) { - result.clear(); + if (dp.result->size() < (size_t) params.n_min) { + dp.result->clear(); + } } } - void accept(uint16_t n_accepted) override { + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop - GGML_UNUSED(n_accepted); - } - - int32_t n_max(const common_params_speculative & params) const override { - return params.draft.n_max; - } - - int32_t n_min(const common_params_speculative & params) const override { - return params.draft.n_min; - } - - std::string replace_to_dft(const std::string & input) const { - std::string result = input; - - for (const auto & pair : this->vocab_map) { - size_t pos = result.find(pair.first); - while (pos != std::string::npos) { - result.replace(pos, pair.first.length(), pair.second); - pos = result.find(pair.first, pos + pair.second.length()); - } - } - - return result; - } - - std::string replace_to_tgt(const std::string & input) const { - std::string result = input; - - for (const auto & pair : this->vocab_map) { - size_t pos = result.find(pair.second); - while (pos != std::string::npos) { - result.replace(pos, pair.second.length(), pair.first); - pos = result.find(pair.second, pos + pair.first.length()); - } - } - - return result; } }; -struct common_speculative_state_eagle3 : public common_speculative_state { - common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {} +struct common_speculative_state_eagle3 : public common_speculative_impl { + //common_params_speculative_eagle3 params; - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - } + common_speculative_state_eagle3(const common_params_speculative & /*params*/, uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_EAGLE3, n_seq) {} - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & draft_tokens) override { - // TODO: implement - GGML_UNUSED(params); - GGML_UNUSED(prompt_tgt); - GGML_UNUSED(id_last); - GGML_UNUSED(draft_tokens); - } - - void accept(uint16_t n_accepted) override { + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop - GGML_UNUSED(n_accepted); } - int32_t n_max(const common_params_speculative & params) const override { - return params.draft.n_max; + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; } - int32_t n_min(const common_params_speculative & params) const override { - return params.draft.n_min; + void draft(common_speculative_draft_params_vec & /*dparams*/) override { + // TODO: implement + } + + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { + // noop } }; // state of self-speculation (simple implementation, not ngram-map) -struct common_speculative_state_ngram_simple : public common_speculative_state { +struct common_speculative_state_ngram_simple : public common_speculative_impl { + common_params_speculative_ngram_map params; + + // shared across all sequences common_ngram_simple_config config; common_speculative_state_ngram_simple( - enum common_speculative_type type, + const common_params_speculative & params, uint32_t n_seq, common_ngram_simple_config config) - : common_speculative_state(type), config(config) {} + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, n_seq) + , params(params.ngram_simple) + , config(config) {} - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - - result = common_ngram_simple_draft(config, prompt_tgt, id_last); - GGML_UNUSED(params); - } - - void accept(uint16_t n_accepted) override { + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop - GGML_UNUSED(n_accepted); } - int32_t n_max(const common_params_speculative & /*params*/) const override { - return config.size_mgram; + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; } - int32_t n_min(const common_params_speculative & /*params*/) const override { - return config.size_mgram; + void draft(common_speculative_draft_params_vec & dparams) override { + assert(dparams.size() == n_seq); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + if (!dp.drafting) { + continue; + } + + *dp.result = common_ngram_simple_draft(config, *dp.prompt, dp.id_last); + } + } + + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { + // noop } }; -struct common_speculative_state_ngram_map_k : public common_speculative_state { - // draft ngram map for speculative decoding without draft model - common_ngram_map config; +struct common_speculative_state_ngram_map_k : public common_speculative_impl { + common_params_speculative_ngram_map params; + + // n_seq configs + std::vector config; common_speculative_state_ngram_map_k( - enum common_speculative_type type, - common_ngram_map config) - : common_speculative_state(type), config(std::move(config)) {} - - void begin(const llama_tokens & prompt) override { - common_ngram_map_begin(config, prompt); - } - - void draft( const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - common_ngram_map_draft(config, prompt_tgt, id_last, result); - GGML_UNUSED(params); + const common_ngram_map & config, + uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, n_seq) + , params(params.ngram_map_k) { + for (uint32_t i = 0; i < n_seq; i++) { + this->config.push_back(config); + } } - void accept(uint16_t n_accepted) override { - common_ngram_map_accept(config, n_accepted); + void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { + GGML_ASSERT(seq_id < (llama_seq_id) n_seq); + + common_ngram_map_begin(config[seq_id], prompt); } - int32_t n_max(const common_params_speculative & /*params*/) const override { - return config.size_value; + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; } - int32_t n_min(const common_params_speculative & /*params*/) const override { - return config.size_value; + void draft(common_speculative_draft_params_vec & dparams) override { + assert(dparams.size() == n_seq); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + if (!dp.drafting) { + continue; + } + + common_ngram_map_draft(config[seq_id], *dp.prompt, dp.id_last, *dp.result); + } + } + + void accept(llama_seq_id seq_id, uint16_t n_accepted) override { + GGML_ASSERT((seq_id < (llama_seq_id) config.size())); + + common_ngram_map_accept(config[seq_id], n_accepted); } }; -struct common_speculative_state_ngram_mod : public common_speculative_state { - common_ngram_mod & mod; +struct common_speculative_state_ngram_mod : public common_speculative_impl { + common_params_speculative_ngram_mod params; - // the last position in the prompt that was added to the ngram container - size_t i_last = 0; - - // length of the last drafted n‑gram (number of tokens returned by draft) - size_t n_draft_last = 0; - - // consecutive accept rounds with low acceptance fraction (< 0.5) - int n_low = 0; + // shared across all sequences + common_ngram_mod mod; // enable trace logging if LLAMA_TRACE is set const bool verbose; - common_speculative_state_ngram_mod(enum common_speculative_type type, common_ngram_mod & mod) - : common_speculative_state(type), mod(mod), verbose(std::getenv("LLAMA_TRACE") != nullptr) { + struct seq_info { + // the last position in the prompt that was added to the ngram container + size_t i_last = 0; + + // length of the last drafted n‑gram (number of tokens returned by draft) + size_t n_draft_last = 0; + + // consecutive accept rounds with low acceptance fraction (< 0.5) + int n_low = 0; + }; + + std::vector sinfos; + + common_speculative_state_ngram_mod( + const common_params_speculative & params, + uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, n_seq) + , params(params.ngram_mod) + , mod(params.ngram_mod.n_match, 4*1024*1024) + , verbose(std::getenv("LLAMA_TRACE") != nullptr) { static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t)); + + LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__, + this->params.n_match, mod.size(), (float)(mod.size_bytes())/1024/1024); + + if (this->params.n_match < 16) { + LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, " + "see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, this->params.n_match); + } + + sinfos.resize(n_seq); } - void begin(const llama_tokens & prompt) override { - i_last = 0; + void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { + auto & sinfo = sinfos[seq_id]; - n_draft_last = 0; + sinfo.i_last = 0; + sinfo.n_draft_last = 0; const size_t n = mod.get_n(); - if (prompt.size() < n) { return; } @@ -706,7 +510,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { mod.add(prompt.data() + i); } - i_last = prompt.size() - n; + sinfo.i_last = prompt.size() - n; const double f = (double)mod.get_used() / (double)mod.size(); LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f); @@ -719,16 +523,17 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { } } - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - const auto & sparams = params.ngram_mod; + void draft_one( + llama_seq_id seq_id, + common_speculative_draft_params & dparams) { + auto & sinfo = sinfos[seq_id]; + auto & result = *dparams.result; - n_draft_last = 0; + const auto & prompt = *dparams.prompt; - const size_t cur_len = prompt_tgt.size(); + sinfo.n_draft_last = 0; + + const size_t cur_len = prompt.size(); if (cur_len < mod.get_n()) { return; } @@ -736,24 +541,24 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { const size_t n = mod.get_n(); // add new ngrams in chunks - if (i_last + 32 < cur_len) { - for (size_t i = i_last; i < cur_len - n; ++i) { - mod.add(prompt_tgt.data() + i); + if (sinfo.i_last + 32 < cur_len) { + for (size_t i = sinfo.i_last; i < cur_len - n; ++i) { + mod.add(prompt.data() + i); } - i_last = cur_len - n; + sinfo.i_last = cur_len - n; } - result.resize(n + sparams.n_max); + result.resize(n + params.n_max); for (size_t i = 0; i < n - 1; ++i) { - result[i] = prompt_tgt[cur_len - n + 1 + i]; + result[i] = prompt.at(cur_len - n + 1 + i); } - result[n - 1] = id_last; + result[n - 1] = dparams.id_last; - for (int i = 0; i < sparams.n_max; ++i) { + for (int i = 0; i < params.n_max; ++i) { const llama_token token = mod.get(result.data() + i); if (token == common_ngram_mod::EMPTY) { - if (i < sparams.n_min) { + if (i < params.n_min) { result.clear(); return; } @@ -771,65 +576,92 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { result.resize(result.size() - n); // store length of drafted n‑gram for later acceptance analysis - n_draft_last = result.size(); + sinfo.n_draft_last = result.size(); } - void accept(uint16_t n_accepted) override { - // compute acceptance fraction if we have a recorded draft length - if (n_draft_last > 0) { - const double f_acc = (double)n_accepted / (double)n_draft_last; - if (f_acc < 0.5) { - n_low++; - if (n_low >= 3) { - if (verbose) { - LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low); - } + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; + } - mod.reset(); - n_low = 0; - i_last = 0; - } - } else { - n_low = 0; + void draft(common_speculative_draft_params_vec & dparams) override { + assert(dparams.size() == n_seq); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + if (!dp.drafting) { + continue; } + + draft_one(seq_id, dp); } } - int32_t n_max(const common_params_speculative & params) const override { - return params.ngram_mod.n_max; - } + void accept(llama_seq_id seq_id, uint16_t n_accepted) override { + auto & sinfo = sinfos[seq_id]; - int32_t n_min(const common_params_speculative & params) const override { - return params.ngram_mod.n_min; + // compute acceptance fraction if we have a recorded draft length + if (sinfo.n_draft_last > 0) { + const double f_acc = (double)n_accepted / (double)sinfo.n_draft_last; + if (f_acc < 0.5) { + sinfo.n_low++; + if (sinfo.n_low >= 3) { + if (verbose) { + LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, sinfo.n_low); + } + + mod.reset(); + sinfo.n_low = 0; + sinfo.i_last = 0; + } + } else { + sinfo.n_low = 0; + } + } } }; -struct common_speculative_state_ngram_cache : public common_speculative_state { +struct common_speculative_state_ngram_cache : public common_speculative_impl { + common_params_speculative_ngram_cache params; + uint16_t n_draft; + bool save_dynamic; bool save_static; - common_ngram_cache ngram_cache_context; - common_ngram_cache ngram_cache_dynamic; - common_ngram_cache ngram_cache_static; + struct seq_info { + size_t cache_size = 0; // number of tokens in n-gram cache - size_t cache_size = 0; // number of tokens in n-gram cache + common_ngram_cache ngram_cache_context; + common_ngram_cache ngram_cache_dynamic; + common_ngram_cache ngram_cache_static; + }; + + std::vector sinfos; common_speculative_state_ngram_cache( - const enum common_speculative_type type, + const common_params_speculative & params, + uint32_t n_seq, + uint16_t n_draft, const std::string & path_static, const std::string & path_dynamic, - uint16_t n_draft, - bool save_dynamic, - bool save_static) - : common_speculative_state(type) + bool save_dynamic, + bool save_static) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, n_seq) + , params(params.ngram_cache) , n_draft(n_draft) , save_dynamic(save_dynamic) , save_static(save_static) { + sinfos.resize(n_seq); + if (!path_static.empty()) { try { - ngram_cache_static = common_ngram_cache_load(path_static); + auto ngram_cache_static = common_ngram_cache_load(path_static); + + for (auto & sinfo : sinfos) { + sinfo.ngram_cache_static = ngram_cache_static; + } } catch (...) { LOG_ERR("failed to open static lookup cache: %s", path_static.c_str()); GGML_ABORT("Couldn't read static lookup cache"); @@ -838,7 +670,11 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { if (!path_dynamic.empty()) { try { - ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); + auto ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); + + for (auto & sinfo : sinfos) { + sinfo.ngram_cache_dynamic = ngram_cache_dynamic; + } } catch (...) { LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str()); GGML_ABORT("Couldn't read dynamic lookup cache"); @@ -846,44 +682,48 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { } } - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { + // noop } - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - GGML_UNUSED(params); + void draft_one( + llama_seq_id seq_id, + common_speculative_draft_params & dparams) { + auto & sinfo = sinfos[seq_id]; + auto & result = *dparams.result; - if (cache_size < prompt_tgt.size() + 1) { + const auto & prompt = *dparams.prompt; + + if (sinfo.cache_size < prompt.size() + 1) { llama_tokens tokens_new; - tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); - for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { - tokens_new.push_back(prompt_tgt[j]); + tokens_new.reserve(prompt.size() + 1 - sinfo.cache_size); + for (size_t j = sinfo.cache_size; j < prompt.size(); ++j) { + tokens_new.push_back(prompt[j]); } - tokens_new.push_back(id_last); // add the last token + tokens_new.push_back(dparams.id_last); // add the last token - // Update context ngram cache with new prompt_tgt: - common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + // Update context ngram cache with new dparams.prompt: + common_ngram_cache_update( + sinfo.ngram_cache_context, + LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, tokens_new, tokens_new.size(), false); - cache_size = prompt_tgt.size() + 1; + sinfo.cache_size = prompt.size() + 1; } llama_tokens inp; - inp.reserve(prompt_tgt.size() + 1); - for (size_t j = 0; j < prompt_tgt.size(); ++j) { - inp.push_back(prompt_tgt[j]); + inp.reserve(prompt.size() + 1); + for (size_t j = 0; j < prompt.size(); ++j) { + inp.push_back(prompt[j]); } - inp.push_back(id_last); + inp.push_back(dparams.id_last); - result.push_back(id_last); + result.push_back(dparams.id_last); - common_ngram_cache_draft(inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, - ngram_cache_context, - ngram_cache_dynamic, - ngram_cache_static); + common_ngram_cache_draft( + inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + sinfo.ngram_cache_context, + sinfo.ngram_cache_dynamic, + sinfo.ngram_cache_static); if (result.size() > 0) { // delete first token in result (which is the id_last token) @@ -891,24 +731,37 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { } } - void accept(uint16_t n_accepted) override { - // TODO: noop - GGML_UNUSED(n_accepted); + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; } - int32_t n_max(const common_params_speculative & /*params*/) const override { - return n_draft; + void draft(common_speculative_draft_params_vec & dparams) override { + assert(dparams.size() == n_seq); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + if (!dp.drafting) { + continue; + } + + draft_one(seq_id, dp); + } } - int32_t n_min(const common_params_speculative & /*params*/) const override { - return 0; + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { + // noop } }; struct common_speculative { - std::vector> impls; // list of implementations to use and their states + common_speculative_draft_params_vec dparams; - common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) + // list of implementations to use and their states + std::vector> impls; + + // which implementaion was used for a given seq_id + std::vector impl_last; }; static common_ngram_map get_common_ngram_map( @@ -923,45 +776,79 @@ static common_ngram_map get_common_ngram_map( } static common_speculative_state_ngram_cache create_state_ngram_cache( - const std::string & path_static, const std::string & path_dynamic, - const common_speculative_config & config) { + const common_speculative_config & config, + uint32_t n_seq, + const std::string & path_static, + const std::string & path_dynamic) { uint16_t n_draft = 8; // TODO get from config? // TODO bool param in common/common.h to set save_static/save_dynamic? bool save_static = false; bool save_dynamic = false; - common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic); + common_speculative_state_ngram_cache state(config.params, n_seq, n_draft, path_static, path_dynamic, save_static, save_dynamic); return state; } -std::string common_speculative_type_name_str() { +std::string common_speculative_type_name_str(const std::vector & types) { std::string result; - for (size_t i = 0; i < common_speculative_types.size(); i++) { + + for (size_t i = 0; i < types.size(); i++) { if (i > 0) { - result += ", "; + result += ","; } - result += common_speculative_type_to_str(common_speculative_types[i]); + result += common_speculative_type_to_str(types[i]); } return result; } -std::string common_speculative_type_to_str(enum common_speculative_type type) { +const char * common_speculative_all_types_str() { + static std::string all_types_str = []() { + std::vector types; + types.reserve(COMMON_SPECULATIVE_TYPE_COUNT); + for (int i = 0; i < COMMON_SPECULATIVE_TYPE_COUNT; i++) { + types.push_back((common_speculative_type) i); + } + return common_speculative_type_name_str(types); + }(); + return all_types_str.c_str(); +} + +std::string common_speculative_type_to_str(common_speculative_type type) { switch (type) { case COMMON_SPECULATIVE_TYPE_NONE: return "none"; case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft"; case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3"; - case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod"; - case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache"; + case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram-mod"; + case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram-cache"; default: return "unknown"; } } -enum common_speculative_type common_speculative_type_from_name(const std::string & name) { +std::vector common_speculative_types_from_names(const std::vector & names) { + std::vector types; + types.reserve(names.size()); + + for (const auto & name : names) { + auto type = common_speculative_type_from_name_map.find(name); + if (type != common_speculative_type_from_name_map.end()) { + if (type->second == COMMON_SPECULATIVE_TYPE_NONE) { + return std::vector { COMMON_SPECULATIVE_TYPE_NONE }; + } + types.push_back(type->second); + continue; + } + throw std::invalid_argument("unknown speculative type: " + name); + } + + return types; +} + +common_speculative_type common_speculative_type_from_name(const std::string & name) { const auto it = common_speculative_type_from_name_map.find(name); if (it == common_speculative_type_from_name_map.end()) { return COMMON_SPECULATIVE_TYPE_COUNT; @@ -969,34 +856,39 @@ enum common_speculative_type common_speculative_type_from_name(const std::string return it->second; } +static uint32_t common_get_enabled_speculative_configs(const std::vector & configs) { + uint32_t result = 0; + for (size_t i = 0; i < configs.size(); i++) { + result |= (1u << configs[i]); + } + return result; +} + // initialization of the speculative decoding system // -common_speculative * common_speculative_init( - common_params_speculative & params, - llama_context * ctx_tgt) { - llama_context * ctx_dft = nullptr; - if (params.draft.model) { - ctx_dft = llama_init_from_model(params.draft.model, params.draft.cparams); - if (ctx_dft == nullptr) { - LOG_ERR("%s", "failed to create draft context\n"); - return nullptr; - } - } - +common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq) { // Compute the implementations to use based on the config and their order of preference std::vector configs = {}; // list of speculative configs to try { - bool has_draft = !params.draft.mparams.path.empty(); + uint32_t enabled_configs = common_get_enabled_speculative_configs(params.types); + + bool has_draft = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT)); + bool has_draft_model = !params.draft.mparams.path.empty(); + + // bool has_mtp = false; // TODO: add MTP here bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 - bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE); - bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE); - bool has_ngram_map_k = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); - bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V); - bool has_ngram_mod = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD); + bool has_ngram_cache = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_CACHE)); + bool has_ngram_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE)); + bool has_ngram_map_k = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K)); + bool has_ngram_map_k4v = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V)); + bool has_ngram_mod = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MOD)); - // In a more complex implementation we could use the same implementation but with different parameters. - // This was initially used in PR-18471 but removed to simplify the code. + // when adding a new type - update here the logic above + static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 8); + + // this list here defines the priority of the speculators + // the one with highest priority are listed first if (has_ngram_simple) { // This implementation can guess a lot of tokens without any draft model. configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, params)); @@ -1009,53 +901,43 @@ common_speculative * common_speculative_init( configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params)); } if (has_ngram_mod) { - auto & sparams = params.ngram_mod; - - if (!sparams.obj) { - sparams.obj = std::make_shared(sparams.n_match, 4*1024*1024); - - LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__, - sparams.n_match, sparams.obj->size(), (float)(sparams.obj->size_bytes())/1024/1024); - - if (sparams.n_match < 16) { - LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, " - "see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, sparams.n_match); - } - } - configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, params)); } if (has_ngram_cache) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params)); } + if (has_draft) { + if (!has_draft_model) { + LOG_WRN("%s: draft model is not specified - cannot use 'draft' type\n", __func__); + has_draft = false; + } + } else if (has_draft_model) { + LOG_WRN("%s: draft model is specified but 'draft' speculative type is not explicitly enabled - enabling it\n", __func__); + has_draft = true; + } + if (has_draft) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params)); } + // TODO: add MTP here if (has_draft_eagle3) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params)); } } - std::vector> impls = {}; + std::vector> impls = {}; for (const common_speculative_config & config : configs) { - LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str()); + LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(config.type).c_str()); switch (config.type) { case COMMON_SPECULATIVE_TYPE_NONE: break; case COMMON_SPECULATIVE_TYPE_DRAFT: { - const bool use_ckpt = common_context_can_seq_rm(ctx_dft) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; - - impls.push_back(std::make_unique(config.type, - /* .ctx_tgt = */ ctx_tgt, - /* .ctx_dft = */ ctx_dft, - /* .replacements = */ params.draft.replacements, - /* .use_ckpt = */ use_ckpt - )); + impls.push_back(std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_EAGLE3: { - impls.push_back(std::make_unique(config.type)); + impls.push_back(std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { @@ -1069,27 +951,30 @@ common_speculative * common_speculative_init( /* .size_mgram = */ mgram_size_value }; auto state = std::make_unique( - /* .type = */ config.type, - /* .state = */ config_simple + /* .params = */ config.params, + /* .n_seq = */ n_seq, + /* .state = */ config_simple ); impls.push_back(std::move(state)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: { - impls.push_back(std::make_unique( - (config.type), - get_common_ngram_map(config.type, config.params.ngram_map_k) - )); + impls.push_back( + std::make_unique( + config.params, get_common_ngram_map(config.type, config.params.ngram_map_k), n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: { - GGML_ASSERT(config.params.ngram_mod.obj); - impls.push_back(std::make_unique(config.type, *config.params.ngram_mod.obj)); + impls.push_back( + std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: { - auto state = create_state_ngram_cache(params.ngram_cache.lookup_cache_static, params.ngram_cache.lookup_cache_dynamic, config); + auto state = create_state_ngram_cache( + config, n_seq, + params.ngram_cache.lookup_cache_static, + params.ngram_cache.lookup_cache_dynamic); impls.push_back(std::make_unique(state)); break; } @@ -1104,8 +989,9 @@ common_speculative * common_speculative_init( } auto * result = new common_speculative { + /* .dparams = */ common_speculative_draft_params_vec(n_seq), /* .impls = */ std::move(impls), - /* .curr_impl = */ nullptr, + /* .impl_last = */ std::vector(n_seq, nullptr) }; return result; @@ -1119,65 +1005,128 @@ void common_speculative_free(common_speculative * spec) { delete spec; } -void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt) { +common_speculative_draft_params & common_speculative_get_draft_params( + common_speculative * spec, + llama_seq_id seq_id) { + GGML_ASSERT(spec); + GGML_ASSERT(seq_id < (llama_seq_id) spec->dparams.size()); + + return spec->dparams[seq_id]; +} + +void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt) { if (spec == nullptr) { return; } for (auto & impl : spec->impls) { common_time_meas tm(impl->t_begin_us, !impl->gen_perf); - impl->begin(prompt); + impl->begin(seq_id, prompt); impl->n_call_begin++; } } -llama_tokens common_speculative_draft( - common_speculative * spec, - const common_params_speculative & params, - const llama_tokens & prompt_tgt, // specified in target model vocab - llama_token id_last) { - llama_tokens result; +bool common_speculative_process(common_speculative * spec, const llama_batch & batch) { + bool result = true; - spec->curr_impl = nullptr; // reset current implementation + if (spec == nullptr) { + return result; + } for (auto & impl : spec->impls) { - { - common_time_meas tm(impl->t_draft_us, !impl->gen_perf); - impl->draft(params, prompt_tgt, id_last, result); - impl->n_call_draft++; - } - - { - const int n_min = impl->n_min(params); - - if (!result.empty() && (int) result.size() < n_min) { - LOG_DBG("%s: ignoring small draft: %d < %d\n", __func__, (int) result.size(), n_min); - result.clear(); - } - } - - if (!result.empty()) { - LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, - common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(), - impl.get()->n_call_draft, result.size()); - - spec->curr_impl = impl.get(); // set current implementation for stats - impl->n_gen_drafts++; - impl->n_gen_tokens += result.size(); - - break; // we have a draft, so break out of the loop and return it. - } + result = result && impl->process(batch); } return result; } -void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { +void common_speculative_draft(common_speculative * spec) { + if (spec == nullptr) { + return; + } + + auto & dparams = spec->dparams; + + { + int n_drafting = 0; + + for (auto & dp : dparams) { + GGML_ASSERT(!dp.drafting || dp.result->empty()); + + if (dp.drafting) { + n_drafting++; + } + } + + if (n_drafting == 0) { + return; + } + } + + for (auto & impl : spec->impls) { + { + common_time_meas tm(impl->t_draft_us, !impl->gen_perf); + impl->draft(dparams); + impl->n_call_draft++; + } + + int n_drafting = 0; + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) dparams.size(); ++seq_id) { + auto & dp = dparams[seq_id]; + + auto & result = *dp.result; + + // a new draft has been sampled + if (dp.drafting && !result.empty()) { + dp.drafting = false; + + if (dp.n_max > 0) { + if (!result.empty() && (int) result.size() > dp.n_max) { + LOG_DBG("%s: truncating draft to %d tokens\n", __func__, dp.n_max); + result.resize(dp.n_max); + } + } + + if (!result.empty()) { + LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, + common_speculative_type_to_str(impl.get()->type).c_str(), dp.prompt->size(), + impl.get()->n_call_draft, result.size()); + + // remember which implementation was used + spec->impl_last[seq_id] = impl.get(); + + impl->n_gen_drafts++; + impl->n_gen_tokens += result.size(); + } + } + + if (dp.drafting) { + n_drafting++; + } + } + + if (n_drafting == 0) { + break; + } + } + + // these sequences failed to generate a draft + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) dparams.size(); ++seq_id) { + auto & dp = dparams[seq_id]; + + if (dp.drafting) { + dp.drafting = false; + } + } +} + +void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, uint16_t n_accepted) { if (n_accepted == 0) { return; } - common_speculative_state * impl = spec->curr_impl; + common_speculative_impl * impl = spec->impl_last[seq_id]; GGML_ASSERT(impl); @@ -1188,37 +1137,11 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { impl->n_acc_tokens += n_accepted; } - impl->accept(n_accepted); + impl->accept(seq_id, n_accepted); impl->n_call_accept++; } } -int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params) { - if (spec == nullptr) { - return 0; - } - - int32_t n_max = 0; - for (const auto & impl : spec->impls) { - n_max = std::max(n_max, impl->n_max(params)); - } - - return n_max; -} - -int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params) { - if (spec == nullptr) { - return 0; - } - - int32_t n_min = 0; - for (const auto & impl : spec->impls) { - n_min = std::max(n_min, impl->n_min(params)); - } - - return n_min; -} - void common_speculative_print_stats(const common_speculative * spec) { if (spec == nullptr) { return; diff --git a/common/speculative.h b/common/speculative.h index 147447631..51f0b059f 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -5,8 +5,14 @@ struct common_speculative; +// comma separated list the provided types +std::string common_speculative_type_name_str(const std::vector & types); + // comma separated list of all types -std::string common_speculative_type_name_str(); +const char * common_speculative_all_types_str(); + +// parse user provided types +std::vector common_speculative_types_from_names(const std::vector & names); // convert string to type enum common_speculative_type common_speculative_type_from_name(const std::string & name); @@ -14,27 +20,44 @@ enum common_speculative_type common_speculative_type_from_name(const std::string // convert type to string std::string common_speculative_type_to_str(enum common_speculative_type type); -common_speculative * common_speculative_init( - common_params_speculative & params, - llama_context * ctx_tgt); +common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq); void common_speculative_free(common_speculative * spec); +struct common_speculative_draft_params { + // this flag is used to chain the drafts through all the available implementations + // after the first successful draft from an implementation, we set it + // to false to prevent further drafts for that sequence + // at the end of the draft() call, all drafting flags will be reset to false + bool drafting = false; + + // overrides individual configurations (-1 disabled) + // can be used to constraint the max draft based on the remaining context size + int32_t n_max = -1; + + llama_pos n_past; + llama_token id_last; + + // TODO: remove in the future by keeping track of the prompt from the _begin() call and the consecutive accept calls + const llama_tokens * prompt; + + // the generated draft from the last _draft() call + llama_tokens * result; +}; + +common_speculative_draft_params & common_speculative_get_draft_params(common_speculative * spec, llama_seq_id seq_id); + // optionally call once at the beginning of a new generation -void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt); +void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt); -// sample up to n_draft tokens and add them to the batch using the draft model -llama_tokens common_speculative_draft( - common_speculative * spec, - const common_params_speculative & params, - const llama_tokens & prompt, - llama_token id_last); +// process the batch and update the internal state of the speculative context +bool common_speculative_process(common_speculative * spec, const llama_batch & batch); -// informs the speculative decoder that n_accepted tokens were accepted by the target model -void common_speculative_accept(common_speculative * spec, uint16_t n_accepted); +// generate drafts for the sequences specified with `common_speculative_get_draft_params` +void common_speculative_draft(common_speculative * spec); -int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params); -int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params); +// informs the speculative context that n_accepted tokens were accepted by the target model +void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted); // print statistics about the speculative decoding void common_speculative_print_stats(const common_speculative * spec); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e5dea18ae..d79372cea 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2176,7 +2176,8 @@ class MmprojModel(ModelBase): text_config = { k: v for k, v in self.hparams.items() if k not in ["vision_encoder", "audio_encoder"] } - self.n_embd_text = text_config.get("hidden_dim", 0) + # mistral native params.json: "dim" is the text hidden size ("hidden_dim" is the FFN intermediate size) + self.n_embd_text = text_config.get("dim", 0) assert self.n_embd_text > 0, "n_embd not found in hparams" @@ -3137,6 +3138,11 @@ class LlavaVisionModel(MmprojModel): assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json" if self.use_break_tok: self.img_break_tok_id = self.find_vparam(["image_break_token_id"]) + + # params.json may ship -1 placeholders (Mistral Medium 3.5) + # resolve the real id from the bundled tokenizer in that case + if self.img_break_tok_id < 0: + self.img_break_tok_id = self.get_mistral_token_id("[IMG_BREAK]") else: raise ValueError(f"Unsupported model type: {self.hparams['model_type']}") logger.info(f"Image break token id: {self.img_break_tok_id}") @@ -3156,6 +3162,24 @@ class LlavaVisionModel(MmprojModel): return int(token_data["id"]) raise ValueError(f"Token '{token}' not found in tokenizer config.") + def get_mistral_token_id(self, token: str) -> int: + # mistral native format ships tekken.json or a versioned spm tokenizer + tekken_file = self.dir_model / "tekken.json" + if tekken_file.is_file(): + with open(tekken_file, "r", encoding="utf-8") as f: + data = json.load(f) + for entry in data.get("special_tokens", []): + if entry.get("token_str") == token: + return int(entry["rank"]) + tokenizer_json_file = self.dir_model / "tokenizer.json" + if tokenizer_json_file.is_file(): + with open(tokenizer_json_file, "r", encoding="utf-8") as f: + data = json.load(f) + for entry in data.get("added_tokens", []): + if entry.get("content") == token: + return int(entry["id"]) + raise ValueError(f"Token '{token}' not found in mistral tokenizer files.") + def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams @@ -9736,6 +9760,73 @@ class MimoV2Model(TextModel): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("MiMoV2ForCausalLM") +class MiMoV2VisionModel(MmprojModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + hp = self.hparams_vision + + hp["image_size"] = hp.get("image_size", 560) + hp["num_attention_heads"] = hp.get("num_heads", 32) + hp["num_hidden_layers"] = hp.get("depth", 28) + + self.n_q_heads = int(hp["num_heads"]) + self.num_kv_heads = int(hp.get("num_key_value_heads", 8)) + self.head_dim = int(hp.get("qk_channels", 64)) + self.spatial_merge_size = int(hp["spatial_merge_size"]) + # MiMoV2 vision RMSNorm: HF uses getattr(config, "rms_norm_eps", 1e-6) and the + # field is absent from MiMo-V2.5's vision_config + self.rms_norm_eps = float(hp.get("rms_norm_eps", 1e-6)) + + # fullatt_block_indexes are also reflected in vit_window_attn_types as -1 + self.fullatt_block_indexes = list(hp.get("fullatt_block_indexes") or []) + self.vit_window_attn_types = list(hp.get("vit_window_attn_types") or []) + self.visual_token_window_size = int(hp.get("visual_token_window_size", -1)) + self.use_sink = bool(hp.get("use_sink", False)) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.MIMOVL) + self.gguf_writer.add_vision_use_silu(True) + self.gguf_writer.add_vision_head_count_kv(self.num_kv_heads) + self.gguf_writer.add_vision_spatial_merge_size(self.spatial_merge_size) + self.gguf_writer.add_uint32(gguf.Keys.ClipVision.WINDOW_SIZE, self.visual_token_window_size) + self.gguf_writer.add_vision_wa_pattern_mode(self.vit_window_attn_types) + self.gguf_writer.add_vision_attention_layernorm_eps(self.rms_norm_eps) + self.gguf_writer.add_vision_min_pixels(int(self.preprocessor_config["min_pixels"])) + self.gguf_writer.add_vision_max_pixels(int(self.preprocessor_config["max_pixels"])) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + # Sinks must be F32: any sink-style softmax/mask add in ggml requires + # F32, and we fold sinks into a host-built F32 mask at encode time. + if new_name.endswith(".attn_sinks"): + return gguf.GGMLQuantizationType.F32 + return super().tensor_force_quant(name, new_name, bid, n_dims) + + @classmethod + def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None: + name, _ = item + if not name.startswith("visual."): + return None + return super().filter_tensors(item) + + def modify_tensors(self, data_torch, name, bid): + # Conv3D patch embed: split along the temporal axis (kt=2) into two Conv2D + # weights that the existing qwen2vl-style two-Conv2D path consumes. + if name == "visual.patch_embed.proj.weight": + _, _, kt, _, _ = data_torch.shape + if kt != 2: + raise ValueError(f"unexpected temporal_patch_size: {kt}") + embd_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + yield (embd_name + ".weight", data_torch[:, :, 0, ...]) + yield (embd_name + ".weight.1", data_torch[:, :, 1, ...]) + return + + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("Step3p5ForCausalLM") class Step35Model(TextModel): model_arch = gguf.MODEL_ARCH.STEP35 diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index d58334205..ad4751bb9 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -188,6 +188,24 @@ class LoraTorchTensor: def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor: return self.transpose(axis0, axis1) + def split(self, split_size: int | Sequence[int], dim: int = 0) -> tuple[LoraTorchTensor, ...]: + shape = self.shape + ndim = len(shape) + if dim < 0: + dim += ndim + if dim == ndim - 1: + A_chunks = self._lora_A.split(split_size, dim=-1) + return tuple(LoraTorchTensor(a, self._lora_B) for a in A_chunks) + elif dim == ndim - 2: + B_chunks = self._lora_B.split(split_size, dim=-2) + return tuple(LoraTorchTensor(self._lora_A, b) for b in B_chunks) + else: + B_chunks = self._lora_B.split(split_size, dim=dim) + if self._lora_A.shape[dim] == 1: + return tuple(LoraTorchTensor(self._lora_A, b) for b in B_chunks) + A_chunks = self._lora_A.split(split_size, dim=dim) + return tuple(LoraTorchTensor(a, b) for a, b in zip(A_chunks, B_chunks)) + def to(self, *args, **kwargs): return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs)) @@ -230,6 +248,11 @@ class LoraTorchTensor: ) else: raise NotImplementedError + elif func is torch.split: + assert len(args) and len(args) >= 2 + tensor, split_size = args[0], args[1] + dim = args[2] if len(args) > 2 else kwargs.get("dim", 0) + return tensor.split(split_size, dim=dim) else: raise NotImplementedError diff --git a/examples/llama-eval/README.md b/examples/llama-eval/README.md new file mode 100644 index 000000000..3c5c35f78 --- /dev/null +++ b/examples/llama-eval/README.md @@ -0,0 +1,26 @@ +# llama-eval + +Simple evaluation tool for llama.cpp with support for multiple datasets. + +For a full description, usage examples, and sample results, see: + +- [PR 21152](https://github.com/ggml-org/llama.cpp/pull/21152) + +## Quick start + +```bash +# Single server +python3 llama-eval.py \ + --server http://localhost:8033 \ + --model my-model \ + --dataset gsm8k --n_cases 100 \ + --grader-type regex --threads 32 + +# Multiple servers (comma-separated URLs and thread counts) +python3 llama-eval.py \ + --server http://server1:8033,http://server2:8033 \ + --server-name server1,server2 \ + --threads 16,16 \ + --dataset aime2025 --n_cases 240 \ + --grader-type regex +``` diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py new file mode 100644 index 000000000..b33a3615b --- /dev/null +++ b/examples/llama-eval/llama-eval.py @@ -0,0 +1,1416 @@ +#!/usr/bin/env python3 +# type: ignore + +import argparse +import json +import os +import re +import subprocess +import sys +import threading +import time +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass, asdict, field +from pathlib import Path +from queue import Queue +from typing import Dict, List, Optional, Any, Tuple +import requests +from tqdm import tqdm +import random +from math import sqrt + + +@dataclass +class ServerConfig: + url: str + threads: int + name: str = "" + +def wilson_interval(correct: int, total: int, z: float = 1.96) -> Tuple[float, float]: + """Wilson score confidence interval for a proportion.""" + if total == 0: + return (0.0, 1.0) + p = correct / total + z2 = z * z / total + center = (p + z2 / 2) / (1 + z2) + margin = z * sqrt((p * (1 - p) + z2 / 4) / total) / (1 + z2) + return (center - margin, center + margin) + +cache_dir = Path.home() / ".cache" / "huggingface" / "datasets" +cache_dir.mkdir(parents=True, exist_ok=True) +os.environ["HF_DATASETS_CACHE"] = str(cache_dir) +os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" + +GRADER_PATTERNS = { + "aime": r'\boxed{(\d+)}|\b(\d+)\b', + "aime2025": r'\boxed{(\d+)}|\b(\d+)\b', + "gsm8k": r'\b(\d+)\b', +} + +SAMPLE_ANSWERS = { + "aime": [ + "42", + "-123", + "999" + ], + "aime2025": [ + "42", + "-123", + "999" + ], + "gsm8k": [ + "42", + "-123", + "999" + ], + "gpqa": [ + "A", + "D", + "C" + ], +} + +TEMPLATE_REGISTRY = { + "aime": """Solve the following math problem step by step. Put your answer inside \\boxed{{}}. + +{question} + +Remember to put your answer inside \\boxed{{}}. +""", + "aime2025": """Solve the following math problem step by step. Put your answer inside \\boxed{{}}. + +{question} + +Remember to put your answer inside \\boxed{{}}. +""", + "gsm8k": """{question} +Please reason step by step, and put your final numeric answer within \\boxed{{}} without any extra characters. +""", + "gpqa": """Answer the following multiple choice question. The last line of your response should be in the following format: 'Answer: A/B/C/D' (e.g. 'Answer: A'). + +{Question} + +A) {A} +B) {B} +C) {C} +D) {D} +""", +} + + +class BaseDataset(ABC): + @abstractmethod + def get_question(self, index: int) -> Dict: + pass + + @abstractmethod + def get_question_text(self, question: Dict) -> str: + pass + + @abstractmethod + def get_answer(self, question: Dict) -> str: + pass + + @abstractmethod + def get_prompt(self, question: Dict) -> str: + pass + + def __len__(self) -> int: + return len(self.questions) + + +@dataclass +class TaskState: + task_id: str + prompt: str + expected: str + question_text: str = "" + response: Optional[str] = None + answer: Optional[str] = None + grader_log: Dict[str, Any] = field(default_factory=dict) + correct: bool = False + status: str = "pending" + tokens: Optional[int] = None + tps_gen: Optional[float] = None + t_gen_ms: Optional[float] = None + reasoning_content: Optional[str] = None + server_name: Optional[str] = None + + +class EvalState: + def __init__( + self, + dataset_type: str, + sampling_config: Dict[str, Any], + output_file: Path = Path("llama-eval-state.json"), + model_name: Optional[str] = None + ): + self.dataset_type = dataset_type + self.sampling_config = sampling_config + self.output_file = output_file + self.model_name = model_name + self.dataset: Optional[BaseDataset] = None + self.tasks: List[Tuple[int, str]] = [] + self.all_tasks: List[Tuple[int, str]] = [] + self.task_states: Dict[str, Any] = {} + self.total = 0 + self.correct = 0 + self.processed = 0 + self.total_time: float = 0.0 + self._lock = threading.Lock() + + def load_dataset(self, seed: int = 1234): + if self.dataset_type == "aime": + self.dataset = AimeDataset() + elif self.dataset_type == "aime2025": + self.dataset = Aime2025Dataset() + elif self.dataset_type == "gsm8k": + self.dataset = Gsm8kDataset() + elif self.dataset_type == "gpqa": + self.dataset = GpqaDataset(variant="diamond", seed=seed) + else: + raise ValueError(f"Unknown dataset type: {self.dataset_type}") + + def setup_tasks(self, n_cases: Optional[int] = None, seed: int = 1234): + if self.dataset is None: + raise ValueError("Dataset not loaded. Call load_dataset() first.") + + if n_cases is None: + n_cases = len(self.dataset) + + dataset_size = len(self.dataset) + rng = random.Random(seed) + + self.tasks = [] + for chunk_idx in range((n_cases + dataset_size - 1) // dataset_size): + chunk_size = min(dataset_size, n_cases - chunk_idx * dataset_size) + indices = list(range(dataset_size)) + rng.shuffle(indices) + chunk_indices = indices[:chunk_size] + + for i in chunk_indices: + task_id = f"{self.dataset_type}_{chunk_idx:03d}_{i:03d}" + self.tasks.append((i, task_id)) + + self.all_tasks = list(self.tasks) + + def get_case(self, index: int) -> Tuple[str, str, str]: + if self.dataset is None: + raise ValueError("Dataset not loaded.") + question = self.dataset.get_question(index) + question_text = self.dataset.get_question_text(question) + prompt = self.dataset.get_prompt(question) + expected = self.dataset.get_answer(question) + return question_text, prompt, expected + + def add_result( + self, + task_id: str, + prompt: str, + expected: str, + response: Optional[str], + answer: Optional[str], + grader_log: Dict[str, Any], + correct: bool, + status: str, + tokens: Optional[int] = None, + tps_gen: Optional[float] = None, + t_gen_ms: Optional[float] = None, + reasoning_content: Optional[str] = None, + server_name: Optional[str] = None + ): + with self._lock: + if "cases" not in self.task_states: + self.task_states["cases"] = {} + + self.task_states["cases"][task_id] = { + "task_id": task_id, + "prompt": prompt, + "expected": expected, + "response": response, + "answer": answer, + "grader_log": grader_log, + "correct": correct, + "status": status, + "tokens": tokens, + "tps_gen": tps_gen, + "t_gen_ms": t_gen_ms, + "reasoning_content": reasoning_content, + "server_name": server_name + } + + self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False)) + + def print_progress(self, task_state: TaskState, total_tasks: int, n_correct: int = 0): + display_answer = task_state.answer if task_state.answer else "N/A" + display_tokens = str(task_state.tokens) if task_state.tokens is not None else "N/A" + display_tps = f"{task_state.tps_gen:.1f}" if task_state.tps_gen is not None else "N/A" + display_t_gen = f"{task_state.t_gen_ms/1000:.1f}" if task_state.t_gen_ms is not None else "N/A" + display_server = task_state.server_name if task_state.server_name else "N/A" + success_ratio = n_correct / self.processed if self.processed > 0 else 0.0 + first_line = task_state.question_text.split('\n')[0] + truncated_question = first_line[:43] + if len(first_line) > 43: + truncated_question += "..." + else: + truncated_question = truncated_question.ljust(43) + "..." + print(f"{self.processed:3}/{total_tasks:3} {task_state.task_id:<20} {self.dataset_type.upper()} {truncated_question:<40} {task_state.expected:<10} {display_answer:<10} {display_tokens:<6} {display_tps:<6} {display_t_gen:<8} {'✓' if task_state.correct else '✗'} [{n_correct:3}/{self.processed:3}, {success_ratio:.3f}] {display_server}") + + def print_summary(self): + if self.total == 0: + print(f"\n{'='*60}") + print(f"Results: 0/0 correct (0.0%)") + print(f"{'='*60}") + else: + ci_lower, ci_upper = self.accuracy_ci() + print(f"\n{'='*60}") + print(f"Results: {self.correct}/{self.total} correct ({self.correct/self.total*100:.1f}%) [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]") + print(f"{'='*60}") + + def dump(self): + with self._lock: + tasks_to_save = self.all_tasks if self.all_tasks else self.tasks + all_cases = {} + for i, task_id in tasks_to_save: + question_text, prompt, expected = self.get_case(i) + if task_id in self.task_states.get("cases", {}): + all_cases[task_id] = self.task_states["cases"][task_id] + else: + all_cases[task_id] = { + "task_id": task_id, + "prompt": prompt, + "expected": expected, + "question_text": question_text, + "response": None, + "answer": None, + "grader_log": {}, + "correct": False, + "status": "pending", + "tokens": None, + "tps_gen": None, + "t_gen_ms": None, + "reasoning_content": None, + "server_name": None + } + + ci_lower, ci_upper = self.accuracy_ci() + data = { + "id": self.dataset_type, + "model_name": self.model_name, + "tasks": [tid for _, tid in tasks_to_save], + "task_states": { + "total": self.total, + "correct": self.correct, + "total_time": self.total_time, + "ci_lower": ci_lower, + "ci_upper": ci_upper, + "cases": all_cases, + }, + "sampling_config": self.sampling_config + } + with open(self.output_file, "w") as f: + json.dump(data, f, indent=2) + + self.dump_html(tasks_to_save, all_cases) + + def dump_html(self, tasks_to_save: List[Tuple[int, str]], all_cases: Dict[str, Any]): + html_file = Path(str(self.output_file) + ".html") + + cases = all_cases + completed = {tid: c for tid, c in cases.items() if c.get("status") == "ok"} + n_correct = sum(1 for c in completed.values() if c.get("correct", False)) + n_incorrect = len(completed) - n_correct + n_pending = len(tasks_to_save) - len(completed) + accuracy = n_correct / len(completed) * 100 if completed else 0.0 + ci_lower, ci_upper = wilson_interval(n_correct, len(completed)) if completed else (0.0, 1.0) + + sampling_parts = [] + for k, v in self.sampling_config.items(): + if v is not None: + sampling_parts.append(f"{k}={v}") + sampling_str = ", ".join(sampling_parts) if sampling_parts else "default" + + rows = [] + for i, task_id in tasks_to_save: + case = cases.get(task_id, {}) + status = case.get("status", "pending") + expected = case.get("expected", "") + answer = case.get("answer", "") if status == "ok" else "" + is_correct = case.get("correct", False) if status == "ok" else False + response = case.get("response", "") or "" + prompt = case.get("prompt", "") or "" + grader_log = case.get("grader_log", {}) + + if status == "ok": + status_class = "correct" if is_correct else "incorrect" + status_text = "✓" if is_correct else "✗" + elif status == "pending": + status_class = "pending" + status_text = "–" + else: + status_class = "error" + status_text = "!" + + tokens = case.get("tokens") + tokens_str = str(tokens) if tokens is not None else "" + tps_gen = case.get("tps_gen") + tps_str = f"{tps_gen:.1f}" if tps_gen is not None else "" + t_gen_ms = case.get("t_gen_ms") + t_gen_str = f"{t_gen_ms/1000:.1f}" if t_gen_ms is not None else "" + reasoning_content = case.get("reasoning_content", "") or "" + server_name = case.get("server_name", "") or "" + + escaped_response = self._escape_html(response) + escaped_prompt = self._escape_html(prompt) + escaped_reasoning = self._escape_html(reasoning_content) + grader_log_str = self._escape_html(json.dumps(grader_log, indent=2)) + escaped_server = self._escape_html(server_name) + + rows.append(f""" + {task_id} + {status_text} + {self._escape_html(expected)} + {self._escape_html(answer)} + {tokens_str} + {tps_str} + {t_gen_str} + {escaped_server} + + + +
+ Prompt
{escaped_prompt}
+ Response
{escaped_response}
+ {f'Reasoning
{escaped_reasoning}
' if escaped_reasoning else ''} + Grader
{grader_log_str}
+
+ + """) + + rows_html = "\n".join(rows) + + html_content = f""" + + + +{self.dataset_type.upper()} Eval + + + +
+ {self.dataset_type.upper()} + Model: {self.model_name or 'N/A'} + Accuracy: {accuracy:.1f}% [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%] + Correct: {n_correct} / {len(completed)} + Pending: {n_pending} + Time: {self.total_time:.1f}s + Sampling: {sampling_str} +
+ + + + + + + + + + + + + + + {rows_html} + +
IDGoldAnswerTokensT/sGen sServer
+ + +""" + + with open(html_file, "w") as f: + f.write(html_content) + + def _escape_html(self, s: str) -> str: + return (s.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'")) + + @classmethod + def load(cls, path: Path) -> "EvalState": + with open(path, "r") as f: + data = json.load(f) + + eval_state = cls( + dataset_type=data["id"], + sampling_config=data["sampling_config"], + output_file=path, + model_name=data.get("model_name") + ) + eval_state.load_dataset() + + eval_state.tasks = [] + eval_state.all_tasks = [] + for task_id in data.get("tasks", []): + parts = task_id.rsplit("_", 2) + if len(parts) >= 3: + idx = int(parts[-1]) + else: + idx = 0 + eval_state.tasks.append((idx, task_id)) + eval_state.all_tasks.append((idx, task_id)) + + eval_state.task_states = data.get("task_states", {}) + + cases = eval_state.task_states.get("cases", {}) + eval_state.total = eval_state.task_states.get("total", 0) + eval_state.correct = eval_state.task_states.get("correct", 0) + eval_state.total_time = eval_state.task_states.get("total_time", 0.0) + + if eval_state.total == 0: + eval_state.total = len(cases) + eval_state.correct = sum(1 for c in cases.values() if c.get("correct", False)) + + return eval_state + + def is_complete(self) -> bool: + if not self.all_tasks: + return False + cases = self.task_states.get("cases", {}) + completed = {tid for tid in self.task_states.get("cases", {}).keys() if cases.get(tid, {}).get("status") == "ok"} + return len(completed) == len(self.all_tasks) + + def get_pending_tasks(self) -> List[Tuple[int, str]]: + cases = self.task_states.get("cases", {}) + pending = [] + for i, task_id in self.all_tasks: + status = cases.get(task_id, {}).get("status", "pending") + if status != "ok": + pending.append((i, task_id)) + return pending + + def print_all_tasks(self): + cases = self.task_states.get("cases", {}) + tasks_to_show = self.all_tasks if self.all_tasks else self.tasks + print() + print("Tasks:") + print(" Task ID Dataset Prompt (first 40 chars) Expected Answer Tokens T/s Gen s Status") + for i, task_id in tasks_to_show: + question, prompt, expected = self.get_case(i) + case = cases.get(task_id, {}) + status = case.get("status", "pending") + answer = case.get("answer", "N/A") if status == "ok" else "N/A" + tokens = case.get("tokens") + tokens_str = str(tokens) if tokens is not None else "N/A" + tps_gen = case.get("tps_gen") + tps_str = f"{tps_gen:.1f}" if tps_gen is not None else "N/A" + t_gen_ms = case.get("t_gen_ms") + t_gen_str = f"{t_gen_ms/1000:.1f}" if t_gen_ms is not None else "N/A" + server_name = case.get("server_name", "") or "" + is_correct = case.get("correct", False) if status == "ok" else False + symbol = "✓ " if is_correct else ("✗ " if status == "ok" else "") + first_line = question.split('\n')[0] + question_trunc = first_line[:43] + if len(first_line) > 43: + question_trunc += "..." + else: + question_trunc = question_trunc.ljust(43) + "..." + print(f" {task_id:<20} {self.dataset_type.upper()} {question_trunc:<40} {expected:<10} {answer:<10} {tokens_str:<6} {tps_str:<6} {t_gen_str:<8} {symbol}{status} {server_name}") + print() + + def print_existing_summary(self): + cases = self.task_states.get("cases", {}) + completed_cases = {tid: c for tid, c in cases.items() if c.get("status") == "ok"} + correct = sum(1 for c in completed_cases.values() if c.get("correct", False)) + total = len(completed_cases) + if total == 0: + print(f"{'='*60}") + print(f"Results: 0/0 correct (0.0%)") + print(f"{'='*60}") + else: + ci_lower, ci_upper = self.accuracy_ci() + print(f"{'='*60}") + print(f"Results: {correct}/{total} correct ({correct/total*100:.1f}%) [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]") + print(f"{'='*60}") + + def accuracy_ci(self) -> Tuple[float, float]: + """Compute Wilson score confidence interval from completed cases.""" + cases = self.task_states.get("cases", {}) + completed = {tid: c for tid, c in cases.items() if c.get("status") == "ok"} + correct = sum(1 for c in completed.values() if c.get("correct", False)) + total = len(completed) + return wilson_interval(correct, total) + +def normalize_number(s: str) -> Optional[int]: + match = re.match(r"\d+", s) # match digits from the start + if not match: + return None + return int(match.group(0)) + +class AimeDataset(BaseDataset): + def __init__(self, split: str = "train"): + self.split = split + self.questions: List[Dict] = [] + self._load_dataset() + + def _load_dataset(self): + print(f"Loading AIME dataset (split: {self.split})...") + from datasets import load_dataset + + cache_path = cache_dir / "AI-MO___aimo-validation-aime" / "default" / "0.0.0" + if cache_path.exists(): + print(f"Using cached dataset from {cache_path}") + ds = load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path)) + else: + ds = load_dataset("AI-MO/aimo-validation-aime", split=self.split) + + self.questions = [] + for row in ds: + question = dict(row) + question["dataset_type"] = "aime" + self.questions.append(question) + + print(f"AIME dataset loaded: {len(self.questions)} questions") + + def get_question(self, index: int) -> Dict: + """Get question by index""" + return self.questions[index] + + def get_question_text(self, question: Dict) -> str: + """Get question string""" + return question["problem"] if "problem" in question else question["question"] + + def get_answer(self, question: Dict) -> str: + answer = question["answer"] + if isinstance(answer, str): + normalized = normalize_number(answer) + return str(normalized) if normalized is not None else answer + return str(answer) + + def get_prompt(self, question: Dict) -> str: + """Get formatted prompt for the question""" + return TEMPLATE_REGISTRY[question["dataset_type"]].format( + question=self.get_question_text(question), + ) + +class Aime2025Dataset(BaseDataset): + def __init__(self): + self.questions: List[Dict] = [] + self._load_dataset() + + def _load_dataset(self): + print(f"Loading AIME2025 dataset...") + from datasets import load_dataset + + config_name = "AIME2025-I" + cache_path = cache_dir / "opencompass___AIME2025" / "default" / "0.0.0" + if cache_path.exists(): + print(f"Using cached dataset from {cache_path}") + ds = load_dataset("opencompass/AIME2025", config_name, split="test", cache_dir=str(cache_path)) + else: + ds = load_dataset("opencompass/AIME2025", config_name, split="test") + + self.questions = [] + for row in ds: + question = dict(row) + question["dataset_type"] = "aime2025" + self.questions.append(question) + + print(f"AIME2025 dataset loaded: {len(self.questions)} questions") + + print(f"Loading AIME2025 dataset (part 2)...") + config_name_2 = "AIME2025-II" + cache_path_2 = cache_dir / "opencompass___AIME2025" / "default" / "0.0.0" + if cache_path_2.exists(): + print(f"Using cached dataset from {cache_path_2}") + ds_2 = load_dataset("opencompass/AIME2025", config_name_2, split="test", cache_dir=str(cache_path_2)) + else: + ds_2 = load_dataset("opencompass/AIME2025", config_name_2, split="test") + + for row in ds_2: + question = dict(row) + question["dataset_type"] = "aime2025" + self.questions.append(question) + + print(f"AIME2025 dataset loaded: {len(self.questions)} questions (total)") + + def get_question(self, index: int) -> Dict: + """Get question by index""" + return self.questions[index] + + def get_question_text(self, question: Dict) -> str: + """Get question string""" + return question["question"] + + def get_answer(self, question: Dict) -> str: + answer = question["answer"] + if isinstance(answer, str): + normalized = normalize_number(answer) + return str(normalized) if normalized is not None else answer + return str(answer) + + def get_prompt(self, question: Dict) -> str: + """Get formatted prompt for the question""" + return TEMPLATE_REGISTRY["aime2025"].format( + question=self.get_question_text(question), + ) + +class Gsm8kDataset(BaseDataset): + def __init__(self, split: str = "test"): + self.split = split + self.questions: List[Dict] = [] + self._load_dataset() + + def _load_dataset(self): + print(f"Loading GSM8K dataset (split: {self.split})...") + from datasets import load_dataset + + cache_path = cache_dir / "openai___gsm8k" / "default" / "0.0.0" + if cache_path.exists(): + print(f"Using cached dataset from {cache_path}") + ds = load_dataset("openai/gsm8k", "main", split=self.split, cache_dir=str(cache_path)) + else: + ds = load_dataset("openai/gsm8k", "main", split=self.split) + + self.questions = [] + for row in ds: + question = dict(row) + question["dataset_type"] = "gsm8k" + + # Extract numeric answer from the answer field (already has #### prefix) + gold = question["answer"] + # Split by #### and take the last part + parts = gold.split("####") + if len(parts) > 1: + gold = parts[-1].strip() + # Extract the first number from the remaining text + normalized = normalize_number(gold) + question["gold"] = str(normalized) if normalized is not None else gold + + self.questions.append(question) + + print(f"GSM8K dataset loaded: {len(self.questions)} questions") + + def get_question(self, index: int) -> Dict: + """Get question by index""" + return self.questions[index] + + def get_question_text(self, question: Dict) -> str: + """Get question string""" + return question["problem"] if "problem" in question else question["question"] + + def get_answer(self, question: Dict) -> str: + # GSM8K has pre-extracted gold field, AIME uses answer field + if "gold" in question: + return question["gold"] + answer = question["answer"] + if isinstance(answer, str): + normalized = normalize_number(answer) + return str(normalized) if normalized is not None else answer + return str(answer) + + def get_prompt(self, question: Dict) -> str: + """Get formatted prompt for the question""" + return TEMPLATE_REGISTRY[question["dataset_type"]].format( + question=self.get_question_text(question), + ) + +class GpqaDataset(BaseDataset): + def __init__(self, variant: str = "diamond", seed: int = 1234): + self.variant = variant + self.seed = seed + self.questions: List[Dict] = [] + self._load_dataset() + + def _load_dataset(self): + print(f"Loading GPQA dataset (variant: {self.variant})...") + import pandas as pd + + url = f"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_{self.variant}.csv" + df = pd.read_csv(url) + + rng = random.Random(self.seed) + + self.questions = [] + for _, row in df.iterrows(): + question = row.to_dict() + question["dataset_type"] = "gpqa" + + # Shuffle the answer options + correct_answer = question["Correct Answer"] + incorrect_answers = [ + question["Incorrect Answer 1"], + question["Incorrect Answer 2"], + question["Incorrect Answer 3"] + ] + + # Create list of (answer, is_correct) tuples + options = [(ans, ans == correct_answer) for ans in incorrect_answers] + options.append((correct_answer, True)) + + # Shuffle the options + rng.shuffle(options) + + # Extract shuffled answers and determine correct letter + shuffled_answers = [ans for ans, _ in options] + correct_letter = chr(ord('A') + options.index((correct_answer, True))) + + # Store shuffled answers and correct letter + question["shuffled_answers"] = shuffled_answers + question["correct_letter"] = correct_letter + + self.questions.append(question) + + print(f"GPQA dataset loaded: {len(self.questions)} questions") + + def get_question(self, index: int) -> Dict: + """Get question by index""" + return self.questions[index] + + def get_question_text(self, question: Dict) -> str: + """Get question string""" + return question["Question"] + + def get_answer(self, question: Dict) -> str: + # GPQA returns the correct letter (A, B, C, or D) + return question["correct_letter"] + + def get_prompt(self, question: Dict) -> str: + """Get formatted prompt for the question""" + return TEMPLATE_REGISTRY["gpqa"].format( + Question=self.get_question_text(question), + A=question["shuffled_answers"][0], + B=question["shuffled_answers"][1], + C=question["shuffled_answers"][2], + D=question["shuffled_answers"][3] + ) + +class Grader: + def __init__( + self, + grader_type: str = "llm", + grader_script: Optional[str] = None, + grader_model_name: Optional[str] = None, + grader_server_url: str = "", + dataset_type: str = "aime" + ): + self.grader_type = grader_type + self.grader_script = grader_script + self.grader_model_name = grader_model_name + self.grader_server_url = grader_server_url + self.dataset_type = dataset_type + self.pattern = self._get_pattern() + + def _get_pattern(self) -> Optional[str]: + if self.grader_type == "regex": + return GRADER_PATTERNS.get(self.dataset_type) # Use dataset_type as key + return None + + def _extract_answer_regex(self, pred: str) -> Optional[str]: + """Extract answer using regex pattern""" + if not self.pattern: + return None + + # For AIME datasets, prioritize boxed answers + if self.dataset_type in ["aime", "aime2025"]: + boxed_pattern = r'\\boxed{([^}]+)}' + boxed_matches = re.findall(boxed_pattern, pred, re.IGNORECASE) + if boxed_matches: + # Return the last boxed answer found (most likely the final answer) + return boxed_matches[-1].strip() + + # For other datasets, search for numbers from the end of the text + # This prioritizes numbers that appear later in the response + matches = re.findall(self.pattern, pred, re.IGNORECASE) + if not matches: + return None + + # Process matches from end to start + for match in reversed(matches): + if isinstance(match, tuple): + match = match[0] if match[0] else match[1] + answer = match.strip() + if answer: + return answer + return None + + def _grade_regex(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]: + """Grade using regex pattern matching""" + answer = self._extract_answer_regex(pred) + if answer is None: + return False, None + is_correct = answer.strip() == gold.strip() + return is_correct, answer + + def _grade_cli(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]: + """Grade using external CLI script""" + if not self.grader_script: + raise ValueError("CLI grader requires --grader-script") + + script_path = Path(self.grader_script) + if not script_path.exists(): + raise FileNotFoundError(f"Grader script not found: {self.grader_script}") + + try: + result = subprocess.run( + [str(script_path), "--answer", pred, "--expected", gold], + capture_output=True, + text=True, + timeout=30 + ) + is_correct = result.returncode == 0 + answer = pred if is_correct else None + return is_correct, answer + except subprocess.TimeoutExpired: + return False, None + except Exception as e: + return False, None + + def _grade_llm(self, gold: str, pred: str, problem: str) -> Tuple[bool, Optional[str]]: + """Grade using LLM-based extraction with few-shot examples""" + sample_answers = SAMPLE_ANSWERS.get(self.dataset_type, []) + sample_examples = "\n".join([ + f"Example {i+1}: {ans}" for i, ans in enumerate(sample_answers) + ]) + + system_prompt = f"""You are an answer extraction system. Your task is to extract the answer from the model's response. + +Here are some examples of extracted answers to demonstrate what you are supposed to output: + +{sample_examples} + +When extracting the answer, provide only the extracted answer itself, nothing else. If there is no clear answer that can be extracted from the response, reply with 'no answer'.""" + + user_prompt = f"""Extract the answer from the following response: + +"{pred}" + +Please provide only the extracted answer, nothing else. If there is no clear answer that can be extracted from the response, reply with 'no answer'.""" + + url = f"{self.grader_server_url}/v1/chat/completions" + headers = {"Content-Type": "application/json"} + data = { + "model": self.grader_model_name, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt} + ], + "temperature": 0, + } + #print(json.dumps(data, indent=2)) + + try: + response = requests.post(url, headers=headers, json=data) + response.raise_for_status() + answer = response.json()["choices"][0]["message"]["content"].strip() + is_correct = answer.strip().lower() == gold.strip().lower() + return is_correct, answer + except Exception as e: + return False, None + + def _truncate_response(self, response: str, max_lines: int = 6) -> str: + """Keep only last N lines of response""" + lines = response.split('\n') + return '\n'.join(lines[-max_lines:]) if len(lines) > max_lines else response + + def grade(self, gold: str, pred: str, problem: str = "") -> Tuple[bool, Optional[str]]: + """Grade the response""" + if self.grader_type == "regex": + return self._grade_regex(gold, pred) + elif self.grader_type == "cli": + return self._grade_cli(gold, pred) + elif self.grader_type == "llm": + return self._grade_llm(gold, pred, problem) + else: + raise ValueError(f"Unknown grader type: {self.grader_type}") + +class Processor: + def __init__( + self, + server_configs: List[ServerConfig], + grader: Grader, + model_name: Optional[str] = None, + n_predict: int = -1 + ): + self.server_configs = server_configs + self.grader = grader + self.model_name = model_name + self.n_predict = n_predict + + @staticmethod + def _check_server(server_config: ServerConfig) -> List[str]: + url = f"{server_config.url}/v1/models" + try: + response = requests.get(url) + response.raise_for_status() + models = [m["id"] for m in response.json().get("data", [])] + return models + except Exception as e: + print(f"Error: Cannot reach server {server_config.name} ({server_config.url}): {e}", file=sys.stderr) + sys.exit(1) + + def _make_request( + self, server_config: ServerConfig, eval_state: EvalState, prompt: str + ) -> Tuple[Dict[str, Any], int, Optional[float], Optional[float], str]: + url = f"{server_config.url}/v1/chat/completions" + headers = {"Content-Type": "application/json"} + data = { + "model": self.model_name if self.model_name else "llama", + "messages": [{"role": "user", "content": prompt}], + "n_predict": self.n_predict + } + if eval_state.sampling_config.get("temperature") is not None: + data["temperature"] = eval_state.sampling_config["temperature"] + if eval_state.sampling_config.get("top_k") is not None: + data["top_k"] = eval_state.sampling_config["top_k"] + if eval_state.sampling_config.get("top_p") is not None: + data["top_p"] = eval_state.sampling_config["top_p"] + if eval_state.sampling_config.get("min_p") is not None: + data["min_p"] = eval_state.sampling_config["min_p"] + + response = requests.post(url, headers=headers, json=data) + response.raise_for_status() + result = response.json() + tokens = result.get("usage", {}).get("completion_tokens", 0) + timings = result.get("timings", {}) + tps_gen = timings.get("predicted_per_second") if timings else None + t_gen_ms = timings.get("predicted_ms") if timings else None + finish_reason = result.get("choices", [{}])[0].get("finish_reason", "stop") + return result, tokens, tps_gen, t_gen_ms, finish_reason + + def _process_single_case( + self, server_config: ServerConfig, eval_state: EvalState, i: int, task_id: str + ) -> TaskState: + question_text, prompt, expected = eval_state.get_case(i) + + task_state = TaskState( + task_id=task_id, + prompt=prompt, + expected=expected, + question_text=question_text, + server_name=server_config.name + ) + + try: + response, tokens, tps_gen, t_gen_ms, finish_reason = self._make_request(server_config, eval_state, prompt) + result = response["choices"][0]["message"]["content"] + reasoning_content = response["choices"][0].get("message", {}).get("reasoning_content") + task_state.response = result + task_state.tokens = tokens + task_state.tps_gen = tps_gen + task_state.t_gen_ms = t_gen_ms + task_state.reasoning_content = reasoning_content + + if finish_reason != "stop": + task_state.status = f"error: finish_reason={finish_reason}" + eval_state.add_result( + task_id, prompt, expected, result, None, + {"finish_reason": finish_reason}, False, task_state.status, + tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name + ) + eval_state.dump() + return task_state + + result_truncated = self.grader._truncate_response(result, max_lines=10) + is_correct, answer = self.grader.grade(expected, result_truncated, prompt) + + grader_log = { + "pred": result_truncated, + "grader_type": self.grader.grader_type + } + if self.grader.grader_type == "regex" and self.grader.pattern: + grader_log["pattern"] = self.grader.pattern + + task_state.correct = is_correct + task_state.answer = answer + task_state.grader_log = grader_log + task_state.status = "ok" + + eval_state.add_result( + task_id, prompt, expected, result, answer, + grader_log, is_correct, "ok", + tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name + ) + + eval_state.dump() + + except Exception as e: + task_state.status = f"error: {str(e)}" + + return task_state + + @staticmethod + def _worker( + server_config: ServerConfig, + processor: "Processor", + eval_state: EvalState, + task_queue: Queue, + results_queue: Queue, + ): + """Worker that pulls tasks from a shared queue and sends them to its server.""" + while True: + task = task_queue.get() + if task is None: # sentinel + task_queue.task_done() + break + try: + i, task_id = task + result = processor._process_single_case(server_config, eval_state, i, task_id) + results_queue.put(result) + finally: + task_queue.task_done() + + def evaluate(self, eval_state: EvalState, verbose: bool = False, resume: bool = False): + total_tasks = len(eval_state.tasks) + eval_state.total = len(eval_state.all_tasks) if eval_state.all_tasks else total_tasks + eval_state.processed = 0 + start_time = time.time() + + # Check servers and list models + server_models = [self._check_server(sc) for sc in self.server_configs] + + # Print server info + print(f"\nProcessing {len(eval_state.tasks)} {eval_state.dataset_type.upper()} tasks ...") + print(f"Servers ({len(self.server_configs)}):") + for i, sc in enumerate(self.server_configs): + models_str = ", ".join(server_models[i]) if server_models[i] else "(none)" + print(f" {i+1}. {sc.name} — {sc.url} ({sc.threads} threads) [{models_str}]") + print(f"Model: {self.model_name}") + print(f"Grader: {self.grader.grader_type}") + print(f"Sampling: temp={eval_state.sampling_config.get('temperature', 'skip')}, top-k={eval_state.sampling_config.get('top_k', 'skip')}, top-p={eval_state.sampling_config.get('top_p', 'skip')}, min-p={eval_state.sampling_config.get('min_p', 'skip')}") + print() + + # Shared task queue: all workers compete for tasks + task_queue: Queue = Queue() + for i, task_id in eval_state.tasks: + task_queue.put((i, task_id)) + + # Results queue: workers push completed TaskStates here + results_queue: Queue = Queue() + + # Total worker threads across all servers + total_threads = sum(sc.threads for sc in self.server_configs) + + # Add one sentinel per worker so every worker exits cleanly + for _ in range(total_threads): + task_queue.put(None) + + # Launch workers: one ThreadPoolExecutor per server + executors: List[ThreadPoolExecutor] = [] + worker_futures: List[Any] = [] + for server_config in self.server_configs: + executor = ThreadPoolExecutor(max_workers=server_config.threads) + executors.append(executor) + for _ in range(server_config.threads): + future = executor.submit( + self._worker, server_config, self, eval_state, + task_queue, results_queue + ) + worker_futures.append(future) + + # Drain results as they complete + n_correct = 0 + session_time = 0.0 + completed_count = 0 + + while completed_count < total_tasks: + task_state = results_queue.get() + eval_state.processed += 1 + completed_count += 1 + if task_state.correct: + n_correct += 1 + elapsed = time.time() - start_time + eval_state.total_time += elapsed + session_time += elapsed + start_time = time.time() + eval_state.print_progress(task_state, total_tasks, n_correct) + + if verbose: + print(f"\nCase {eval_state.processed}: {task_state.correct}") + print(f" Expected: {task_state.expected}") + if task_state.response: + print(f" Response: {task_state.response}") + if task_state.answer: + print(f" Answer: {task_state.answer}") + print(f" Status: {task_state.status}") + + # Wait for all workers to finish and shut down executors + for future in worker_futures: + future.result() + for executor in executors: + executor.shutdown(wait=True) + + print(f"\nSession time: {session_time:.1f}s | Total accumulated time: {eval_state.total_time:.1f}s") + eval_state.print_summary() + eval_state.dump() + +def main(): + parser = argparse.ArgumentParser( + description="Simplified evaluation tool for llama.cpp" + ) + parser.add_argument( + "--server", + type=str, + default="http://localhost:8033", + help="Comma-separated llama-server URLs (default: http://localhost:8033)" + ) + parser.add_argument( + "--server-name", + type=str, + default="", + help="Comma-separated display names for servers (default: use URLs)" + ) + parser.add_argument( + "--dataset", + type=str, + default="aime", + choices=["aime", "aime2025", "gsm8k", "gpqa"], + help="Dataset type (default: aime)" + ) + parser.add_argument( + "--n_cases", + type=int, + default=None, + help="Number of cases to evaluate (default: all)" + ) + parser.add_argument( + "--seed", + type=int, + default=1234, + help="Random seed for shuffling (default: 1234)" + ) + parser.add_argument( + "--n_predict", + type=int, + default=-1, + help="Max tokens to predict per prompt (default: -1, infinite)" + ) + parser.add_argument( + "--temperature", + type=float, + default=None, + help="Sampling temperature (default: not passed)" + ) + parser.add_argument( + "--top-k", + type=int, + default=None, + help="Top K sampling (default: not passed)" + ) + parser.add_argument( + "--top-p", + type=float, + default=None, + help="Top P sampling (default: not passed)" + ) + parser.add_argument( + "--min-p", + type=float, + default=None, + help="Min P sampling (default: not passed)" + ) + parser.add_argument( + "--threads", + type=str, + default="32", + help="Comma-separated thread counts per server (default: 32)" + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="Model name to append as query parameter (e.g., gpt-oss-20b-hf)" + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Show detailed output for each case" + ) + parser.add_argument( + "--output", + type=Path, + default=Path("llama-eval-state.json"), + help="Output file for eval state (default: llama-eval-state.json)" + ) + parser.add_argument( + "--grader-type", + type=str, + default="llm", + choices=["regex", "cli", "llm"], + help="Grader type: regex, cli, or llm (default: llm)" + ) + parser.add_argument( + "--grader-script", + type=str, + default=None, + help="CLI grader script path (required for --grader-type cli)" + ) + parser.add_argument( + "--grader-server", + type=str, + default="", + help="Server URL for LLM grader (default: same as main server)" + ) + parser.add_argument( + "--grader-model", + type=str, + default="", + help="Model name for LLM grader (default: same as main model)" + ) + parser.add_argument( + "--resume", + action="store_true", + help="Resume from existing eval state" + ) + + args = parser.parse_args() + + # Parse server URLs and thread counts + server_urls = [u.strip() for u in args.server.split(",") if u.strip()] + thread_counts = [int(t.strip()) for t in args.threads.split(",") if t.strip()] + + if len(server_urls) != len(thread_counts): + print(f"Error: --server ({len(server_urls)} URLs) and --threads ({len(thread_counts)} values) must have the same count") + sys.exit(1) + + # Parse server names (optional, defaults to URLs) + if args.server_name: + server_names = [n.strip() for n in args.server_name.split(",") if n.strip()] + if len(server_names) != len(server_urls): + print(f"Error: --server-name ({len(server_names)} names) and --server ({len(server_urls)} URLs) must have the same count") + sys.exit(1) + else: + server_names = server_urls # fallback to URLs + + server_configs = [ + ServerConfig(url=url, threads=threads, name=name) + for url, threads, name in zip(server_urls, thread_counts, server_names) + ] + + if args.dataset == "gpqa" and args.grader_type != "llm": + print("Error: GPQA dataset requires --grader-type llm") + parser.print_help() + sys.exit(1) + + if args.output.exists(): + print(f"Loading existing eval state from {args.output}") + eval_state = EvalState.load(args.output) + + # Verify model matches + if eval_state.model_name is not None and args.model != eval_state.model_name: + print(f"Error: Model mismatch. State has '{eval_state.model_name}', but --model is '{args.model}'") + sys.exit(1) + + eval_state.print_all_tasks() + eval_state.print_existing_summary() + + if eval_state.is_complete(): + return + + print() + + if not args.resume: + print(f"Evaluation incomplete. Run with --resume to continue.") + return + + pending_tasks = eval_state.get_pending_tasks() + print(f"Resuming from {len(pending_tasks)} pending tasks") + + existing_cases = eval_state.task_states.get("cases", {}) + + eval_state.tasks = pending_tasks + eval_state.task_states["cases"] = existing_cases + + grader_server_url = args.grader_server if args.grader_server else server_configs[0].url + grader_model_name = args.grader_model if args.grader_model else args.model + if args.grader_type == "llm" and not grader_model_name: + print("Error: --grader-type llm requires --grader-model or --model") + sys.exit(1) + grader = Grader( + grader_type=args.grader_type, + grader_script=args.grader_script, + grader_model_name=grader_model_name, + grader_server_url=grader_server_url, + dataset_type=eval_state.dataset_type + ) + resume = True + else: + if args.resume: + print("Error: No existing eval state found to resume") + sys.exit(1) + + grader_server_url = args.grader_server if args.grader_server else server_configs[0].url + grader_model_name = args.grader_model if args.grader_model else args.model + if args.grader_type == "llm" and not grader_model_name: + print("Error: --grader-type llm requires --grader-model or --model") + sys.exit(1) + + grader = Grader( + grader_type=args.grader_type, + grader_script=args.grader_script, + grader_model_name=grader_model_name, + grader_server_url=grader_server_url, + dataset_type=args.dataset + ) + + if args.grader_type == "llm" and not args.grader_server: + print("Warning: Using same server for LLM grader (no --grader-server specified)") + + sampling_config = {} + if args.temperature is not None: + sampling_config["temperature"] = args.temperature + if args.top_k is not None: + sampling_config["top_k"] = args.top_k + if args.top_p is not None: + sampling_config["top_p"] = args.top_p + if args.min_p is not None: + sampling_config["min_p"] = args.min_p + + eval_state = EvalState( + dataset_type=args.dataset, + sampling_config=sampling_config, + output_file=args.output, + model_name=args.model + ) + eval_state.load_dataset(seed=args.seed) + eval_state.setup_tasks(n_cases=args.n_cases, seed=args.seed) + eval_state.dump() + resume = False + + eval_state.print_all_tasks() + + processor = Processor( + server_configs=server_configs, + grader=grader, + model_name=args.model, + n_predict=args.n_predict + ) + + processor.evaluate(eval_state, verbose=args.verbose, resume=resume) + print(f"\nEval state dumped to {args.output}") + +if __name__ == "__main__": + main() diff --git a/examples/llama-eval/llama-server-simulator.py b/examples/llama-eval/llama-server-simulator.py new file mode 100644 index 000000000..2f9cdc545 --- /dev/null +++ b/examples/llama-eval/llama-server-simulator.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python3 + +import argparse +import json +import random +import re +import time +import sys +import os +import threading +from http.server import HTTPServer, BaseHTTPRequestHandler +from typing import Dict, List, Optional +from dataclasses import dataclass +from pathlib import Path + +import datasets + +# Set cache directory for HuggingFace datasets +cache_dir = Path.home() / ".cache" / "huggingface" / "datasets" +cache_dir.mkdir(parents=True, exist_ok=True) +os.environ["HF_DATASETS_CACHE"] = str(cache_dir) + +def dice(s1: str, s2: str) -> float: + """Calculate Dice coefficient between two strings based on bigram overlap.""" + if not s1 and not s2: + return 1.0 + + def _bigrams(s: str): + return [s[i : i + 2] for i in range(len(s) - 1)] + + bigrams1 = _bigrams(s1) + bigrams2 = _bigrams(s2) + + if not bigrams1 and not bigrams2: + return 1.0 + + from collections import Counter + + freq1 = Counter(bigrams1) + freq2 = Counter(bigrams2) + + intersection = sum(min(freq1[bg], freq2[bg]) for bg in freq1) + dice_coeff = 2 * intersection / (len(bigrams1) + len(bigrams2)) + return dice_coeff + +def debug_log(message: str): + """Log debug messages to both stdout and a file""" + print(message, file=sys.stderr) + with open("/tmp/simulator-debug.log", "a") as f: + f.write(message + "\n") + +simulator: Optional["Simulator"] = None + +@dataclass +class EvalState: + id: str + tasks: List[str] + task_states: Dict[str, Dict] + sampling_config: Dict + +def normalize_number(s: str) -> Optional[int]: + match = re.match(r"\d+", s) # match digits from the start + if not match: + return None + return int(match.group(0)) + +class AimeDataset: + def __init__(self, split: str = "train"): + self.split = split + self.questions: List[Dict] = [] + self._load_dataset() + + def _load_dataset(self): + print(f"Loading AIME dataset (split: {self.split})...") + + cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0" + if cache_path.exists(): + print(f"Using cached dataset from {cache_path}") + ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path)) + else: + ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split) + + self.questions = list(ds) + print(f"AIME dataset loaded: {len(self.questions)} questions") + + def find_question(self, request_text: str) -> Optional[Dict]: + best_match = None + best_distance = -1 + best_index = -1 + + for i, question in enumerate(self.questions): + question_text = question["problem"] + request_lower = request_text.lower() + question_lower = question_text.lower() + + # Exact match + if question_lower == request_lower: + debug_log(f"DEBUG: Found exact match at index {i}") + return question + + # Remove LaTeX formatting for more flexible matching + question_no_latex = re.sub(r'\$[^$]+\$', '', question_text) + if question_no_latex.lower() == request_lower: + debug_log(f"DEBUG: Found match (no LaTeX) at index {i}") + return question + + # Calculate Dice coefficient for partial matches + # Only consider if request is at least 50% of question length + if len(request_lower) >= len(question_lower) * 0.5: + distance = dice(question_lower, request_lower) + + if distance > best_distance: + best_distance = distance + best_match = question + best_index = i + + if best_match and best_distance > 0.3: # Threshold for partial match + debug_log(f"DEBUG: Found best partial match at index {best_index} with distance {best_distance:.3f}") + return best_match + + debug_log(f"DEBUG: No matching question found for: {request_text[:100]}...") + return None + + def get_answer(self, question: Dict) -> str: + answer = question["answer"] + if isinstance(answer, str): + normalized = normalize_number(answer) + return str(normalized) if normalized is not None else answer + return str(answer) + +class Simulator: + def __init__( + self, + port: int = 8033, + host: str = "localhost", + success_rate: float = 0.8, + dataset_split: str = "train" + ): + self.port = port + self.host = host + self.success_rate = success_rate + self.dataset = AimeDataset(dataset_split) + self.eval_state = EvalState( + id="aime-2025", + tasks=["aime"], + task_states={}, + sampling_config={"temperature": 0, "max_tokens": 2048} + ) + + def _generate_response( + self, + question: Dict, + should_be_correct: bool + ) -> Dict: + expected_answer = self.dataset.get_answer(question) + + if should_be_correct: + response_text = expected_answer + else: + response_text = self._generate_wrong_answer(question) + + return { + "id": f"chatcmpl-{int(time.time())}", + "object": "chat.completion", + "created": int(time.time()), + "model": "llama", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": response_text + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150 + } + } + + def _generate_wrong_answer(self, question: Dict) -> str: + expected_answer = self.dataset.get_answer(question) + + if expected_answer.isdigit(): + wrong_answer = str(int(expected_answer) + 1) + else: + wrong_answer = expected_answer + " (wrong)" + + return wrong_answer + + def _process_request(self, request_data: Dict) -> Dict: + messages = request_data.get("messages", []) + if not messages: + return {"error": "No messages in request"} + + request_text = messages[0].get("content", "") + debug_log(f"DEBUG: Received request with content: {request_text[:150]}...") + + question = self.dataset.find_question(request_text) + if not question: + debug_log(f"DEBUG: find_question returned None") + return {"error": "No matching question found"} + + should_be_correct = random.random() < self.success_rate + + response = self._generate_response(question, should_be_correct) + + task_id = "aime" + self.eval_state.task_states[task_id] = { + "correct": should_be_correct, + "expected": self.dataset.get_answer(question), + "predicted": response["choices"][0]["message"]["content"] + } + + return response + +class RequestHandler(BaseHTTPRequestHandler): + def do_POST(self): + if self.path != "/v1/chat/completions": + self._send_json({"error": "Not found"}, 404) + return + + try: + content_length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(content_length) + request_data = json.loads(body) if body else None + + if not request_data: + self._send_json({"error": "Invalid JSON"}, 400) + return + + if simulator is None: + self._send_json({"error": "Simulator not initialized"}, 500) + return + + response = simulator._process_request(request_data) + self._send_json(response, 200) + + except json.JSONDecodeError: + self._send_json({"error": "Invalid JSON"}, 400) + except Exception as e: + print(f"Error processing request: {e}") + self._send_json({"error": str(e)}, 500) + + def _send_json(self, data: dict, status: int = 200): + body = json.dumps(data).encode("utf-8") + self.send_response(status) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def log_message(self, format, *args): + # Suppress default request logging + pass + + +def main(): + parser = argparse.ArgumentParser( + description="llama-server simulator for testing eval scripts" + ) + parser.add_argument( + "--port", + type=int, + default=8033, + help="Server port (default: 8033)" + ) + parser.add_argument( + "--host", + type=str, + default="localhost", + help="Server host (default: localhost)" + ) + parser.add_argument( + "--success-rate", + type=float, + default=0.8, + help="Success rate 0-1 (default: 0.8)" + ) + parser.add_argument( + "--dataset-split", + type=str, + default="train", + help="AIME dataset split to use (default: train)" + ) + + args = parser.parse_args() + + global simulator + simulator = Simulator( + port=args.port, + host=args.host, + success_rate=args.success_rate, + dataset_split=args.dataset_split + ) + + server = HTTPServer((args.host, args.port), RequestHandler) + server_thread = threading.Thread(target=server.serve_forever, daemon=True) + server_thread.start() + + print("\n=== llama-server-simulator ===") + print(f"Server running on http://{args.host}:{args.port}") + print(f"Success rate: {args.success_rate}") + print(f"AIME dataset loaded: {len(simulator.dataset.questions)} questions") + print("\nPress Ctrl+C to stop\n") + + try: + server_thread.join() + except KeyboardInterrupt: + print("\nShutting down...") + server.shutdown() + +if __name__ == "__main__": + main() diff --git a/examples/llama-eval/test-simulator.sh b/examples/llama-eval/test-simulator.sh new file mode 100644 index 000000000..f3ddf3e95 --- /dev/null +++ b/examples/llama-eval/test-simulator.sh @@ -0,0 +1,86 @@ +#!/bin/bash + +set -e + +# Get the directory where this script is located +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +echo "=== llama-server-simulator Test Script ===" +echo "" + +PORT=8033 +SUCCESS_RATE=0.8 +TEST_PORT=8034 + +echo "Starting simulator on port $PORT with success rate $SUCCESS_RATE..." +source "$SCRIPT_DIR/venv/bin/activate" +python3 "$SCRIPT_DIR/llama-server-simulator.py" --port $PORT --success-rate $SUCCESS_RATE > /tmp/simulator-test.log 2>&1 & +SIMULATOR_PID=$! + +echo "Waiting for simulator to start..." +sleep 5 + +# Helper function to make a request and extract the answer +make_request() { + local question="$1" + curl -s -X POST http://localhost:$PORT/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d "{ + \"model\": \"llama\", + \"messages\": [ + {\"role\": \"user\", \"content\": \"$question\"} + ], + \"temperature\": 0, + \"max_tokens\": 2048 + }" | python3 -c "import sys, json; data = json.load(sys.stdin); print(data.get('choices', [{}])[0].get('message', {}).get('content', data.get('error', 'No response')))" +} + +# Test question (repeated in multiple tests) +TEST_QUESTION="Quadratic polynomials P(x) and Q(x) have leading coefficients 2 and -2, respectively. The graphs of both polynomials pass through the two points (16,54) and (20,53). Find P(0) + Q(0)." + +echo "" +echo "=== Test 1: Correct Answer ===" +echo "Sending request with known question..." +answer=$(make_request "$TEST_QUESTION") +echo "Answer: $answer" +echo "Expected: 116" +echo "Correct: $([ "$answer" == "116" ] && echo "Yes" || echo "No")" + +echo "" +echo "=== Test 2: Wrong Answer ===" +echo "Sending request with known question (success rate 0.0)..." +answer=$(make_request "$TEST_QUESTION") +echo "Answer: $answer" +echo "Expected: 116" +echo "Correct: $([ "$answer" == "116" ] && echo "Yes" || echo "No")" + +echo "" +echo "=== Test 3: No Matching Question ===" +echo "Sending request with non-matching text..." +response=$(make_request "What is the capital of France?") +echo "Response: $response" +echo "Expected: No matching question found" +echo "Correct: $([ "$response" == "No matching question found" ] && echo "Yes" || echo "No")" + +echo "" +echo "=== Test 4: Success Rate Verification ===" +echo "Sending 10 requests to test success rate..." +correct_count=0 +for i in {1..10}; do + answer=$(make_request "$TEST_QUESTION") + if [ "$answer" == "116" ]; then + correct_count=$((correct_count + 1)) + fi + echo " Request $i: Answer = $answer" +done +echo "Correct answers: $correct_count/10" +echo "Expected: ~8/10 (80% success rate)" +echo "Success rate: $(echo "scale=1; $correct_count * 10" | bc)%" + +echo "" +echo "=== Test Complete ===" +echo "Stopping simulator..." +kill $SIMULATOR_PID 2>/dev/null +wait $SIMULATOR_PID 2>/dev/null || true + +echo "Simulator stopped." diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 0f3f017b5..c4f08091e 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -4,6 +4,7 @@ # include # if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1) # define STRIDED_ITERATOR_AVAILABLE +# include # endif using namespace cub; #endif // GGML_CUDA_USE_CUB diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 1888faf3e..25bcbda57 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3945,10 +3945,25 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph // closure check: the trailing add must read the same x as the leading mul const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0]; - const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16); + // Kernel iterates over total = T * C, so x and add must be 2D and + // a / inv_b must collapse to [1, C, 1, 1]. Higher dims are not handled. + const bool dim_ok = (x->ne[2] == 1 && x->ne[3] == 1) && + (add->ne[2] == 1 && add->ne[3] == 1) && + (a->ne[2] == 1 && a->ne[3] == 1); const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1]; - if (type_ok && shape_ok && x_in_add == x && add->type == x->type) { + // x must be in the supported whitelist and every operand / intermediate + // result must share x's type, since launch_snake casts a / inv_b as + // float and templates the kernel on a single T. Mixed precision chains + // fall back to the naive path. + const ggml_tensor * sin1 = cgraph->nodes[i + 1]; + const bool types_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16) && + (a->type == x->type) && (inv_b->type == x->type) && + (mul0->type == x->type) && (sin1->type == x->type) && + (sqr->type == x->type) && (mul1->type == x->type) && + (add->type == x->type); + + if (types_ok && shape_ok && dim_ok && x_in_add == x) { ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add); return 4; } @@ -5312,12 +5327,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: - case GGML_OP_ADD: case GGML_OP_ADD_ID: case GGML_OP_ADD1: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SQRT: @@ -5326,6 +5337,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CLAMP: case GGML_OP_LOG: return true; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_SSM_SCAN: { if (op->src[3]->ne[0] == 1) { // Mamba2 diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu index c1b223727..28c79ab46 100644 --- a/ggml/src/ggml-cuda/im2col.cu +++ b/ggml/src/ggml-cuda/im2col.cu @@ -1,5 +1,6 @@ #include "im2col.cuh" +#define MAX_GRIDDIM_Y 65535 #define MAX_GRIDDIM_Z 65535 template @@ -18,25 +19,25 @@ static __global__ void im2col_kernel( const int64_t ikh = rem / KW; const int64_t ikw = rem - ikh * KW; - for (int64_t iow = blockIdx.y; iow < OW; iow += gridDim.y) { - for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) { - const int64_t in = iz / OH; - const int64_t ioh = iz - in * OH; + for (int64_t iow = blockIdx.y; iow < OW; iow += MAX_GRIDDIM_Y) { + for (int64_t iz = blockIdx.z; iz < N_OH; iz += MAX_GRIDDIM_Z) { + const int64_t in = iz / OH; + const int64_t ioh = iz - in * OH; - const int64_t iiw = iow * s0 + ikw * d0 - p0; - const int64_t iih = ioh * s1 + ikh * d1 - p1; + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; - const int64_t offset_dst = - ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw; + const int64_t offset_dst = + ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw; - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; - } else { - const int64_t offset_src = iic * IC_IH_IW + in * IH_IW; - dst[offset_dst] = x[offset_src + iih * IW + iiw]; + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = iic * IC_IH_IW + in * IH_IW; + dst[offset_dst] = x[offset_src + iih * IW + iiw]; + } } } - } GGML_UNUSED(IC); GGML_UNUSED(KH); @@ -52,7 +53,7 @@ static void im2col_cuda(const float * x, T* dst, const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; const int64_t N_OH = N * OH; const int64_t KH_KW = KW*KH; - dim3 block_nums(num_blocks, MIN(OW, MAX_GRIDDIM_Z), MIN(N_OH, MAX_GRIDDIM_Z)); + dim3 block_nums(num_blocks, MIN(OW, MAX_GRIDDIM_Y), MIN(N_OH, MAX_GRIDDIM_Z)); im2col_kernel<<>>(x, dst, IC, IW, IH, OH, OW, KW, KH, IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW, s0, s1, p0, p1, d0, d1); @@ -137,23 +138,24 @@ static __global__ void im2col_3d_kernel( const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; const int64_t ikw = i % KW; - const int64_t iow = blockIdx.y; - for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) { - const int64_t in = iz / OD_OH; - const int64_t iod = (iz - in*OD_OH) / OH; - const int64_t ioh = iz % OH; + for (int64_t iow = blockIdx.y; iow < OW; iow += MAX_GRIDDIM_Y) { + for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz += MAX_GRIDDIM_Z) { + const int64_t in = iz / OD_OH; + const int64_t iod = (iz - in*OD_OH) / OH; + const int64_t ioh = iz % OH; - const int64_t iiw = iow * s0 + ikw * d0 - p0; - const int64_t iih = ioh * s1 + ikh * d1 - p1; - const int64_t iid = iod * s2 + ikd * d2 - p2; + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; + const int64_t iid = iod * s2 + ikd * d2 - p2; - const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { - dst[offset_dst] = 0.0f; - } else { - const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x); - dst[offset_dst] = src[offset_src]; + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x); + dst[offset_dst] = src[offset_src]; + } } } } @@ -179,7 +181,7 @@ static void im2col_3d_cuda(const float * src, T* dst, const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; - dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z)); + dim3 block_nums(num_blocks, MIN(OW, MAX_GRIDDIM_Y), MIN(N_OD_OH, MAX_GRIDDIM_Z)); im2col_3d_kernel<<>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW, IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW, diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index d211bf79f..f0147af84 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -647,19 +647,30 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_m return res; } -ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, const ggml_tensor * op, int nsg, int nxpsg, int r1ptg) { char base[256]; char name[256]; + const ggml_type tsrc0 = op->src[0]->type; + const ggml_type tsrc1 = op->src[1]->type; + const int ne12 = op->src[1]->ne[2]; + const int r2 = ne12 / op->src[0]->ne[2]; + const int r3 = op->src[1]->ne[3] / op->src[0]->ne[3]; + + GGML_ASSERT(ne12 <= INT16_MAX && r2 <= INT16_MAX && r3 <= INT16_MAX); + snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg); - snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg); + snprintf(name, 256, "%s_nsg=%d_nxpsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, nxpsg, ne12, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); - ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, (int16_t) r2, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, (int16_t) r3, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -687,8 +698,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta ? (op->ne[0] % NRA != 0 || op->ne[1] % NRB != 0) : (op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0); + GGML_ASSERT(op->src[1]->ne[2] <= INT16_MAX && op->src[1]->ne[3] <= INT16_MAX); + const int16_t ne12 = (int16_t) op->src[1]->ne[2]; + const int16_t ne13 = (int16_t) op->src[1]->ne[3]; + const int16_t r2 = (int16_t) (ne12 / op->src[0]->ne[2]); + const int16_t r3 = (int16_t) (ne13 / op->src[0]->ne[3]); + snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); - snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out); + snprintf(name, 256, "%s_bci=%d_bco=%d_ne12=%d_ne13=%d_r2=%d_r3=%d", + base, bc_inp, bc_out, ne12, ne13, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { @@ -696,6 +714,10 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1); + ggml_metal_cv_set_int16(cv, ne12, FC_MUL_MM + 2); + ggml_metal_cv_set_int16(cv, ne13, FC_MUL_MM + 3); + ggml_metal_cv_set_int16(cv, r2, FC_MUL_MM + 4); + ggml_metal_cv_set_int16(cv, r3, FC_MUL_MM + 5); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -877,14 +899,21 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta } }; + GGML_ASSERT(ne12 <= INT16_MAX && ne13 <= INT16_MAX); + const int16_t r2 = (int16_t) (ne12 / ne02); + const int16_t r3 = (int16_t) (ne13 / ne03); + snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); - snprintf(name, 256, "%s_nsg=%d", base, nsg); + snprintf(name, 256, "%s_nsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, ne12, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, r2, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, r3, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -1102,6 +1131,9 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m ggml_metal_cv_t cv = ggml_metal_cv_init(); ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 4718ca083..1f212a92f 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -129,7 +129,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); -struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, const struct ggml_tensor * op, int nsg, int nxpsg, int r1ptg); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 5fa162c87..a114391c2 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2120,7 +2120,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { GGML_ABORT("unsupported ne11"); }; - auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); + auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op, nsg, nxpsg, r1ptg); ggml_metal_kargs_mul_mv_ext args = { /*.ne00 =*/ ne00, diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index c372eaede..3882b9558 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3353,6 +3353,9 @@ static inline void helper_mv_reduce_and_write( constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]]; constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]]; +constant short FC_mul_mv_ne12 [[function_constant(FC_MUL_MV + 2)]]; +constant short FC_mul_mv_r2 [[function_constant(FC_MUL_MV + 3)]]; +constant short FC_mul_mv_r3 [[function_constant(FC_MUL_MV + 4)]]; template void mul_vec_q_n_f32_impl( @@ -3376,10 +3379,10 @@ void mul_vec_q_n_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); @@ -3388,7 +3391,7 @@ void mul_vec_q_n_f32_impl( // pointers to src0 rows device const block_q_type * ax[NR0]; FOR_UNROLL (int row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } @@ -3462,8 +3465,8 @@ void kernel_mul_mv_q1_0_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13; @@ -3471,7 +3474,7 @@ void kernel_mul_mv_q1_0_f32_impl( device const block_q1_0 * ax[nr0]; for (int row = 0; row < nr0; ++row) { - const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0); } @@ -3590,10 +3593,10 @@ void kernel_mul_mv_q8_0_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); @@ -3602,7 +3605,7 @@ void kernel_mul_mv_q8_0_f32_impl( // pointers to src0 rows device const block_q8_0 * ax[NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } @@ -3682,10 +3685,10 @@ void kernel_mul_mv_ext_q4_f32_impl( const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; - const int i12 = i1m%args.ne12; - const int i13 = i1m/args.ne12; + const int i12 = i1m%FC_mul_mv_ne12; + const int i13 = i1m/FC_mul_mv_ne12; - const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; @@ -3785,10 +3788,10 @@ void kernel_mul_mv_ext_q4x4_f32_impl( const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; - const int i12 = i1m%args.ne12; - const int i13 = i1m/args.ne12; + const int i12 = i1m%FC_mul_mv_ne12; + const int i13 = i1m/FC_mul_mv_ne12; - const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; @@ -4000,10 +4003,10 @@ void kernel_mul_mv_t_t_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const T0 * x = (device const T0 *) (src0 + offset0); @@ -4012,7 +4015,7 @@ void kernel_mul_mv_t_t_impl( // pointers to src0 rows device const T0 * ax [NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const T0 *) ((device char *) src0 + offset0); } @@ -4122,10 +4125,10 @@ void kernel_mul_mv_t_t_4_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); @@ -4135,7 +4138,7 @@ void kernel_mul_mv_t_t_4_impl( device const T0 * ax [NR0]; device const T04 * ax4[NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax [row] = (device const T0 *) ((device char *) src0 + offset0); ax4[row] = (device const T04 *) ((device char *) src0 + offset0); @@ -4239,10 +4242,10 @@ void kernel_mul_mv_t_t_short_impl( return; } - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; device const T0 * x = (device const T0 *) (src0 + offset0); @@ -7462,10 +7465,10 @@ void kernel_mul_mv_q2_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0); @@ -7567,10 +7570,10 @@ void kernel_mul_mv_q3_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0); @@ -7741,10 +7744,10 @@ void kernel_mul_mv_q4_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0); @@ -7853,10 +7856,10 @@ void kernel_mul_mv_q5_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); @@ -7989,10 +7992,10 @@ void kernel_mul_mv_q6_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); @@ -8094,10 +8097,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0); @@ -8202,10 +8205,10 @@ void kernel_mul_mv_iq2_xs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0); @@ -8321,10 +8324,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0); @@ -8433,10 +8436,10 @@ void kernel_mul_mv_iq3_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0); @@ -8545,10 +8548,10 @@ void kernel_mul_mv_iq2_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0); @@ -8658,10 +8661,10 @@ void kernel_mul_mv_iq1_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0); @@ -8757,10 +8760,10 @@ void kernel_mul_mv_iq1_m_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0); @@ -8866,10 +8869,10 @@ void kernel_mul_mv_iq4_nl_f32_impl( const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); @@ -8975,10 +8978,10 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int im = tgpig.z; const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); @@ -9086,10 +9089,10 @@ void kernel_mul_mv_mxfp4_f32_impl( const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0); @@ -9304,6 +9307,10 @@ kernel void kernel_diag_f32( constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]]; constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; +constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]]; +constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]]; +constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]]; +constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]]; // each block_q contains 16*nl weights #ifdef GGML_METAL_HAS_TENSOR @@ -9330,11 +9337,11 @@ kernel void kernel_mul_mm( // Batch dimension handling const int im = tgpig.z; - const int i12 = im % args.ne12; - const int i13 = im / args.ne12; + const int i12 = im % FC_mul_mm_ne12; + const int i13 = im / FC_mul_mm_ne12; // Batch offsets for srcA and srcB - const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03; // Tile dimensions constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X; @@ -9473,10 +9480,10 @@ kernel void kernel_mul_mm( short il = il0; - const int i12 = im%args.ne12; - const int i13 = im/args.ne12; + const int i12 = im % FC_mul_mm_ne12; + const int i13 = im / FC_mul_mm_ne12; - const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03; const short offset1 = il0/nl; device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ea3f74bca..54ce2c3aa 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -686,6 +686,15 @@ struct vk_device_struct { bool mul_mat_id_m[GGML_TYPE_COUNT]; bool mul_mat_id_s[GGML_TYPE_COUNT]; + // Separate flags for the q8_1 (integer dot) mmq path, whose shader uses + // a different shared-memory layout than the float matmul shaders. + bool mul_mat_l_int[GGML_TYPE_COUNT]; + bool mul_mat_m_int[GGML_TYPE_COUNT]; + bool mul_mat_s_int[GGML_TYPE_COUNT]; + bool mul_mat_id_l_int[GGML_TYPE_COUNT]; + bool mul_mat_id_m_int[GGML_TYPE_COUNT]; + bool mul_mat_id_s_int[GGML_TYPE_COUNT]; + vk::DescriptorSetLayout dsl; vk_matmul_pipeline pipeline_matmul_f32 {}; @@ -860,7 +869,7 @@ struct vk_device_struct { vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32; vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32; - std::map pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT]; + std::map pipeline_flash_attn_f32_f16; std::map, vk_pipeline> pipeline_fa_mask_opt; @@ -2938,10 +2947,10 @@ struct vk_fa_tuning_params { } }; -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type); +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type); static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); -static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { +static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { vk_fa_tuning_params result{}; result.path = FA_SCALAR; @@ -2993,7 +3002,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0; - if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) { + if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, k_type, v_type)) { result.block_rows /= 2; } @@ -3016,10 +3025,11 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, return result; } -static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { +static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { GGML_UNUSED(n_rows); GGML_UNUSED(n_kv); - GGML_UNUSED(kv_type); + GGML_UNUSED(k_type); + GGML_UNUSED(v_type); GGML_UNUSED(f32acc); vk_fa_tuning_params result{}; @@ -3075,12 +3085,6 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device } static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { - // Mixed K/V is only implemented on the coopmat2 (flash_attn_cm2) path; never use scalar/cm1. - if (k_type != v_type) { - GGML_ASSERT(device->coopmat2); - return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); - } - FaCodePath path = device->coopmat2 ? FA_COOPMAT2 : device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; @@ -3092,7 +3096,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ if (path == FA_COOPMAT1) { bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) || (!f32acc && device->coopmat_support_16x16x16_f16acc); - const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); + const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc); if (!shape_ok || !shmem_ok) { @@ -3112,9 +3116,9 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ switch (path) { case FA_SCALAR: - return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); + return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); case FA_COOPMAT1: - return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); + return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); case FA_COOPMAT2: return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); default: @@ -3217,6 +3221,70 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec return supported; } +// Shmem usage for the q8_1 mmq shader (mul_mmq.comp), which uses +// block_a_cache / block_b_cache layouts (see mul_mmq_shmem_types.glsl) rather +// than the float load buffers checked by ggml_vk_matmul_shmem_support. +// Sizes follow std430 rules. Returns false for types without a q8_1 pipeline. +static bool ggml_vk_matmul_int_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { + + // FLOAT_TYPE in the shader is float16_t with fp16 support, otherwise float. + const uint32_t fp_size = device->fp16 ? 2u : 4u; + const uint32_t fp_align = fp_size; + const uint32_t fp2_size = 2u * fp_size; + const uint32_t fp2_align = device->fp16 ? 4u : 8u; + + struct member { uint32_t size, align; }; + auto std430_size = [](std::initializer_list members) { + uint32_t off = 0, struct_align = 1; + for (const auto &m : members) { + off = (off + m.align - 1) & ~(m.align - 1); + off += m.size; + struct_align = std::max(struct_align, m.align); + } + return (off + struct_align - 1) & ~(struct_align - 1); + }; + + uint32_t block_a_size = 0; + switch (src0_type) { + case GGML_TYPE_Q4_0: block_a_size = std430_size({{16, 4}, {fp_size, fp_align}}); break; // qs[16/4] + dm + case GGML_TYPE_Q4_1: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + dm(vec2) + case GGML_TYPE_Q5_0: block_a_size = std430_size({{16, 4}, {4, 4}, {fp_size, fp_align}}); break; // qs[16/4] + qh + dm + case GGML_TYPE_Q5_1: block_a_size = std430_size({{16, 4}, {4, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + qh + dm(vec2) + case GGML_TYPE_Q8_0: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + dm + case GGML_TYPE_MXFP4: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + d + case GGML_TYPE_Q2_K: block_a_size = std430_size({{ 8, 4}, {2, 2}, {fp2_size, fp2_align}}); break; // qs[2] + scales(u8vec2) + dm(vec2) + case GGML_TYPE_Q3_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + d_scales(vec2) + case GGML_TYPE_Q4_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + dm(vec2) + case GGML_TYPE_Q5_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + dm(vec2) + case GGML_TYPE_Q6_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + d_scales(vec2) + default: + return false; + } + + // block_b_cache: { int32_t qs[8]; FLOAT_TYPEV2 ds; } + const uint32_t block_b_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); + + const uint32_t BM = warptile[1]; + const uint32_t BN = warptile[2]; + // mul_mmq.comp: BK_STEP=1 for MUL_MAT_ID, 4 otherwise. + const uint32_t BK_STEP = mul_mat_id ? 1u : 4u; + + const uint32_t buf_a_size = BM * BK_STEP * block_a_size; + const uint32_t buf_b_size = BN * BK_STEP * block_b_size; + const uint32_t mmid_row_ids = mul_mat_id ? (BN * 2u * (uint32_t)sizeof(uint16_t)) : 0u; + + const uint32_t warps = warptile[0] / warptile[10]; + const uint32_t ballots_sh = mul_mat_id ? (warps * 4u * (uint32_t)sizeof(uint32_t)) : 0u; + + const uint32_t total_size = buf_a_size + buf_b_size + mmid_row_ids + ballots_sh; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_matmul_int_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), " + "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", total=" << total_size << ", supported=" << supported); + + return supported; +} + struct GpuPipelineConfig { // GPU architecture identifier. // Example: vk_device_architecture::AMD_GCN @@ -3284,6 +3352,20 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev return 0; // If no matching configuration is found } +// Whether scalar flash attention will use the MMQ path for the given k_type. +static bool ggml_vk_fa_scalar_uses_mmq(const vk_device& device, ggml_type k_type) { +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + return device->integer_dot_product && device->subgroup_clustered && + (k_type == GGML_TYPE_Q4_0 || k_type == GGML_TYPE_Q4_1 || + k_type == GGML_TYPE_Q5_0 || k_type == GGML_TYPE_Q5_1 || + k_type == GGML_TYPE_Q8_0); +#else + GGML_UNUSED(device); + GGML_UNUSED(k_type); + return false; +#endif +} + static void ggml_vk_load_shaders(vk_device& device) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); @@ -3449,6 +3531,40 @@ static void ggml_vk_load_shaders(vk_device& device) { } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) { device->mul_mat_id_l[i] = false; } + + // The q8_1 mmq path has its own (larger) shmem layout, check it separately. + // K-quants use the _int_k warptiles, others use _int. + const bool is_k_quant = (t == GGML_TYPE_Q2_K || t == GGML_TYPE_Q3_K || + t == GGML_TYPE_Q4_K || t == GGML_TYPE_Q5_K || + t == GGML_TYPE_Q6_K); + const auto & s_int = is_k_quant ? s_warptile_mmq_int_k : s_warptile_mmq_int; + const auto & m_int = is_k_quant ? m_warptile_mmq_int_k : m_warptile_mmq_int; + const auto & l_int = is_k_quant ? l_warptile_mmq_int_k : l_warptile_mmq_int; + const auto & s_intid = is_k_quant ? s_warptile_mmqid_int_k : s_warptile_mmqid_int; + const auto & m_intid = is_k_quant ? m_warptile_mmqid_int_k : m_warptile_mmqid_int; + const auto & l_intid = is_k_quant ? l_warptile_mmqid_int_k : l_warptile_mmqid_int; + + if (!ggml_vk_matmul_int_shmem_support(device, s_int, false, t)) { + device->mul_mat_s_int[i] = false; + device->mul_mat_m_int[i] = false; + device->mul_mat_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, m_int, false, t)) { + device->mul_mat_m_int[i] = false; + device->mul_mat_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, l_int, false, t)) { + device->mul_mat_l_int[i] = false; + } + + if (!ggml_vk_matmul_int_shmem_support(device, s_intid, true, t)) { + device->mul_mat_id_s_int[i] = false; + device->mul_mat_id_m_int[i] = false; + device->mul_mat_id_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, m_intid, true, t)) { + device->mul_mat_id_m_int[i] = false; + device->mul_mat_id_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, l_intid, true, t)) { + device->mul_mat_id_l_int[i] = false; + } } } @@ -3530,121 +3646,96 @@ static void ggml_vk_load_shaders(vk_device& device) { align, disable_robustness, require_full_subgroups, required_subgroup_size); }; -#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ - for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \ - FaCodePath path = fa.first.path; \ - uint32_t Br = fa.first.Br; \ - uint32_t Bc = fa.first.Bc; \ - bool aligned = fa.first.aligned; \ - bool f32acc = fa.first.f32acc; \ - uint32_t fa_sgs = fa.first.subgroup_size; \ - bool fa_ds = fa.first.subgroup_size == 0; \ - if (path == FAPATH) { \ - if (aligned) { \ - if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ - } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ - } \ - } else { \ - if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ - } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ - } \ - } \ - } \ - } + // FA scalar has two SPIR-V modules (MMQ vs non-MMQ); FA cm1 has one. K/V + // quant type is selected at runtime via the FaTypeK / FaTypeV spec constants. - if (device->fp16) { - CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) + for (auto &fa : device->pipeline_flash_attn_f32_f16) { + if (fa.first.path != FA_SCALAR) continue; + const uint32_t Br = fa.first.Br; + const uint32_t Bc = fa.first.Bc; + const bool aligned = fa.first.aligned; + const bool f32acc = fa.first.f32acc; + const uint32_t fa_sgs = fa.first.subgroup_size; + const bool fa_ds = fa.first.subgroup_size == 0; + const bool use_mmq = ggml_vk_fa_scalar_uses_mmq(device, fa.first.k_type); + const void * spv_data = nullptr; + size_t spv_size = 0; + if (use_mmq) { #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (device->integer_dot_product && device->subgroup_clustered) { - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8) - } else + if (device->fp16) { + if (f32acc) { spv_data = flash_attn_f32_f16_int8_data; spv_size = flash_attn_f32_f16_int8_len; } + else { spv_data = flash_attn_f32_f16_f16acc_int8_data; spv_size = flash_attn_f32_f16_f16acc_int8_len; } + } else { + spv_data = flash_attn_f32_f16_fp32_int8_data; + spv_size = flash_attn_f32_f16_fp32_int8_len; + } #endif - { - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, ) - } - } else { - CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32) - -#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (device->integer_dot_product && device->subgroup_clustered) { - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8) - } else -#endif - { - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32) + } else { + if (device->fp16) { + if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; } + else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; } + } else { + spv_data = flash_attn_f32_f16_fp32_data; + spv_size = flash_attn_f32_f16_fp32_len; + } } + const char *name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; + ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, + sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, + get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, + !fa_ds, !fa_ds ? fa_sgs : 0); } + #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat1_fa_support) { - CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT1, _cm1) - } -#endif -#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) -#define CREATE_FA_CM2_MIXED() \ - for (int fa_k_ty = 0; fa_k_ty < (int)GGML_TYPE_COUNT; ++fa_k_ty) { \ - for (auto &fa : device->pipeline_flash_attn_f32_f16[fa_k_ty]) { \ - FaCodePath path = fa.first.path; \ - uint32_t Br = fa.first.Br; \ - uint32_t Bc = fa.first.Bc; \ - bool aligned = fa.first.aligned; \ - bool f32acc = fa.first.f32acc; \ - if (path == FA_COOPMAT2) { \ - if (aligned) { \ - if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \ - } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \ - } \ - } else { \ - if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \ - } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \ - } \ - } \ - } \ - } \ + for (auto &fa : device->pipeline_flash_attn_f32_f16) { + if (fa.first.path != FA_COOPMAT1) continue; + const uint32_t Br = fa.first.Br; + const uint32_t Bc = fa.first.Bc; + const bool aligned = fa.first.aligned; + const bool f32acc = fa.first.f32acc; + const uint32_t fa_sgs = fa.first.subgroup_size; + const bool fa_ds = fa.first.subgroup_size == 0; + + const void * spv_data; + size_t spv_size; + if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; } + else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; } + const char *name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1"; + ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, + sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, + get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, + !fa_ds, !fa_ds ? fa_sgs : 0); + } + } +#endif + +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + for (auto &fa : device->pipeline_flash_attn_f32_f16) { + if (fa.first.path != FA_COOPMAT2) continue; + const uint32_t Br = fa.first.Br; + const uint32_t Bc = fa.first.Bc; + const bool aligned = fa.first.aligned; + const bool f32acc = fa.first.f32acc; + + const void * spv_data; + size_t spv_size; + const char * name; + if (aligned) { + if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_aligned_f32acc_cm2"; } + else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_aligned_f16acc_cm2"; } + } else { + if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_f32acc_cm2"; } + else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_f16acc_cm2"; } + } + ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, + sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, + get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, false, 0); } - if (device->coopmat2) { - CREATE_FA_CM2_MIXED(); } -#undef CREATE_FA_CM2_MIXED #endif -#undef CREATE_FA const int mul_mat_id_param_count = 5; @@ -4174,11 +4265,6 @@ static void ggml_vk_load_shaders(vk_device& device) { m_wg_denoms = { 64, 64, 1 }; s_wg_denoms = { 32, 32, 1 }; - if (device->vendor_id == VK_VENDOR_ID_INTEL && device->architecture == INTEL_XE2) { - // Xe2/Xe3 - bf16 warptile performance tuning - l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, 4, 4, 1, subgroup_size_8 }; - } - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); } @@ -5617,19 +5703,19 @@ static vk_device ggml_vk_get_device(size_t idx) { device->mul_mat_id_m[i] = true; device->mul_mat_id_s[i] = true; break; - case VK_VENDOR_ID_INTEL: - if (!device->coopmat_support || device->architecture != INTEL_XE2) { - device->mul_mat_l[i] = false; - device->mul_mat_id_l[i] = false; - } else { - device->mul_mat_l[i] = true; // if coopmat & XE2+, allow large matmul warptile config for Intel - device->mul_mat_id_l[i] = true; - } + case VK_VENDOR_ID_INTEL: { + // Current Windows driver does not expose BF16 support. + // We only want to use l_warptile if coopmat is available and is Xe2+ + const bool xe2_with_coopmat = device->coopmat_support && device->architecture == INTEL_XE2; + const bool use_l_warptile = (i == GGML_TYPE_BF16) ? (device->coopmat_bf16_support && xe2_with_coopmat) : xe2_with_coopmat; + device->mul_mat_l[i] = use_l_warptile; + device->mul_mat_id_l[i] = use_l_warptile; device->mul_mat_m[i] = true; device->mul_mat_s[i] = true; device->mul_mat_id_m[i] = true; device->mul_mat_id_s[i] = true; break; + } case VK_VENDOR_ID_APPLE: device->mul_mat_l[i] = false; device->mul_mat_m[i] = true; @@ -5648,6 +5734,13 @@ static vk_device ggml_vk_get_device(size_t idx) { device->mul_mat_id_s[i] = true; break; } + + device->mul_mat_l_int[i] = true; + device->mul_mat_m_int[i] = true; + device->mul_mat_s_int[i] = true; + device->mul_mat_id_l_int[i] = true; + device->mul_mat_id_m_int[i] = true; + device->mul_mat_id_s_int[i] = true; } @@ -7263,6 +7356,13 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) { VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + // The q8_1 (integer dot) mmq path uses a different shader with its own + // shared-memory layout, so use the int-specific availability flags. + const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1); + const bool mm_l = is_q8_1 ? ctx->device->mul_mat_l_int[src0_type] : ctx->device->mul_mat_l[src0_type]; + const bool mm_m = is_q8_1 ? ctx->device->mul_mat_m_int[src0_type] : ctx->device->mul_mat_m[src0_type]; + const bool mm_s = is_q8_1 ? ctx->device->mul_mat_s_int[src0_type] : ctx->device->mul_mat_s[src0_type]; + if (ctx->device->coopmat2) { const uint32_t shader_core_count = ctx->device->shader_core_count; const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]); @@ -7279,26 +7379,24 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, // split_k==3 with large tiles likely better than medium tiles with no split_k. (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2); - if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { + if ((mm_l && (n > crossover_large && prefer_large)) || (!mm_m && !mm_s)) { return aligned ? mmp->a_l : mmp->l; } // Use medium shader when the N dimension is greater than the small shader's tile size uint32_t crossover_medium = mmp->s->wg_denoms[1]; - if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) { + if ((mm_m && (n > crossover_medium)) || !mm_s) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) { + if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) { return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) { + if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; - - GGML_UNUSED(src1_type); } static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { @@ -7355,35 +7453,42 @@ static void ggml_vk_matmul( ctx->prealloc_split_k_need_sync = true; } -static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); +static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + + // The q8_1 (integer dot) mmq path uses a different shader with its own + // shared-memory layout, so use the int-specific availability flags. + const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1); + const bool mm_l = is_q8_1 ? ctx->device->mul_mat_id_l_int[src0_type] : ctx->device->mul_mat_id_l[src0_type]; + const bool mm_m = is_q8_1 ? ctx->device->mul_mat_id_m_int[src0_type] : ctx->device->mul_mat_id_m[src0_type]; + const bool mm_s = is_q8_1 ? ctx->device->mul_mat_id_s_int[src0_type] : ctx->device->mul_mat_id_s[src0_type]; if (ctx->device->coopmat2) { // Use large shader when the N dimension is greater than the medium shader's tile size uint32_t crossover_large = mmp->m->wg_denoms[1]; - if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { + if ((mm_l && (n > crossover_large)) || (!mm_m && !mm_s)) { return aligned ? mmp->a_l : mmp->l; } // Use medium shader when the N dimension is greater than the small shader's tile size uint32_t crossover_medium = mmp->s->wg_denoms[1]; - if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) { + if ((mm_m && (n > crossover_medium)) || !mm_s) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) { + if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) { return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) { + if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; } -static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")"); - return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align; +static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align; } static void ggml_vk_matmul_id( @@ -7679,10 +7784,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type))); + const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type); + + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, effective_src1_type)); const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)); + vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type); if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline); @@ -8514,10 +8621,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); + const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type); + + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type, effective_src1_type)); const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type); + vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type); if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline); @@ -8967,8 +9076,9 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type) { GGML_UNUSED(f32acc); + GGML_UNUSED(v_type); // Needs to be kept up to date on shader changes const uint32_t wg_size = params.workgroup_size; const uint32_t Br = params.block_rows; @@ -8976,10 +9086,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); - const bool mmq = device->integer_dot_product && device->subgroup_clustered && - (kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 || - kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 || - kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL); + const bool mmq = ggml_vk_fa_scalar_uses_mmq(device, k_type); // tmpsh is overestimated slightly const uint32_t tmpsh = wg_size * sizeof(float); @@ -8996,17 +9103,10 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con // kvsh uses D = HSV (K goes through kblocksh instead) kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size; - // block_a_cache size depends on quant type - uint32_t block_a_size; - switch (kv_type) { - case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break; - case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break; - case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break; - case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break; - case GGML_TYPE_Q8_0: - case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break; - default: block_a_size = 0; break; - } + // The mixed MMQ shader uses a superset block_a_cache that fits every + // FA-supported quant: int32_t qs[8] + uint32_t qh + FLOAT_TYPEV2 dm. + // Single-scale types leave dm.y unused; non-Q5_* leave qh unused. + const uint32_t block_a_size = 8 * sizeof(int32_t) + sizeof(uint32_t) + 2 * float_type_size; kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size; } else { Qf = Br * (hsk / 4 + 1) * 4 * float_type_size; @@ -9144,10 +9244,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc); - if (tuning_params.path != FA_COOPMAT2) { - GGML_ASSERT(k->type == v->type); - } - const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); @@ -9191,7 +9287,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx { std::lock_guard guard(ctx->device->mutex); - auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type]; + auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16; auto it = pipelines.find(fa_pipeline_state); if (it != pipelines.end()) { pipeline = it->second; @@ -15669,10 +15765,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { return false; } - // mismatching K/V type is currently supported for coopmat2 only. - if (op->src[1]->type != op->src[2]->type && !coopmat2) { - return false; - } auto fa_kv_ok = [coopmat2](ggml_type t) { switch (t) { case GGML_TYPE_F32: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 6e6bdabc9..6ac095489 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -22,6 +22,7 @@ #include "types.glsl" #include "flash_attn_base.glsl" +#include "flash_attn_dequant.glsl" const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; @@ -128,18 +129,20 @@ void main() { Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals)); -#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) - if (buf_iqs == 0) { - Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0); - } -#else // Q4_0, Q4_1, Q5_0, Q5_1 - const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w; - const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8); + // Q8_0 K only needs (qd, _); the asymmetric Q4_*/Q5_* family also stores + // the row-sum scaled by qd, used in k_dot_correction. + if (FaTypeK == FA_TYPE_Q8_0) { + if (buf_iqs == 0) { + Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0); + } + } else { + const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w; + const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8); - if (buf_iqs == 0) { - Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd); + if (buf_iqs == 0) { + Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd); + } } -#endif #endif } barrier(); @@ -177,13 +180,9 @@ void main() { // mo_offset will point to the tile starting at row i*Br and col 0 uint32_t mo_offset = mo_stride * i; -#if BLOCK_SIZE > 1 - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; -#else - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; -#endif + // FaBlockBytesK/V == 2 for f16, 16 for f32, ggml block byte size for quants. + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV; uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; @@ -257,21 +256,21 @@ void main() { if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) { FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); if (!KV_bounds_check || j * Bc + c < KV) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else - K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); -#endif + if (USE_DECODE_K) { + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d; + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + } else { + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + } } kvsh[c * kvsh_stride + d] = K_Tf; } } #else // MMQ - const uint ints_per_block = 8 / QUANT_R_MMQ; + const uint ints_per_block = 8u / fa_quant_r_mmq(FaTypeK); const uint quant_iters = Bc * HSK / 32 * ints_per_block; [[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) { const uint32_t iqs = (idx + tid) % ints_per_block; @@ -310,15 +309,13 @@ void main() { FLOAT_TYPEV4 K_Tf; if (SHMEM_STAGING != 0) { K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; - } else { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); + } else if (USE_DECODE_K) { + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else + } else { K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); -#endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf)); @@ -335,15 +332,13 @@ void main() { FLOAT_TYPEV4 K_Tf; if (SHMEM_STAGING != 0) { K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; - } else { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); + } else if (USE_DECODE_K) { + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else + } else { K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); -#endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf)); @@ -366,72 +361,47 @@ void main() { int32_t k_quants[d_per_step]; ACC_TYPEV2 k_dm; + // Q4_*/Q5_* take the block-8 fast path when one step covers a full + // block; Q8_0 always goes through the per-int get_k_qs* helpers + // (its qs is byte-packed, not nibble-packed). + const bool block8_fast = (d_per_step == 8) && (FaTypeK != FA_TYPE_Q8_0); + if (SHMEM_STAGING != 0) { const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8; const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx; -#if QUANT_AUXF == 1 - k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0); -#else k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm); -#endif -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - if (d_per_step == 8) { + if (block8_fast) { + const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1); [[unroll]] for (uint32_t d = 0; d < 4; d++) { uint vui = kblocksh[buf_ib].qs[d]; k_quants[d ] = int32_t( vui & 0x0F0F0F0F); k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); -#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF; - uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF; - k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); - k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); -#endif + if (has_qh) { + uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF; + uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF; + k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); + k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); + } } - } else -#endif - { + } else { [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d); } } } else { - const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block); - const uint ib = coord / BLOCK_SIZE; - const uint iqs = (coord % BLOCK_SIZE); + const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d_tid * (HSK_per_thread / 4) + d_block); + const uint ib = coord / BLOCK_SIZE_K; + const uint iqs = (coord % BLOCK_SIZE_K); -#if QUANT_AUXF == 1 - k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0); -#else - k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset)); -#endif -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - if (d_per_step == 8) { -#if defined(DATA_A_Q5_0) - uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0], - k_packed.k_data_packed16[k_offset + ib].qh[1])); -#elif defined(DATA_A_Q5_1) - uint qh = k_packed.k_data_packed16[k_offset + ib].qh; -#endif - [[unroll]] for (uint32_t d = 0; d < 4; d++) { -#if defined(A_TYPE_PACKED32) - uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d]; -#else - uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0], - k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1])); -#endif - k_quants[d ] = int32_t( vui & 0x0F0F0F0F); - k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); -#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - uint qh_lo = (qh >> (d * 4)) & 0xF; - uint qh_hi = (qh >> (d * 4 + 16)) & 0xF; - k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); - k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); -#endif + k_dm = ACC_TYPEV2(get_k_scale(ib, k_offset)); + + if (block8_fast) { + fa_k_qs_block8 blk = get_k_qs_block8(ib, k_offset); + [[unroll]] for (uint32_t d = 0; d < 8; d++) { + k_quants[d] = blk.qs[d]; } - } else -#endif - { + } else { [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset); } @@ -516,14 +486,14 @@ void main() { if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) { FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0); if (!KV_bounds_check || j * Bc + c < KV) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); -#else - V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); -#endif + if (USE_DECODE_V) { + uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d; + uint ib = coord / BLOCK_SIZE_V; + uint iqs = (coord % BLOCK_SIZE_V); + V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + } else { + V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); + } } kvsh[c * kvsh_stride + d] = V_Tf; @@ -547,15 +517,13 @@ void main() { FLOAT_TYPEV4 Vf; if (SHMEM_STAGING != 0) { Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; - } else { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); + } else if (USE_DECODE_V) { + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE_V + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE_V; + uint iqs = (coord % BLOCK_SIZE_V); Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); -#else + } else { Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); -#endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index efed3a73e..9a7957da9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -87,176 +87,58 @@ layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];}; #define BINDING_IDX_K 0 #define BINDING_IDX_V 1 -#if defined(DATA_A_F32) -layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed; -layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed; -#elif defined(A_TYPE_PACKED16) -layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed; -layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; -#endif -#if defined(A_TYPE_PACKED32) -layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32; -layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32; -#endif +// FaTypeK / FaTypeV spec constant values. These mirror enum ggml_type so the +// host can pass the type directly. Keep in sync with ggml.h. +#define FA_TYPE_F32 0u +#define FA_TYPE_F16 1u +#define FA_TYPE_Q4_0 2u +#define FA_TYPE_Q4_1 3u +#define FA_TYPE_Q5_0 6u +#define FA_TYPE_Q5_1 7u +#define FA_TYPE_Q8_0 8u +#define FA_TYPE_Q1_0 41u -#ifndef BLOCK_SIZE -#define BLOCK_SIZE 1 -#endif - -#if defined(DATA_A_F32) -#undef BLOCK_SIZE -#define BLOCK_SIZE 4 -#define BLOCK_BYTE_SIZE 16 - -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - // iqs is currently always zero in the flash attention shaders - if (binding_idx == BINDING_IDX_K) { - return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]); - } else { - return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]); +// Number of matrix elements per buffer block, derived from the K/V type spec +// constant. F32 is treated as a vec4 "block" of 4 floats. F16 uses block size 1 +// and bypasses the dequant path entirely. Quants follow their ggml block sizes. +uint fa_block_elems(uint ty) { + switch (ty) { + case FA_TYPE_F32: return 4u; + case FA_TYPE_F16: return 1u; + case FA_TYPE_Q4_0: return uint(QUANT_K_Q4_0); + case FA_TYPE_Q4_1: return uint(QUANT_K_Q4_1); + case FA_TYPE_Q5_0: return uint(QUANT_K_Q5_0); + case FA_TYPE_Q5_1: return uint(QUANT_K_Q5_1); + case FA_TYPE_Q8_0: return uint(QUANT_K_Q8_0); + case FA_TYPE_Q1_0: return uint(QUANT_K_Q1_0); // cm2-only, harmless elsewhere + default: return 1u; } } -#endif -#if defined(DATA_A_Q4_0) -#define BLOCK_BYTE_SIZE 18 -#elif defined(DATA_A_Q4_1) -#define BLOCK_BYTE_SIZE 20 -#endif - -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - if (binding_idx == BINDING_IDX_K) { - uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); -#ifdef DATA_A_Q4_1 - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m); -#else - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); -#endif - } else { - uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); -#ifdef DATA_A_Q4_1 - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m); -#else - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); -#endif +// QUANT_R_MMQ for FA-eligible K types. Q4_*/Q5_* store two nibbles per byte +// (R==2); Q8_0 stores one byte per element (R==1). Used to derive the number +// of int32s per 32-element block on the MMQ K path: ints_per_block == 8 / R. +uint fa_quant_r_mmq(uint ty) { + switch (ty) { + case FA_TYPE_Q4_0: return uint(QUANT_R_Q4_0); + case FA_TYPE_Q4_1: return uint(QUANT_R_Q4_1); + case FA_TYPE_Q5_0: return uint(QUANT_R_Q5_0); + case FA_TYPE_Q5_1: return uint(QUANT_R_Q5_1); + case FA_TYPE_Q8_0: return uint(QUANT_R_Q8_0); + default: return 1u; } } -#endif -#if defined(DATA_A_Q5_0) -#define BLOCK_BYTE_SIZE 22 -#elif defined(DATA_A_Q5_1) -#define BLOCK_BYTE_SIZE 24 -#endif - -#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - if (binding_idx == BINDING_IDX_K) { - uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - -#ifdef DATA_A_Q5_1 - uint qh = k_packed.k_data_packed16[a_offset + ib].qh; -#else - uint qh = uint(k_packed.k_data_packed16[a_offset + ib].qh[0]) | (uint(k_packed.k_data_packed16[a_offset + ib].qh[1]) << 16); -#endif - FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f); - - FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); -#ifdef DATA_A_Q5_1 - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m); -#else - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); -#endif - } else { - uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - -#ifdef DATA_A_Q5_1 - uint qh = v_packed.v_data_packed16[a_offset + ib].qh; -#else - uint qh = uint(v_packed.v_data_packed16[a_offset + ib].qh[0]) | (uint(v_packed.v_data_packed16[a_offset + ib].qh[1]) << 16); -#endif - FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f); - - FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); -#ifdef DATA_A_Q5_1 - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m); -#else - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); -#endif - } -} -#endif - - -#if defined(DATA_A_IQ4_NL) -#define BLOCK_BYTE_SIZE 18 - -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - if (binding_idx == BINDING_IDX_K) { - uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4( - kvalues_iq4nl[vui_lo & 0xF], - kvalues_iq4nl[(vui_lo >> 8) & 0xF], - kvalues_iq4nl[vui_hi & 0xF], - kvalues_iq4nl[(vui_hi >> 8) & 0xF]); - } else { - uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4( - kvalues_iq4nl[vui_lo & 0xF], - kvalues_iq4nl[(vui_lo >> 8) & 0xF], - kvalues_iq4nl[vui_hi & 0xF], - kvalues_iq4nl[(vui_hi >> 8) & 0xF]); - } -} -#endif -#if defined(DATA_A_Q8_0) -#define BLOCK_BYTE_SIZE 34 -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - if (binding_idx == BINDING_IDX_K) { - const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); - } else { - const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); - } -} -#endif +// These can't be `const` globals because GLSL forbids function calls in global +// const initializers, even when the spec constants would let the driver fold +// them. Macros expand at the use site and fold after specialization. +#define BLOCK_SIZE_K fa_block_elems(FaTypeK) +#define BLOCK_SIZE_V fa_block_elems(FaTypeV) +// F16 reads f16 elements directly from the binding; everything else routes +// through dequantize4 / the MMQ helpers to unpack from the packed block layout. +#define USE_DECODE_K (FaTypeK != FA_TYPE_F16) +#define USE_DECODE_V (FaTypeV != FA_TYPE_F16) #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 526e8da38..bffcc095b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -14,6 +14,7 @@ #include "types.glsl" #include "flash_attn_base.glsl" +#include "flash_attn_dequant.glsl" // These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd const uint32_t MatBr = 16; @@ -127,13 +128,9 @@ void main() { // mo_offset will point to the tile starting at row i*Br and col 0 uint32_t mo_offset = mo_stride * i; -#if BLOCK_SIZE > 1 - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; -#else - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; -#endif + // FaBlockBytesK/V == 2 for f16 (sizeof f16) and == 16 for f32 (vec4) and == ggml block size for quants. + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV; uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; @@ -227,14 +224,14 @@ void main() { if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) { f16vec4 K_Tf = f16vec4(0); if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); -#endif + if (USE_DECODE_K) { + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d; + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + } else { + K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + } } kvsh[c * kvsh_stride + d] = K_Tf; @@ -256,47 +253,40 @@ void main() { // staged through a Bc * MatBr size staging buffer. // If K is not type f16, then it is always staged for dequantization. if (SHMEM_STAGING == 0) { -#if BLOCK_SIZE == 1 - if (KV_bounds_check || d * 16 + 16 > HSK) { -#endif - barrier(); - [[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) { - uint32_t col_vec = (idx + tid) % (MatBr / 4); - uint32_t row = (idx + tid) / (MatBr / 4); - if (idx + tid < Bc * MatBr / 4) { - f16vec4 K_Tf = f16vec4(0); - if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); -#endif + // For quants we always need to dequant into kvsh; for f16 we can load + // directly from global memory when alignment / bounds allow it. + const bool stage_k = USE_DECODE_K || KV_bounds_check || d * 16 + 16 > HSK; + if (stage_k) { + barrier(); + [[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) { + uint32_t col_vec = (idx + tid) % (MatBr / 4); + uint32_t row = (idx + tid) / (MatBr / 4); + if (idx + tid < Bc * MatBr / 4) { + f16vec4 K_Tf = f16vec4(0); + if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) { + if (USE_DECODE_K) { + uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE_K + d * 16 + col_vec * 4; + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + } else { + K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); + } + } + + kvsh[row * kvsh_stride + col_vec] = K_Tf; } - - kvsh[row * kvsh_stride + col_vec] = K_Tf; } + barrier(); } - barrier(); -#if BLOCK_SIZE == 1 - } -#endif -#if BLOCK_SIZE == 1 - if (KV_bounds_check || d * 16 + 16 > HSK) -#endif - { + if (stage_k) { uint coord = (gl_SubgroupID * MatBc) * kvsh_stride; coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); - } -#if BLOCK_SIZE == 1 - else { + } else { const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4; coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor); } -#endif } else { uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4; coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); @@ -397,14 +387,14 @@ void main() { if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) { f16vec4 V_Tf = f16vec4(0); if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); -#else - V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); -#endif + if (USE_DECODE_V) { + uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d; + uint ib = coord / BLOCK_SIZE_V; + uint iqs = (coord % BLOCK_SIZE_V); + V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + } else { + V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); + } } kvsh[c * kvsh_stride + d] = V_Tf; @@ -431,36 +421,33 @@ void main() { // staged through a Bc * MatBr size staging buffer. // If V is not type f16, then it is always staged for dequantization. if (SHMEM_STAGING == 0) { -#if BLOCK_SIZE == 1 - // For f16, only preload if not aligned - if (KV_bounds_check) { -#endif - [[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) { - const uint idx = i * gl_WorkGroupSize.x + tid; - const uint row = idx / v_cols; - const uint col = idx % v_cols; + // For quants we always preload via kvsh. For f16 we only preload when + // alignment / bounds force it (otherwise we coopMatLoad direct from data_vv4). + const bool stage_v = USE_DECODE_V || KV_bounds_check; + if (stage_v) { + [[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) { + const uint idx = i * gl_WorkGroupSize.x + tid; + const uint row = idx / v_cols; + const uint col = idx % v_cols; - const uint v_row = j * Bc + row; - const uint v_col = hsv_tile * MatBc * row_split + col * 4; + const uint v_row = j * Bc + row; + const uint v_col = hsv_tile * MatBc * row_split + col * 4; - const uint coord = v_row * v_stride * BLOCK_SIZE + v_col; - const uint ib = coord / BLOCK_SIZE; - const uint iqs = coord % BLOCK_SIZE; + const uint coord = v_row * v_stride * BLOCK_SIZE_V + v_col; + const uint ib = coord / BLOCK_SIZE_V; + const uint iqs = coord % BLOCK_SIZE_V; - if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { -#if BLOCK_SIZE > 1 - kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); -#else - kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; -#endif - } else { - kvsh[row * vsh_stride + col] = f16vec4(0.0f); + if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { + if (USE_DECODE_V) { + kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + } else { + kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; + } + } else { + kvsh[row * vsh_stride + col] = f16vec4(0.0f); + } } } - -#if BLOCK_SIZE == 1 - } -#endif } barrier(); @@ -471,15 +458,12 @@ void main() { coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor); if (SHMEM_STAGING == 0) { -#if BLOCK_SIZE == 1 - if (!KV_bounds_check) { + if (!USE_DECODE_V && !KV_bounds_check) { // F16 values can be loaded directly from global memory const uint v_tile_row = j * Bc + bc_chunk * MatBc; const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); - } else -#endif - { + } else { const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4); coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 8a7bbaeb9..141bb8708 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -28,43 +28,28 @@ layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_ uint8_t raw[FaBlockBytesV]; }; -uint fa_block_elems(uint ty) { - switch (ty) { - case 0u: return 4u; // GGML_TYPE_F32: vec4 block (matches decodeBufF32 / dequantFuncF32) - case 1u: return 1u; // GGML_TYPE_F16 - case 2u: return uint(QUANT_K_Q4_0); - case 3u: return uint(QUANT_K_Q4_1); - case 6u: return uint(QUANT_K_Q5_0); - case 7u: return uint(QUANT_K_Q5_1); - case 8u: return uint(QUANT_K_Q8_0); - case 41u: return uint(QUANT_K_Q1_0); - default: - return 1u; - } -} - float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { switch (FaTypeK) { - case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock); - case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); - case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); - case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); - case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); - case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); - case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_0: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_1: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_0: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_1: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q8_0: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q1_0: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); default: return float16_t(0); } } float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { switch (FaTypeV) { - case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock); - case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); - case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); - case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); - case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); - case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); - case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_0: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_1: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_0: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_1: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q8_0: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q1_0: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); default: return float16_t(0); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl new file mode 100644 index 000000000..02106f33c --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl @@ -0,0 +1,123 @@ +// Asymmetric K/V flash attention: aliased SSBO views of bindings 1 (K) and 2 (V) +// covering every supported FA element type, plus an uber dequantize4() that +// switches on FaTypeK / FaTypeV. After spec-constant specialization the driver +// folds away every path except the one matching the K/V type for this pipeline. +// +// Included by flash_attn.comp and flash_attn_cm1.comp. Not included by +// flash_attn_cm2.comp, which has its own buffer_reference-based decode path. +// +// We use macros (rather than per-quant decode functions taking a struct) on +// purpose: the FA shaders don't enable GL_EXT_shader_explicit_arithmetic_types_float16 +// when FLOAT16 isn't defined, which makes float16-containing struct values +// illegal to return from / pass to functions. Macros expand inline where the +// float16 stays in storage and is converted to FLOAT_TYPE at use. + +// F32 is fed as a vec4 "block" (4 floats), matching what dequant_funcs_cm2.glsl +// does for F32 in the cm2 shader. FaBlockBytesK/V == 16 for F32. +layout (binding = 1) readonly buffer K_PACKED_F32 { vec4 data[]; } k_packed_f32; +layout (binding = 2) readonly buffer V_PACKED_F32 { vec4 data[]; } v_packed_f32; + +layout (binding = 1) readonly buffer K_PACKED_Q4_0 { block_q4_0_packed16 data[]; } k_packed_q4_0; +layout (binding = 2) readonly buffer V_PACKED_Q4_0 { block_q4_0_packed16 data[]; } v_packed_q4_0; +layout (binding = 1) readonly buffer K_PACKED_Q4_1 { block_q4_1_packed16 data[]; } k_packed_q4_1; +layout (binding = 2) readonly buffer V_PACKED_Q4_1 { block_q4_1_packed16 data[]; } v_packed_q4_1; +layout (binding = 1) readonly buffer K_PACKED_Q5_0 { block_q5_0_packed16 data[]; } k_packed_q5_0; +layout (binding = 2) readonly buffer V_PACKED_Q5_0 { block_q5_0_packed16 data[]; } v_packed_q5_0; +layout (binding = 1) readonly buffer K_PACKED_Q5_1 { block_q5_1_packed16 data[]; } k_packed_q5_1; +layout (binding = 2) readonly buffer V_PACKED_Q5_1 { block_q5_1_packed16 data[]; } v_packed_q5_1; +layout (binding = 1) readonly buffer K_PACKED_Q8_0 { block_q8_0_packed16 data[]; } k_packed_q8_0; +layout (binding = 2) readonly buffer V_PACKED_Q8_0 { block_q8_0_packed16 data[]; } v_packed_q8_0; + +// Q4_1 and Q5_1 packed32 views: aliased to the same memory as the packed16 +// views, used by the MMQ K-side hot path for fast 4-uint loads. +layout (binding = 1) readonly buffer K_PACKED_Q4_1_P32 { block_q4_1_packed32 data[]; } k_packed_q4_1_p32; +layout (binding = 1) readonly buffer K_PACKED_Q5_1_P32 { block_q5_1_packed32 data[]; } k_packed_q5_1_p32; + +// Per-quant decode bodies are expanded once for the K view set and once for +// the V view set. The macros take the buffer name as a parameter. +#define FA_DEQUANT4_F32(BUF) \ + return FLOAT_TYPEV4(BUF.data[a_offset + ib]); + +#define FA_DEQUANT4_Q4_0(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); \ +} + +#define FA_DEQUANT4_Q4_1(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * nibbles \ + + FLOAT_TYPE(BUF.data[a_offset + ib].m); \ +} + +#define FA_DEQUANT4_Q5_0(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + uint qh = uint(BUF.data[a_offset + ib].qh[0]) \ + | (uint(BUF.data[a_offset + ib].qh[1]) << 16); \ + FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, \ + (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) \ + * FLOAT_TYPE(16.0f); \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); \ +} + +#define FA_DEQUANT4_Q5_1(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + uint qh = BUF.data[a_offset + ib].qh; \ + FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, \ + (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) \ + * FLOAT_TYPE(16.0f); \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles + hb) \ + + FLOAT_TYPE(BUF.data[a_offset + ib].m); \ +} + +#define FA_DEQUANT4_Q8_0(BUF) { \ + const i8vec2 v0 = unpack8(int32_t(BUF.data[a_offset + ib].qs[iqs / 2 ])).xy; \ + const i8vec2 v1 = unpack8(int32_t(BUF.data[a_offset + ib].qs[iqs / 2 + 1])).xy; \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); \ +} + +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + if (binding_idx == BINDING_IDX_K) { + switch (FaTypeK) { + case FA_TYPE_F32: FA_DEQUANT4_F32 (k_packed_f32) + case FA_TYPE_Q4_0: FA_DEQUANT4_Q4_0(k_packed_q4_0) + case FA_TYPE_Q4_1: FA_DEQUANT4_Q4_1(k_packed_q4_1) + case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(k_packed_q5_0) + case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(k_packed_q5_1) + case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(k_packed_q8_0) + } + } else { + switch (FaTypeV) { + case FA_TYPE_F32: FA_DEQUANT4_F32 (v_packed_f32) + case FA_TYPE_Q4_0: FA_DEQUANT4_Q4_0(v_packed_q4_0) + case FA_TYPE_Q4_1: FA_DEQUANT4_Q4_1(v_packed_q4_1) + case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(v_packed_q5_0) + case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(v_packed_q5_1) + case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(v_packed_q8_0) + } + } + return FLOAT_TYPEV4(0); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl index e14e62d54..6bf10a7cf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl @@ -1,149 +1,203 @@ -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) +// MMQ K-side helpers, asymmetric form. Each function dispatches on FaTypeK and +// reads from the matching aliased K binding declared in flash_attn_dequant.glsl. +// Spec-constant specialization folds the unused paths. + int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { -#ifdef DATA_A_Q4_0 - uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], - k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); -#else - uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4]; -#endif - - uint shift = (iqs & 0x10) >> 2; - vui >>= shift; - - return int32_t(vui & 0x0F0F0F0F); + switch (FaTypeK) { + case FA_TYPE_Q4_0: { + uint vui = pack32(u16vec2(k_packed_q4_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed_q4_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + return int32_t(vui & 0x0F0F0F0F); + } + case FA_TYPE_Q4_1: { // uses packed32 alias + uint vui = k_packed_q4_1_p32.data[a_offset + ib].qs[(iqs & 0xF) / 4]; + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + return int32_t(vui & 0x0F0F0F0F); + } + case FA_TYPE_Q5_0: { + uint vui = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed_q5_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); + uint qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qh[0], + k_packed_q5_0.data[a_offset + ib].qh[1])); + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + uint qh_bits = (qh >> iqs) & 0xF; + return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u); + } + case FA_TYPE_Q5_1: { // qs via packed32, qh via packed16 + uint vui = k_packed_q5_1_p32.data[a_offset + ib].qs[(iqs & 0xF) / 4]; + uint qh = k_packed_q5_1.data[a_offset + ib].qh; + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + uint qh_bits = (qh >> iqs) & 0xF; + return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u); + } + case FA_TYPE_Q8_0: { + return pack32(i16vec2(k_packed_q8_0.data[a_offset + ib].qs[iqs / 2], + k_packed_q8_0.data[a_offset + ib].qs[iqs / 2 + 1])); + } + default: return 0; + } } -#endif -#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) -int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { -#ifdef DATA_A_Q5_0 - uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], - k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); - uint qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qh[0], - k_packed.k_data_packed16[a_offset + ib].qh[1])); -#else - uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4]; - uint qh = k_packed.k_data_packed16[a_offset + ib].qh; -#endif - - uint shift = (iqs & 0x10) >> 2; - vui >>= shift; - - uint qh_bits = (qh >> iqs) & 0xF; - return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u); +// Per-block scale/min, packed as (d, m). Single-scale types (Q4_0, Q5_0, Q8_0) +// return (d, 0) so call sites always see the same shape. +FLOAT_TYPEV2 get_k_scale(uint ib, uint a_offset) { + switch (FaTypeK) { + case FA_TYPE_Q4_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q4_0.data[a_offset + ib].d), 0.0); + case FA_TYPE_Q4_1: return FLOAT_TYPEV2(k_packed_q4_1_p32.data[a_offset + ib].dm); + case FA_TYPE_Q5_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q5_0.data[a_offset + ib].d), 0.0); + case FA_TYPE_Q5_1: return FLOAT_TYPEV2(k_packed_q5_1_p32.data[a_offset + ib].dm); + case FA_TYPE_Q8_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q8_0.data[a_offset + ib].d), 0.0); + default: return FLOAT_TYPEV2(0); + } } -#endif - -#if defined(DATA_A_Q8_0) -int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { - return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])); -} -#endif - -#if defined(DATA_A_IQ4_NL) -int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { - uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], - k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); - uint shift = (iqs & 0x10) >> 2; - vui >>= shift; - - u8vec4 idx = unpack8(vui & 0x0F0F0F0F); - return pack32(i8vec4(kvalues_iq4nl_const[idx.x], - kvalues_iq4nl_const[idx.y], - kvalues_iq4nl_const[idx.z], - kvalues_iq4nl_const[idx.w])); -} -#endif - -#if QUANT_AUXF == 1 -FLOAT_TYPE get_k_d(uint ib, uint a_offset) { - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d); -} -#else -FLOAT_TYPEV2 get_k_dm(uint ib, uint a_offset) { - return FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + ib].dm); -} -#endif void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) { -#if defined(DATA_A_Q4_0) - kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], - k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); -#elif defined(DATA_A_Q4_1) - kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs]; -#elif defined(DATA_A_Q5_0) - kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], - k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); - if (iqs == 0) { - kblocksh[buf_ib].qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qh[0], - k_packed.k_data_packed16[a_offset + global_ib].qh[1])); + // kblocksh[].qs is int32_t for the unified MMQ struct; uint sources need + // explicit casts. The bit pattern is what we care about here -- the actual + // signed/unsigned interpretation happens downstream in the dot product. + switch (FaTypeK) { + case FA_TYPE_Q4_0: { + kblocksh[buf_ib].qs[iqs] = int32_t(pack32(u16vec2(k_packed_q4_0.data[a_offset + global_ib].qs[iqs * 2], + k_packed_q4_0.data[a_offset + global_ib].qs[iqs * 2 + 1]))); + break; + } + case FA_TYPE_Q4_1: { + kblocksh[buf_ib].qs[iqs] = int32_t(k_packed_q4_1_p32.data[a_offset + global_ib].qs[iqs]); + break; + } + case FA_TYPE_Q5_0: { + kblocksh[buf_ib].qs[iqs] = int32_t(pack32(u16vec2(k_packed_q5_0.data[a_offset + global_ib].qs[iqs * 2], + k_packed_q5_0.data[a_offset + global_ib].qs[iqs * 2 + 1]))); + if (iqs == 0) { + kblocksh[buf_ib].qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + global_ib].qh[0], + k_packed_q5_0.data[a_offset + global_ib].qh[1])); + } + break; + } + case FA_TYPE_Q5_1: { + kblocksh[buf_ib].qs[iqs] = int32_t(k_packed_q5_1_p32.data[a_offset + global_ib].qs[iqs]); + if (iqs == 0) { + kblocksh[buf_ib].qh = k_packed_q5_1.data[a_offset + global_ib].qh; + } + break; + } + case FA_TYPE_Q8_0: { + kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed_q8_0.data[a_offset + global_ib].qs[iqs * 2], + k_packed_q8_0.data[a_offset + global_ib].qs[iqs * 2 + 1])); + break; + } } -#elif defined(DATA_A_Q5_1) - kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs]; - if (iqs == 0) { - kblocksh[buf_ib].qh = k_packed.k_data_packed16[a_offset + global_ib].qh; - } -#elif defined(DATA_A_Q8_0) - kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], - k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); -#elif defined(DATA_A_IQ4_NL) - const uint qs = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], - k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); - const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); - const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); - kblocksh[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_iq4nl_const[i_a0.x], kvalues_iq4nl_const[i_a0.y], - kvalues_iq4nl_const[i_a0.z], kvalues_iq4nl_const[i_a0.w])); - kblocksh[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_iq4nl_const[i_a1.x], kvalues_iq4nl_const[i_a1.y], - kvalues_iq4nl_const[i_a1.z], kvalues_iq4nl_const[i_a1.w])); -#endif if (iqs == 0) { -#if QUANT_AUXF == 1 - kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[a_offset + global_ib].d); -#else - kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + global_ib].dm); -#endif + // Q4_0/Q5_0/Q8_0 store dm.x = d; Q4_1/Q5_1 store dm = (d, m) pair. + switch (FaTypeK) { + case FA_TYPE_Q4_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q4_0.data[a_offset + global_ib].d), 0.0); break; + case FA_TYPE_Q4_1: kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed_q4_1_p32.data[a_offset + global_ib].dm); break; + case FA_TYPE_Q5_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q5_0.data[a_offset + global_ib].d), 0.0); break; + case FA_TYPE_Q5_1: kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed_q5_1_p32.data[a_offset + global_ib].dm); break; + case FA_TYPE_Q8_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q8_0.data[a_offset + global_ib].d), 0.0); break; + } } } +// d_per_step==8 hot path: read one full 32-element block worth of nibble-packed +// int32 quants. Equivalent to 8 calls to get_k_qs(ib, d*4, a_offset) but reads +// qh (Q5_*) and runs pack32 (Q4_0/Q5_0) once per block instead of per nibble +// quad. iqs is always 0 in this path (hsk4 % 8 == 0 implies block-aligned). +// Q8_0 takes the generic get_k_qs path because its qs layout (i8 pairs) doesn't +// share this nibble shape. +// +// Returned via a struct so the caller's k_quants array (sized from spec +// constants) doesn't need to match a fixed[8] out-parameter type. +struct fa_k_qs_block8 { + int32_t qs[8]; +}; + +fa_k_qs_block8 get_k_qs_block8(uint ib, uint a_offset) { + fa_k_qs_block8 r; + uint qh = 0; + if (FaTypeK == FA_TYPE_Q5_0) { + qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qh[0], + k_packed_q5_0.data[a_offset + ib].qh[1])); + } else if (FaTypeK == FA_TYPE_Q5_1) { + qh = k_packed_q5_1.data[a_offset + ib].qh; + } + const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1); + [[unroll]] for (uint32_t d = 0; d < 4; d++) { + uint vui = 0; + switch (FaTypeK) { + case FA_TYPE_Q4_0: { // packed16 + vui = pack32(u16vec2(k_packed_q4_0.data[a_offset + ib].qs[d * 2 + 0], + k_packed_q4_0.data[a_offset + ib].qs[d * 2 + 1])); + break; + } + case FA_TYPE_Q4_1: { // packed32 alias + vui = k_packed_q4_1_p32.data[a_offset + ib].qs[d]; + break; + } + case FA_TYPE_Q5_0: { // packed16 + vui = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qs[d * 2 + 0], + k_packed_q5_0.data[a_offset + ib].qs[d * 2 + 1])); + break; + } + case FA_TYPE_Q5_1: { // packed32 alias + vui = k_packed_q5_1_p32.data[a_offset + ib].qs[d]; + break; + } + } + r.qs[d ] = int32_t( vui & 0x0F0F0F0F); + r.qs[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); + if (has_qh) { + uint qh_lo = (qh >> (d * 4)) & 0xFu; + uint qh_hi = (qh >> (d * 4 + 16)) & 0xFu; + r.qs[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); + r.qs[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); + } + } + return r; +} + int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) { -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) - uint sub = pos % 4; - uint shift = ((pos % 8) >= 4) ? 4 : 0; - return int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F); -#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - uint sub = pos % 4; - uint shift = ((pos % 8) >= 4) ? 4 : 0; - int32_t result = int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F); - uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4)) & 0xF; - return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u); -#elif defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) - return kblocksh[buf_ib].qs[pos]; -#endif + switch (FaTypeK) { + case FA_TYPE_Q4_0: + case FA_TYPE_Q4_1: { + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4u : 0u; + return int32_t((uint(kblocksh[buf_ib].qs[sub]) >> shift) & 0x0F0F0F0Fu); + } + case FA_TYPE_Q5_0: + case FA_TYPE_Q5_1: { + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4u : 0u; + int32_t result = int32_t((uint(kblocksh[buf_ib].qs[sub]) >> shift) & 0x0F0F0F0Fu); + uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4u)) & 0xFu; + return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u); + } + case FA_TYPE_Q8_0: { + return kblocksh[buf_ib].qs[pos]; + } + default: return 0; + } } ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) { -#if defined(DATA_A_Q4_0) - return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; -#elif defined(DATA_A_Q5_0) - return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; -#elif defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) - return ACC_TYPE(Qf[qib].ds.y) * k_dm.y; -#else - return ACC_TYPE(0.0); -#endif + switch (FaTypeK) { + case FA_TYPE_Q4_0: return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; + case FA_TYPE_Q5_0: return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; + case FA_TYPE_Q4_1: + case FA_TYPE_Q5_1: return ACC_TYPE(Qf[qib].ds.y) * k_dm.y; + default: return ACC_TYPE(0.0); + } } void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) { kblocksh[buf_ib].qs[iqs] = 0; -#if defined(DATA_A_IQ4_NL) - kblocksh[buf_ib].qs[iqs + 4] = 0; -#endif if (iqs == 0) { -#if QUANT_AUXF == 1 - kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f); -#else kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f); -#endif } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl index 10552d013..79c933f40 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl @@ -1,4 +1,13 @@ -#if defined(DATA_A_Q4_0) +#if defined(FA_MMQ_MIXED) +// Mixed-K flash attention MMQ: superset cache that fits Q4_0/Q4_1/Q5_0/Q5_1/Q8_0. +// Q4_*/Q5_* only use qs[0..3] and (for Q5_*) qh. Q8_0 uses qs[0..7]. Single-scale +// types (Q4_0/Q5_0/Q8_0) leave dm.y unused. +struct block_a_cache { + int32_t qs[8]; + uint32_t qh; + FLOAT_TYPEV2 dm; +}; +#elif defined(DATA_A_Q4_0) #define QUANT_R_MMQ 2 struct block_a_cache { uint32_t qs[16/4]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index db34e8b40..8210e221f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -660,42 +660,22 @@ void process_shaders() { if (fp16) { #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - string_to_spv("flash_attn_f32_f16_mixed", "flash_attn_cm2.comp", + string_to_spv("flash_attn_f32_f16", "flash_attn_cm2.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc); #endif - } - for (const auto& tname : type_names) { - if (tname == "bf16") continue; - - if (fp16) { #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); - } else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); - } + string_to_spv("flash_attn_f32_f16", "flash_attn_cm1.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); #endif - } - - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); - } else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc); -#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (tname != "f32") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8"); - } -#endif - } } + + string_to_spv("flash_attn_f32_f16", "flash_attn.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + string_to_spv("flash_attn_f32_f16", "flash_attn.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"MMQ", "1"}, {"FA_MMQ_MIXED", "1"}}), fp16, false, false, f16acc, "_int8"); +#endif } } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 617cbc49d..4055ec287 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -299,30 +299,32 @@ class Keys: HAS_LLAVA_PROJECTOR = "clip.has_llava_projector" class ClipVision: - PROJECTOR_TYPE = "clip.vision.projector_type" # for mixed modality models - IMAGE_SIZE = "clip.vision.image_size" - IMAGE_MIN_PIXELS = "clip.vision.image_min_pixels" - IMAGE_MAX_PIXELS = "clip.vision.image_max_pixels" - PREPROC_MIN_TILES = "clip.vision.preproc_min_tiles" - PREPROC_MAX_TILES = "clip.vision.preproc_max_tiles" - PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size" - PATCH_SIZE = "clip.vision.patch_size" - EMBEDDING_LENGTH = "clip.vision.embedding_length" - FEED_FORWARD_LENGTH = "clip.vision.feed_forward_length" - PROJECTION_DIM = "clip.vision.projection_dim" - BLOCK_COUNT = "clip.vision.block_count" - IMAGE_MEAN = "clip.vision.image_mean" - IMAGE_STD = "clip.vision.image_std" - SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size" - USE_GELU = "clip.use_gelu" - USE_SILU = "clip.use_silu" - N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl - WA_LAYER_INDEXES = "clip.vision.wa_layer_indexes" # used by youtuvl - IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers" - WINDOW_SIZE = "clip.vision.window_size" + PROJECTOR_TYPE = "clip.vision.projector_type" # for mixed modality models + IMAGE_SIZE = "clip.vision.image_size" + IMAGE_MIN_PIXELS = "clip.vision.image_min_pixels" + IMAGE_MAX_PIXELS = "clip.vision.image_max_pixels" + PREPROC_MIN_TILES = "clip.vision.preproc_min_tiles" + PREPROC_MAX_TILES = "clip.vision.preproc_max_tiles" + PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size" + PATCH_SIZE = "clip.vision.patch_size" + EMBEDDING_LENGTH = "clip.vision.embedding_length" + FEED_FORWARD_LENGTH = "clip.vision.feed_forward_length" + PROJECTION_DIM = "clip.vision.projection_dim" + BLOCK_COUNT = "clip.vision.block_count" + IMAGE_MEAN = "clip.vision.image_mean" + IMAGE_STD = "clip.vision.image_std" + SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size" + USE_GELU = "clip.use_gelu" + USE_SILU = "clip.use_silu" + N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl + WA_LAYER_INDEXES = "clip.vision.wa_layer_indexes" # used by youtuvl + WA_PATTERN_MODE = "clip.vision.wa_pattern_mode" # used by mimovl, per-layer -1/0/1 + IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers" + WINDOW_SIZE = "clip.vision.window_size" class Attention: HEAD_COUNT = "clip.vision.attention.head_count" + HEAD_COUNT_KV = "clip.vision.attention.head_count_kv" # used by mimovl (GQA) LAYERNORM_EPS = "clip.vision.attention.layer_norm_epsilon" class Projector: @@ -733,6 +735,7 @@ class MODEL_TENSOR(IntEnum): V_ENC_ATTN_V = auto() V_ENC_ATTN_O = auto() V_ENC_ATTN_O_NORM = auto() + V_ENC_ATTN_SINKS = auto() # mimovl V_ENC_POST_ATTN_NORM = auto() V_ENC_FFN_UP = auto() V_ENC_FFN_GATE = auto() @@ -1246,6 +1249,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.V_ENC_INPUT_NORM: "v.blk.{bid}.ln1", MODEL_TENSOR.V_ENC_ATTN_O: "v.blk.{bid}.attn_out", MODEL_TENSOR.V_ENC_ATTN_O_NORM: "v.blk.{bid}.attn_out_norm", + MODEL_TENSOR.V_ENC_ATTN_SINKS: "v.blk.{bid}.attn_sinks", MODEL_TENSOR.V_ENC_POST_ATTN_NORM: "v.blk.{bid}.ln2", MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up", MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate", @@ -1426,6 +1430,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.V_ENC_ATTN_V, MODEL_TENSOR.V_ENC_ATTN_O, MODEL_TENSOR.V_ENC_ATTN_O_NORM, + MODEL_TENSOR.V_ENC_ATTN_SINKS, MODEL_TENSOR.V_ENC_POST_ATTN_NORM, MODEL_TENSOR.V_ENC_FFN_UP, MODEL_TENSOR.V_ENC_FFN_GATE, @@ -4258,6 +4263,7 @@ class VisionProjectorType: HUNYUANVL = "hunyuanvl" MINICPMV4_6 = "minicpmv4_6" GRANITE_SPEECH = "granite_speech" # audio + MIMOVL = "mimovl" # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 35fb01470..a10138271 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1151,6 +1151,9 @@ class GGUFWriter: def add_vision_head_count(self, value: int) -> None: self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value) + def add_vision_head_count_kv(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT_KV, value) + def add_vision_attention_layernorm_eps(self, value: float) -> None: self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value) @@ -1222,6 +1225,9 @@ class GGUFWriter: def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None: self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers) + def add_vision_wa_pattern_mode(self, modes: Sequence[int]) -> None: + self.add_array(Keys.ClipVision.WA_PATTERN_MODE, modes) + def add_vision_window_size(self, value: int) -> None: self.add_uint32(Keys.ClipVision.WINDOW_SIZE, value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index f27f0e4c9..f40cb8282 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1569,6 +1569,10 @@ class TensorNameMap: "vision_model.transformer.resblocks.{bid}.attn.out_proj", # Step3-VL ), + MODEL_TENSOR.V_ENC_ATTN_SINKS: ( + "visual.blocks.{bid}.attn.sinks", # mimovl + ), + MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2", "model.vision_tower.encoder.layers.{bid}.layer_norm2", # minicpmv4_6 diff --git a/include/llama.h b/include/llama.h index 5d43e571e..702d64d66 100644 --- a/include/llama.h +++ b/include/llama.h @@ -861,6 +861,8 @@ extern "C" { size_t n_token_capacity, size_t * n_token_count_out); +#define LLAMA_STATE_SEQ_FLAGS_NONE 0 + // for backwards-compat #define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1 diff --git a/src/llama-context.cpp b/src/llama-context.cpp index dc54ba252..f27602a99 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2485,11 +2485,29 @@ public: } if (need_alloc) { - mbuf_cur = std::move(mbuf); + if (!mbuf_cur.buf || mbuf_cur.total_size != mbuf.total_size) { + mbuf_cur = std::move(mbuf); - mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft)); + mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft)); - LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + } else { + //LLAMA_LOG_INFO("%s: reallocating tensors in '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + + // save the old buffer and allocate the new tensors in it + auto buf = std::move(mbuf_cur.buf); + + mbuf_cur = std::move(mbuf); + + ggml_tallocr talloc = ggml_tallocr_new(buf.get()); + + for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { + ggml_backend_view_init(mbuf_cur.org[i]); + ggml_tallocr_alloc(&talloc, mbuf_cur.cpy[i]); + } + + mbuf_cur.buf = std::move(buf); + } } for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { @@ -2569,8 +2587,7 @@ public: mbuf.org.push_back(ggml_view_1d(mbuf.ctx.get(), rinfo.tensor, n, rinfo.offset)); - auto & view = mbuf.org.back(); - view->buffer = rinfo.tensor->buffer; + ggml_backend_view_init(mbuf.org.back()); } for (auto & [buft, mbuf] : mbufs_new) { diff --git a/tools/mtmd/clip-graph.h b/tools/mtmd/clip-graph.h index d3e7b1ed0..39f069501 100644 --- a/tools/mtmd/clip-graph.h +++ b/tools/mtmd/clip-graph.h @@ -98,7 +98,8 @@ struct clip_graph { ggml_tensor * v_cur, ggml_tensor * kq_mask, float kq_scale, - int il) const; + int il, + ggml_tensor * sinks = nullptr) const; // implementation of the 2D RoPE without adding a new op in ggml // this is not efficient (use double the memory), but works on all backends diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index cc6a1daf4..35a78a599 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -31,6 +31,7 @@ #define KEY_N_BLOCK "clip.%s.block_count" #define KEY_PROJ_DIM "clip.%s.projection_dim" #define KEY_N_HEAD "clip.%s.attention.head_count" +#define KEY_N_HEAD_KV "clip.%s.attention.head_count_kv" #define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon" // vision-specific @@ -53,6 +54,7 @@ #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern" #define KEY_WIN_ATTN_LAYER_INDEXES "clip.vision.wa_layer_indexes" +#define KEY_WA_PATTERN_MODE "clip.vision.wa_pattern_mode" #define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" #define KEY_MINICPMV_VERSION "clip.minicpmv_version" #define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num" @@ -86,6 +88,7 @@ #define TN_ATTN_Q "%s.blk.%d.attn_q.%s" #define TN_ATTN_V "%s.blk.%d.attn_v.%s" #define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s" +#define TN_ATTN_SINKS "%s.blk.%d.attn_sinks" #define TN_ATTN_K_NORM "%s.blk.%d.attn_k_norm.%s" #define TN_ATTN_Q_NORM "%s.blk.%d.attn_q_norm.%s" #define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s" @@ -344,6 +347,7 @@ enum projector_type { PROJECTOR_TYPE_HUNYUANVL, PROJECTOR_TYPE_MINICPMV4_6, PROJECTOR_TYPE_GRANITE_SPEECH, + PROJECTOR_TYPE_MIMOVL, PROJECTOR_TYPE_UNKNOWN, }; @@ -393,6 +397,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_HUNYUANVL, "hunyuanvl"}, { PROJECTOR_TYPE_MINICPMV4_6, "minicpmv4_6"}, { PROJECTOR_TYPE_GRANITE_SPEECH, "granite_speech"}, + { PROJECTOR_TYPE_MIMOVL, "mimovl"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 48f8b1a19..ce15dbcd1 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -42,6 +42,7 @@ struct clip_hparams { int32_t n_ff = 0; int32_t projection_dim = 0; int32_t n_head = 0; + int32_t n_head_kv = 0; int32_t n_layer = 0; // idefics3 int32_t n_merge = 0; // number of patch merges **per-side** @@ -83,6 +84,7 @@ struct clip_hparams { int32_t attn_window_size = 0; int32_t n_wa_pattern = 0; std::unordered_set wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL) + std::vector wa_pattern_mode; // mimovl: per-layer window-attention mode // deepseek-ocr (sam) int32_t sam_n_layer = 0; @@ -166,6 +168,8 @@ struct clip_layer { ggml_tensor * o_w = nullptr; ggml_tensor * o_b = nullptr; + ggml_tensor * attn_sinks = nullptr; + ggml_tensor * k_norm = nullptr; ggml_tensor * q_norm = nullptr; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 562dd5785..d8d442ab0 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -70,6 +70,7 @@ #include "models/pixtral.cpp" #include "models/qwen2vl.cpp" #include "models/qwen3vl.cpp" +#include "models/mimovl.cpp" #include "models/qwen3a.cpp" #include "models/step3vl.cpp" #include "models/siglip.cpp" @@ -701,7 +702,8 @@ ggml_tensor * clip_graph::build_attn( ggml_tensor * v_cur, ggml_tensor * kq_mask, float kq_scale, - int il) const { + int il, + ggml_tensor * sinks) const { // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced ggml_build_forward_expand(gf, q_cur); @@ -724,6 +726,9 @@ ggml_tensor * clip_graph::build_attn( cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f); ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + if (sinks != nullptr) { + ggml_flash_attn_ext_add_sinks(cur, sinks); + } cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); @@ -736,6 +741,9 @@ ggml_tensor * clip_graph::build_attn( // ggml_mul_mat_set_prec(kq, GGML_PREC_F32); kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f); + if (sinks != nullptr) { + ggml_soft_max_add_sinks(kq, sinks); + } ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); @@ -925,6 +933,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_MIMOVL: + { + builder = std::make_unique(ctx, img); + } break; case PROJECTOR_TYPE_STEP3VL: { builder = std::make_unique(ctx, img); @@ -1476,6 +1488,22 @@ struct clip_model_loader { // LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__); // } } break; + case PROJECTOR_TYPE_MIMOVL: + { + hparams.n_merge = 2; // spatial_merge_size + hparams.image_resize_algo = RESIZE_ALGO_BICUBIC_PILLOW; + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); + get_u32(string_format(KEY_N_HEAD_KV, "vision"), hparams.n_head_kv); + // 1D banded sliding-window radius (visual_token_window_size); required + get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size); + std::vector pat; + get_arr_int(KEY_WA_PATTERN_MODE, pat, true); + GGML_ASSERT((int) pat.size() == hparams.n_layer && "mimovl wa_pattern_mode length must equal n_layer"); + hparams.wa_pattern_mode.assign(pat.begin(), pat.end()); + get_u32(KEY_IMAGE_MIN_PIXELS, hparams.image_min_pixels); + get_u32(KEY_IMAGE_MAX_PIXELS, hparams.image_max_pixels); + hparams.set_warmup_n_tokens(46*46); // avoid OOM on warmup + } break; case PROJECTOR_TYPE_STEP3VL: { hparams.n_merge = 4; // two stride-2 downsamplers after patching @@ -1821,6 +1849,8 @@ struct clip_model_loader { layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight")); layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false); + // mimovl per-head attention sink bias + layer.attn_sinks = get_tensor(string_format(TN_ATTN_SINKS, prefix, il), false); // qwen3vl deepstack layer layer.deepstack_norm_w = get_tensor(string_format(TN_DEEPSTACK_NORM, il, "weight"), false); @@ -2005,6 +2035,13 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; + case PROJECTOR_TYPE_MIMOVL: + { + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); + } break; case PROJECTOR_TYPE_STEP3VL: { model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); @@ -3280,6 +3317,7 @@ void setup_init_vision_shim_kcpp(struct clip_ctx * ctx_v) { case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_MIMOVL: { // <|vision_start|> ... (image embeddings) ... <|vision_end|> img_beg = "<|vision_start|>"; @@ -3529,6 +3567,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_MIMOVL: case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_PADDLEOCR: case PROJECTOR_TYPE_HUNYUANOCR: @@ -3550,6 +3589,7 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_MIMOVL: case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_PADDLEOCR: case PROJECTOR_TYPE_HUNYUANVL: @@ -3628,6 +3668,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_MIMOVL: case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_YOUTUVL: { @@ -4199,6 +4240,89 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima set_input_i32("positions", positions); } break; + case PROJECTOR_TYPE_MIMOVL: + { + const int merge = hparams.n_merge; // 2 + const int merge_unit = merge * merge; // 4 + const int patch = hparams.patch_size; // 16 + const int H = image_size_height / patch; + const int W = image_size_width / patch; + const int n_pos_full = H * W; + const int llm_h = H / merge; + const int llm_w = W / merge; + const int n_units = llm_h * llm_w; // n_pos / merge_unit + + // Row-major merge-tile-ordered (h, w) positions + std::vector pos_h_row(n_pos_full); + std::vector pos_w_row(n_pos_full); + { + int idx = 0; + for (int ty = 0; ty < llm_h; ty++) { + for (int tx = 0; tx < llm_w; tx++) { + for (int dy = 0; dy < merge; dy++) { + for (int dx = 0; dx < merge; dx++) { + pos_h_row[idx] = ty * merge + dy; + pos_w_row[idx] = tx * merge + dx; + idx++; + } + } + } + } + } + + // Col-major merge-unit permutation + std::vector idx_col(n_units); + for (int r = 0; r < llm_h; r++) { + for (int c = 0; c < llm_w; c++) { + int u_row = r * llm_w + c; + int u_col = c * llm_h + r; + idx_col[u_col] = (float) u_row; + } + } + + // Col-mode positions: permute pos_*_row by idx_col + std::vector pos_h_col(n_pos_full); + std::vector pos_w_col(n_pos_full); + for (int u = 0; u < n_units; u++) { + int src = (int) idx_col[u]; + for (int k = 0; k < merge_unit; k++) { + pos_h_col[u * merge_unit + k] = pos_h_row[src * merge_unit + k]; + pos_w_col[u * merge_unit + k] = pos_w_row[src * merge_unit + k]; + } + } + + // Pack into ggml_rope_multi VISION-mode layout. The non-CPU kernels + // only read slots 0 and 1, so pack h in slot 0, w in slot 1: + // positions[0..n_pos) = h + // positions[n_pos..2*n_pos) = w + // positions[2*n_pos..3*n_pos) = 0 + // positions[3*n_pos..4*n_pos) = 0 + std::vector positions_row(static_cast(n_pos_full) * 4, 0); + std::vector positions_col(static_cast(n_pos_full) * 4, 0); + for (int i = 0; i < n_pos_full; i++) { + positions_row[0 * n_pos_full + i] = pos_h_row[i]; + positions_row[1 * n_pos_full + i] = pos_w_row[i]; + positions_col[0 * n_pos_full + i] = pos_h_col[i]; + positions_col[1 * n_pos_full + i] = pos_w_col[i]; + } + + // Banded 1D sliding-window mask + const int window = hparams.attn_window_size; + GGML_ASSERT(window > 0); + std::vector mask(static_cast(n_pos_full) * n_pos_full, std::numeric_limits::lowest()); + for (int q = 0; q < n_pos_full; q++) { + int lo = std::max(0, q - window); + int hi = std::min(n_pos_full - 1, q + window); + for (int k = lo; k <= hi; k++) { + mask[static_cast(q) * n_pos_full + k] = 0.0f; + } + } + + set_input_i32("mimovl_positions_row", positions_row); + set_input_i32("mimovl_positions_col", positions_col); + set_input_f32("mimovl_idx_col", idx_col); + set_input_f32("mimovl_window_mask", mask); + } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_KIMIVL: case PROJECTOR_TYPE_KIMIK25: @@ -4796,6 +4920,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_QWEN3VL: // main path + deepstack paths return ctx->model.mm_1_b->ne[0] * (1 + ctx->model.n_deepstack_layers); + case PROJECTOR_TYPE_MIMOVL: + return ctx->model.mm_1_w->ne[1]; case PROJECTOR_TYPE_STEP3VL: return ctx->model.mm_model_proj->ne[1]; case PROJECTOR_TYPE_GEMMA3: diff --git a/tools/mtmd/models/mimovl.cpp b/tools/mtmd/models/mimovl.cpp new file mode 100644 index 000000000..19db88f13 --- /dev/null +++ b/tools/mtmd/models/mimovl.cpp @@ -0,0 +1,209 @@ +#include "models.h" + +ggml_tensor * clip_graph_mimovl::build_mm(ggml_tensor * w, ggml_tensor * x) const { + ggml_tensor * cur = ggml_mul_mat(ctx0, w, x); + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + return cur; +} + +// MiMoVL vision tower for MiMo-V2.5 (non-Pro). Qwen2.5-VL-shaped ViT, except: +// 1. GQA in attention (32 Q / 8 KV heads, head_dim 64). +// 2. Per-head attention sinks on every windowed layer. The sinks adjust +// the softmax denominator (equivalently, a virtual extra K column with V=0), +// so they decay attention weight without contributing to the output. +// 3. Per-layer window-attention mode in hparams.wa_pattern_mode: +// -1 -> full, 0 -> row-window+sinks, 1 -> col-window+sinks. +// Col mode transposes the merge-unit grid on entry and restores +// it on exit. Both patch and rotary orderings are pre-computed +// host-side. +// 4. 1D banded sliding window (|q-k| > window_size -> -inf) as a +// single 2D mask broadcast across heads. +// 5. Per-block MLP biases. +ggml_cgraph * clip_graph_mimovl::build() { + GGML_ASSERT(model.patch_embeddings_0 != nullptr); + GGML_ASSERT(model.patch_embeddings_1 != nullptr); + GGML_ASSERT(model.class_embedding == nullptr); + GGML_ASSERT(hparams.n_head_kv > 0); + GGML_ASSERT(n_head % hparams.n_head_kv == 0); + GGML_ASSERT((int) hparams.wa_pattern_mode.size() == n_layer); + + const int batch_size = 1; + const int n_pos = n_patches; + const int n_head_kv = hparams.n_head_kv; + const int merge = hparams.n_merge > 0 ? hparams.n_merge : 2; + const int merge_unit = merge * merge; + const int n_units = n_pos / merge_unit; + GGML_ASSERT(n_units * merge_unit == n_pos); + + // MiMoVL has head_dim=64 with n_embd=1280, so n_embd is NOT n_head*head_dim + // (the base class's d_head = n_embd/n_head = 40 is wrong here). Derive + // head_dim from the fused QKV projection: rows = (n_head + 2*n_head_kv)*head_dim. + GGML_ASSERT(model.layers[0].qkv_w != nullptr); + const int qkv_rows = model.layers[0].qkv_w->ne[1]; + const int head_dim = qkv_rows / (n_head + 2 * n_head_kv); + GGML_ASSERT(head_dim * (n_head + 2 * n_head_kv) == qkv_rows); + const float attn_scale = 1.0f / std::sqrt((float) head_dim); + const int rope_n_dims = head_dim / 2; + int mrope_sections[4] = {rope_n_dims/2, rope_n_dims/2, 0, 0}; + + // Patch embed: Conv3D(kt=2) split into two Conv2D, then interleave-merge + // along the height axis to match the merge-tile token order. + ggml_tensor * inp_raw = build_inp_raw(); + ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, + patch_size, patch_size, 0, 0, 1, 1); + { + ggml_tensor * inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, + patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_add(ctx0, inp, inp_1); + + GGML_ASSERT(img.nx % (patch_size * 2) == 0); + GGML_ASSERT(img.ny % (patch_size * 2) == 0); + + inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w,h,c,b] -> [c,w,h,b] + inp = ggml_cont_4d(ctx0, inp, n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); + inp = ggml_reshape_4d(ctx0, inp, n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); + inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); + inp = ggml_cont_3d(ctx0, inp, n_embd, n_patches_x * n_patches_y, batch_size); + } + cb(inp, "patch_embed", -1); + + ggml_tensor * positions_row = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos * 4); + ggml_set_name(positions_row, "mimovl_positions_row"); + ggml_set_input(positions_row); + + ggml_tensor * positions_col = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos * 4); + ggml_set_name(positions_col, "mimovl_positions_col"); + ggml_set_input(positions_col); + + // idx_col is the col-major merge-unit permutation. Take it as F32 so we can + // derive the inverse permutation in-graph via ggml_argsort; + // ggml_get_rows requires its index tensor to be I32, so cast back as well. + ggml_tensor * idx_col_f = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_units); + ggml_set_name(idx_col_f, "mimovl_idx_col"); + ggml_set_input(idx_col_f); + ggml_tensor * idx_col = ggml_cast(ctx0, idx_col_f, GGML_TYPE_I32); + ggml_tensor * idx_col_inv = ggml_argsort(ctx0, idx_col_f, GGML_SORT_ORDER_ASC); + + ggml_tensor * window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos); + ggml_set_name(window_mask, "mimovl_window_mask"); + ggml_set_input(window_mask); + + ggml_tensor * window_mask_attn = (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) + ? ggml_cast(ctx0, window_mask, GGML_TYPE_F16) + : window_mask; + + // Reorder helper: permute patches at merge-unit granularity. The patch + // sequence is laid out as n_units groups of merge_unit (=4) consecutive + // patches; the row<->col transpose only permutes whole groups. We keep + // the per-group (h,w) ordering intact by reshaping to + // [n_embd*merge_unit, n_units] before ggml_get_rows. + auto reorder = [&](ggml_tensor * x, ggml_tensor * idx) { + ggml_tensor * y = ggml_reshape_2d(ctx0, x, n_embd * merge_unit, n_units); + y = ggml_get_rows(ctx0, y, idx); + return ggml_reshape_3d(ctx0, y, n_embd, n_pos, batch_size); + }; + + ggml_tensor * inpL = inp; + int prev_mode = -1; + + for (int il = 0; il < n_layer; il++) { + const auto & layer = model.layers[il]; + const int mode = hparams.wa_pattern_mode[il]; + const bool is_full = (mode == -1); + const bool is_col = (mode == 1); + + // Reorder transitions on entry/exit of a col-mode run. + if (is_col && prev_mode != 1) { + inpL = reorder(inpL, idx_col); + cb(inpL, "reorder_to_col", il); + } else if (!is_col && prev_mode == 1) { + inpL = reorder(inpL, idx_col_inv); + cb(inpL, "reorder_to_row", il); + } + + ggml_tensor * cur = inpL; + + // Pre-attention RMSNorm. + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_RMS, eps, il); + cb(cur, "ln1", il); + + // Fused QKV with GQA. + ggml_tensor * qkv = build_mm(layer.qkv_w, cur); + qkv = ggml_add(ctx0, qkv, layer.qkv_b); + + const size_t row = ggml_row_size(qkv->type, head_dim); + const size_t off_k = ggml_row_size(qkv->type, n_head * head_dim); + const size_t off_v = ggml_row_size(qkv->type, (n_head + n_head_kv) * head_dim); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, head_dim, n_head, n_pos, row, qkv->nb[1], 0); + ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, head_dim, n_head_kv, n_pos, row, qkv->nb[1], off_k); + ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, head_dim, n_head_kv, n_pos, row, qkv->nb[1], off_v); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // 2D RoPE + ggml_tensor * pos = is_col ? positions_col : positions_row; + Qcur = ggml_rope_multi(ctx0, Qcur, pos, nullptr, rope_n_dims, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000.0f, 1.0f, 0.0f, 1.0f, 32.0f, 1.0f); + Kcur = ggml_rope_multi(ctx0, Kcur, pos, nullptr, rope_n_dims, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000.0f, 1.0f, 0.0f, 1.0f, 32.0f, 1.0f); + cb(Qcur, "Qcur_rope", il); + cb(Kcur, "Kcur_rope", il); + + // Full layers: plain attention. Windowed layers: banded mask and per-head sinks. + ggml_tensor * mask = is_full ? nullptr : window_mask_attn; + ggml_tensor * sinks = is_full ? nullptr : layer.attn_sinks; + if (!is_full) { + GGML_ASSERT(layer.attn_sinks != nullptr); + } + ggml_tensor * attn_out = build_attn(layer.o_w, layer.o_b, Qcur, Kcur, Vcur, mask, attn_scale, il, sinks); + cb(attn_out, "attn_out", il); + + // Residual 1. + cur = ggml_add(ctx0, attn_out, inpL); + inpL = cur; + cb(cur, "ffn_inp", il); + + // Pre-FFN RMSNorm. + cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_RMS, eps, il); + cb(cur, "ffn_inp_normed", il); + + // SwiGLU MLP with biases + cur = build_ffn(cur, + layer.ff_up_w, layer.ff_up_b, + layer.ff_gate_w, layer.ff_gate_b, + layer.ff_down_w, layer.ff_down_b, + hparams.ffn_op, il); + cb(cur, "ffn_out", il); + + // Residual 2. + cur = ggml_add(ctx0, inpL, cur); + cb(cur, "layer_out", il); + + inpL = cur; + prev_mode = mode; + } + + // If the last block was col-mode, undo the transpose so the merger sees patches in row order. + if (prev_mode == 1) { + inpL = reorder(inpL, idx_col_inv); + cb(inpL, "reorder_to_row_final", -1); + } + + // Merger: post-LayerNorm + inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, NORM_TYPE_NORMAL, 1e-6f, n_layer); + cb(inpL, "post_ln", -1); + + // Spatial merge: pack each merge_unit (=4) of patches into a single + // (n_embd*merge_unit)-wide row, then run the 2-layer MLP. + ggml_tensor * embeddings = ggml_reshape_3d(ctx0, inpL, n_embd * merge_unit, n_units, batch_size); + embeddings = build_ffn(embeddings, + model.mm_0_w, nullptr, + nullptr, nullptr, + model.mm_1_w, nullptr, + FFN_GELU, -1); + cb(embeddings, "vit_out", -1); + + ggml_build_forward_expand(gf, embeddings); + return gf; +} diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index dbba233b1..955daa6d6 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -33,6 +33,15 @@ struct clip_graph_qwen3vl : clip_graph { ggml_cgraph * build() override; }; +struct clip_graph_mimovl : clip_graph { + clip_graph_mimovl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; + // Force F32 mat-mul accumulation to avoid F16 overflow in the FFN down-proj + // when the mmproj is stored in F16 (the source weights are BF16; downcasting + // to F16 reduces dynamic range below the SwiGLU output magnitude on the last few layers). + ggml_tensor * build_mm(ggml_tensor * w, ggml_tensor * x) const override; +}; + struct clip_graph_step3vl : clip_graph { clip_graph_step3vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 87da6876f..22092f6a6 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -325,6 +325,7 @@ struct mtmd_context { case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_MIMOVL: { // <|vision_start|> ... (image embeddings) ... <|vision_end|> img_beg = "<|vision_start|>"; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 637f8d216..ce743e665 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -36,32 +36,6 @@ using json = nlohmann::ordered_json; constexpr int HTTP_POLLING_SECONDS = 1; -static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, bool on_device, llama_pos pos_min = -1, llama_pos pos_max = -1) { - if (pos_min == -1) { - pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id); - } - if (pos_max == -1) { - pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), id); - } - - auto flags = LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY; - if (on_device) { - flags |= LLAMA_STATE_SEQ_FLAGS_ON_DEVICE; - } - - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, flags); - - ckpt.pos_min = pos_min; - ckpt.pos_max = pos_max; - ckpt.n_tokens = n_tokens; - ckpt.data.resize(checkpoint_size); - - const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data.data(), checkpoint_size, id, flags); - if (n != checkpoint_size) { - GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); - } -} - // state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 enum slot_state { SLOT_STATE_IDLE, @@ -80,18 +54,19 @@ enum server_state { struct server_slot { int id; - llama_context * ctx = nullptr; - - common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + llama_context * ctx_tgt = nullptr; + llama_context * ctx_dft = nullptr; // multimodal mtmd_context * mctx = nullptr; // speculative decoding + common_speculative * spec; + llama_tokens spec_draft; + llama_tokens spec_prompt; std::vector spec_i_batch; - server_prompt_checkpoint spec_ckpt; - common_speculative_ptr spec; + common_prompt_checkpoint spec_ckpt; // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state // see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837 @@ -135,21 +110,27 @@ struct server_slot { void prompt_save(server_prompt_cache & prompt_cache) const { GGML_ASSERT(prompt.data.size() == 0); - const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); + const size_t cur_size_tgt = llama_state_seq_get_size_ext(ctx_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE); + const size_t cur_size_dft = ctx_dft ? llama_state_seq_get_size_ext(ctx_dft, id, LLAMA_STATE_SEQ_FLAGS_NONE) : 0; - SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n", - (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); + const size_t cur_size = cur_size_tgt + cur_size_dft; - auto * cur = prompt_cache.alloc(prompt, cur_size); + SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB (draft: %.3f MiB)\n", + (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0), cur_size_dft / (1024.0 * 1024.0)); + + auto * cur = prompt_cache.alloc(prompt, cur_size_tgt, cur_size_dft); if (cur == nullptr) { return; } - llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0); + llama_state_seq_get_data_ext(ctx_tgt, cur->data.main.data(), cur_size_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE); + if (ctx_dft) { + llama_state_seq_get_data_ext(ctx_dft, cur->data.drft.data(), cur_size_dft, id, LLAMA_STATE_SEQ_FLAGS_NONE); + } } bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { - bool res = prompt_cache.load(prompt, tokens, ctx, id); + bool res = prompt_cache.load(prompt, tokens, ctx_tgt, ctx_dft, id); if (!res) { SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); } @@ -164,7 +145,11 @@ struct server_slot { SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size()); - llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), id, -1, -1); + if (ctx_dft) { + llama_memory_seq_rm(llama_get_memory(ctx_dft), id, -1, -1); + } + prompt.tokens.clear(); } @@ -222,7 +207,7 @@ struct server_slot { task_prev = std::move(task); task.reset(); - llama_set_sampler(ctx, id, nullptr); + llama_set_sampler(ctx_tgt, id, nullptr); // clear alora start alora_invocation_start = -1; @@ -259,7 +244,7 @@ struct server_slot { return !task->need_embd() || - (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + (llama_get_memory(ctx_tgt) && llama_pooling_type(ctx_tgt) == LLAMA_POOLING_TYPE_LAST); } bool can_batch_with(server_slot & other_slot) const { @@ -310,14 +295,10 @@ struct server_slot { return 0; } - const int n_draft_min = common_speculative_n_min(spec.get(), task->params.speculative); - // determine the max draft that fits the current slot state - int n_draft_max = common_speculative_n_max(spec.get(), task->params.speculative); - // note: slot.prompt is not yet expanded with the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts - n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2); + int n_draft_max = n_ctx - prompt.n_tokens() - 2; if (n_remaining > 0) { n_draft_max = std::min(n_draft_max, n_remaining - 1); @@ -325,61 +306,10 @@ struct server_slot { SLT_DBG(*this, "max possible draft: %d\n", n_draft_max); - if (n_draft_max < n_draft_min) { - SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, n_draft_min); - n_draft_max = 0; - } - return n_draft_max; } void update_batch(llama_batch & batch) { - const int n_draft_max = get_n_draft_max(); - if (n_draft_max > 0) { - GGML_ASSERT(can_speculate()); - - // generate draft tokens in speculative decoding mode - // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] - // perform the speculative drafting for all sequences at the same time in a single batch - const llama_tokens & tokens = prompt.tokens.get_text_tokens(); - - const auto & params_spec = task->params.speculative; - - if (!spec_draft.empty()) { - // we have a previous (partial) draft to reuse - if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { - GGML_ASSERT(!spec_ckpt.empty()); - } - } else { - GGML_ASSERT(spec_i_batch.empty()); - - // generate a new draft - spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled); - n_draft_total += spec_draft.size(); - - if (spec_draft.size() > (size_t) n_draft_max) { - SLT_WRN(*this, "draft size %d exceeds max %d, truncating\n", (int) spec_draft.size(), n_draft_max); - spec_draft.resize(n_draft_max); - } - - if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { - const auto n_tokens = prompt.tokens.size(); - - //const int64_t t_start = ggml_time_us(); - - server_prompt_checkpoint_update(spec_ckpt, ctx, this->id, n_tokens, true); - - //const int64_t t_total = ggml_time_us() - t_start; - //printf("checkpoint total: %f ms\n", t_total / 1000.0); - - SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n", - spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024); - } - } - - GGML_ASSERT(spec_draft.size() <= (size_t) n_draft_max); - } - if (spec_draft.empty()) { // no speculative decoding i_batch = batch.n_tokens; @@ -511,7 +441,7 @@ struct server_slot { ); } - common_speculative_print_stats(spec.get()); + common_speculative_print_stats(spec); } json to_json(bool only_metrics = false) const { @@ -539,7 +469,7 @@ struct server_slot { }; if (!only_metrics) { - res["prompt"] = ptask->tokens.detokenize(ctx, true); + res["prompt"] = ptask->tokens.detokenize(ctx_tgt, true); res["generated"] = generated_text.empty() ? debug_generated_text : generated_text; } } @@ -550,8 +480,13 @@ struct server_slot { void copy_state_to(server_slot & other) const { GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT); - llama_memory_seq_rm(llama_get_memory(ctx), other.id, -1, -1); - llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, -1, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), other.id, -1, -1); + llama_memory_seq_cp(llama_get_memory(ctx_tgt), id, other.id, -1, -1); + + if (ctx_dft) { + llama_memory_seq_rm(llama_get_memory(ctx_dft), other.id, -1, -1); + llama_memory_seq_cp(llama_get_memory(ctx_dft), id, other.id, -1, -1); + } other.n_decoded = n_decoded; other.n_remaining = n_remaining; @@ -642,7 +577,8 @@ public: // only use these pointers outside of this class: // - when not in sleeping state // - and, with thread-safe APIs (e.g., tokenizer calls) - llama_model * model = nullptr; + llama_model * model_tgt = nullptr; + mtmd_context * mctx = nullptr; const llama_vocab * vocab = nullptr; @@ -669,11 +605,17 @@ private: // note: keep these alive - they determine the lifetime of the model, context, etc. common_init_result_ptr llama_init; - llama_context * ctx = nullptr; + llama_context * ctx_tgt = nullptr; llama_batch batch {}; llama_model_ptr model_dft; + llama_context_ptr ctx_dft; + + common_context_seq_rm_type ctx_tgt_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + common_context_seq_rm_type ctx_dft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + + common_speculative_ptr spec; bool add_bos_token = true; @@ -708,18 +650,12 @@ private: void destroy() { llama_init.reset(); - ctx = nullptr; - model = nullptr; + ctx_tgt = nullptr; + model_tgt = nullptr; mtmd_free(mctx); mctx = nullptr; - for (server_slot & slot : slots) { - if (slot.can_speculate()) { - slot.spec.reset(); - } - } - llama_batch_free(batch); } @@ -759,17 +695,17 @@ private: llama_init = common_init_from_params(params_base); - model = llama_init->model(); - ctx = llama_init->context(); + model_tgt = llama_init->model(); + ctx_tgt = llama_init->context(); - if (model == nullptr) { + if (model_tgt == nullptr) { SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); return false; } - vocab = llama_model_get_vocab(model); + vocab = llama_model_get_vocab(model_tgt); - n_ctx = llama_n_ctx(ctx); + n_ctx = llama_n_ctx(ctx_tgt); add_bos_token = llama_vocab_get_add_bos(vocab); @@ -781,9 +717,6 @@ private: auto params_dft = params_base; - params_dft.n_parallel = 1; - params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx; - params_dft.n_batch = llama_n_ctx_seq(ctx); params_dft.devices = params_spec.devices; params_dft.model = params_spec.mparams; params_dft.n_gpu_layers = params_spec.n_gpu_layers; @@ -805,8 +738,13 @@ private: return false; } - params_base.speculative.draft.model = model_dft.get(); - params_base.speculative.draft.cparams = common_context_params_to_llama(params_dft); + auto cparams = common_context_params_to_llama(params_dft); + ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); + + ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); + + params_base.speculative.draft.ctx_tgt = ctx_tgt; + params_base.speculative.draft.ctx_dft = ctx_dft.get(); } std::string & mmproj_path = params_base.mmproj.path; @@ -826,7 +764,7 @@ private: mparams.image_max_tokens = params_base.image_max_tokens; mparams.media_marker = get_media_marker(); - mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); + mctx = mtmd_init_from_file(mmproj_path.c_str(), model_tgt, mparams); if (mctx == nullptr) { SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); return false; @@ -844,7 +782,7 @@ private: } } - if (!llama_memory_can_shift(llama_get_memory(ctx))) { + if (!llama_memory_can_shift(llama_get_memory(ctx_tgt))) { if (params_base.ctx_shift) { params_base.ctx_shift = false; SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled"); @@ -856,14 +794,14 @@ private: } } - if (llama_model_n_swa(model) == 0) { + if (llama_model_n_swa(model_tgt) == 0) { if (params_base.swa_full) { params_base.swa_full = false; SRV_WRN("%s\n", "swa_full is not supported by this model, it will be disabled"); } } - n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model); + n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model_tgt); // Necessary similarity of prompt for slot selection slot_prompt_similarity = params_base.slot_prompt_similarity; @@ -871,9 +809,9 @@ private: // setup slots SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); - const int n_ctx_train = llama_model_n_ctx_train(model); + const int n_ctx_train = llama_model_n_ctx_train(model_tgt); - int n_ctx_slot = llama_n_ctx_seq(ctx); + int n_ctx_slot = llama_n_ctx_seq(ctx_tgt); if (n_ctx_slot > n_ctx_train) { SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train); n_ctx_slot = n_ctx_train; @@ -881,12 +819,12 @@ private: slots.clear(); - const auto ctx_seq_rm_type = common_context_can_seq_rm(ctx); - if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) { + ctx_tgt_seq_rm_type = common_context_can_seq_rm(ctx_tgt); + if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) { SRV_WRN("%s", "speculative decoding not supported by this context\n"); } - if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { + if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { SRV_WRN("%s", "speculative decoding will use checkpoints\n"); } @@ -895,27 +833,33 @@ private: slots.emplace_back(); } + // try speculative decoding + if (ctx_tgt_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) { + try { + spec.reset(common_speculative_init(params_base.speculative, params_base.n_parallel)); + } catch (const std::exception & e) { + SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what()); + } + } + + if (spec) { + SRV_INF("%s", "speculative decoding context initialized\n"); + } else { + ctx_dft.reset(); + } + for (int i = 0; i < params_base.n_parallel; i++) { server_slot & slot = slots[i]; - slot.id = i; - slot.ctx = ctx; - slot.n_ctx = n_ctx_slot; - - slot.ctx_seq_rm_type = ctx_seq_rm_type; + slot.id = i; + slot.ctx_tgt = ctx_tgt; + slot.ctx_dft = ctx_dft.get(); + slot.spec = spec.get(); + slot.n_ctx = n_ctx_slot; slot.mctx = mctx; slot.prompt.tokens.has_mtmd = mctx != nullptr; - // try speculative decoding - if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) { - slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx)); - - if (slot.spec) { - SLT_INF(slot, "%s", "speculative decoding context initialized\n"); - } - } - SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); slot.callback_on_release = [this](int id_slot) { @@ -946,7 +890,7 @@ private: // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { - const int32_t n_batch = llama_n_batch(ctx); + const int32_t n_batch = llama_n_batch(ctx_tgt); batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); } @@ -990,8 +934,9 @@ private: // unlike load_model(), this is only called once during initialization bool init() { - GGML_ASSERT(ctx != nullptr); - GGML_ASSERT(model != nullptr); + GGML_ASSERT(ctx_tgt != nullptr); + GGML_ASSERT(model_tgt != nullptr); + GGML_ASSERT(!sleeping); // wiring up server queues @@ -1037,7 +982,7 @@ private: common_chat_templates_ptr chat_templates; try { - chat_templates = common_chat_templates_init(model, params_base.chat_template); + chat_templates = common_chat_templates_init(model_tgt, params_base.chat_template); LOG_INF("%s: chat template, example_format: '%s'\n", __func__, common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); @@ -1300,7 +1245,7 @@ private: } } - if (!task.tokens.validate(ctx)) { + if (!task.tokens.validate(ctx_tgt)) { send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); return false; } @@ -1310,7 +1255,7 @@ private: // initialize samplers if (task.need_sampling()) { try { - slot.smpl.reset(common_sampler_init(model, task.params.sampling)); + slot.smpl.reset(common_sampler_init(model_tgt, task.params.sampling)); } catch (std::exception & e) { std::string err_msg = std::string("Failed to initialize samplers: ") + e.what(); send_error(task, err_msg, ERROR_TYPE_INVALID_REQUEST); @@ -1324,16 +1269,16 @@ private: backend_sampling &= task.params.sampling.backend_sampling; // TODO: speculative decoding requires multiple samples per batch - not supported yet - backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0); + backend_sampling &= !(slot.can_speculate()); // TODO: getting pre sampling logits is not yet supported with backend sampling backend_sampling &= !need_pre_sample_logits; // TODO: tmp until backend sampling is fully implemented if (backend_sampling) { - llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get())); + llama_set_sampler(ctx_tgt, slot.id, common_sampler_get(slot.smpl.get())); } else { - llama_set_sampler(ctx, slot.id, nullptr); + llama_set_sampler(ctx_tgt, slot.id, nullptr); } SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str()); @@ -1512,13 +1457,13 @@ private: result.probs.push_back({ cur_p->data[i].id, - common_token_to_piece(ctx, cur_p->data[i].id, special), + common_token_to_piece(ctx_tgt, cur_p->data[i].id, special), cur_p->data[i].p }); } } else { // TODO: optimize this with min-p optimization - std::vector cur = get_token_probabilities(ctx, idx); + std::vector cur = get_token_probabilities(ctx_tgt, idx); const size_t max_probs = cur.size(); const size_t n_probs = std::min(max_probs, n_probs_request); @@ -1536,7 +1481,7 @@ private: for (size_t i = 0; i < n_probs; i++) { result.probs.push_back({ cur[i].id, - common_token_to_piece(ctx, cur[i].id, special), + common_token_to_piece(ctx_tgt, cur[i].id, special), cur[i].p }); } @@ -1639,7 +1584,7 @@ private: res->tokens = std::move(slot.generated_tokens); } res->timings = slot.get_timings(); - res->prompt = slot.task->tokens.detokenize(ctx, true); + res->prompt = slot.task->tokens.detokenize(ctx_tgt, true); res->response_fields = std::move(slot.task->params.response_fields); res->truncated = slot.truncated; @@ -1662,7 +1607,7 @@ private: // populate res.probs_output if (slot.task->params.sampling.n_probs > 0) { if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) { - const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); + const llama_tokens stop_word_toks = common_tokenize(ctx_tgt, slot.stopping_word, false); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); res->probs_output = std::vector( @@ -1687,7 +1632,7 @@ private: res->n_tokens = slot.task->n_tokens(); res->res_type = slot.task->params.res_type; - const int n_embd_out = llama_model_n_embd_out(model); + const int n_embd_out = llama_model_n_embd_out(model_tgt); std::vector embd_res(n_embd_out, 0.0f); @@ -1697,10 +1642,10 @@ private: } const float * embd = nullptr; - if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) { - embd = llama_get_embeddings_ith(ctx, i); + if (llama_pooling_type(slot.ctx_tgt) == LLAMA_POOLING_TYPE_NONE) { + embd = llama_get_embeddings_ith(slot.ctx_tgt, i); } else { - embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + embd = llama_get_embeddings_seq(slot.ctx_tgt, batch.seq_id[i][0]); } if (embd == nullptr) { @@ -1711,7 +1656,7 @@ private: } // normalize only when there is pooling - if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { + if (llama_pooling_type(slot.ctx_tgt) != LLAMA_POOLING_TYPE_NONE) { common_embd_normalize(embd, embd_res.data(), n_embd_out, slot.task->params.embd_normalize); res->embedding.push_back(embd_res); break; @@ -1736,9 +1681,9 @@ private: continue; } - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + const float * embd = llama_get_embeddings_seq(ctx_tgt, batch.seq_id[i][0]); if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); + embd = llama_get_embeddings_ith(ctx_tgt, i); } if (embd == NULL) { @@ -1843,18 +1788,22 @@ private: const auto & cur = slot.prompt.checkpoints.front(); SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024); slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); } auto & cur = slot.prompt.checkpoints.emplace_back(); - server_prompt_checkpoint_update(cur, ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, false, pos_min, pos_max); + + cur.update_pos(slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max); + + cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, - cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024); } void process_single_task(server_task && task) { @@ -2009,7 +1958,7 @@ private: std::string filepath = task.slot_action.filepath; const llama_tokens & tokens = slot->prompt.tokens.get_tokens(); - const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); + const size_t nwrite = llama_state_seq_save_file(ctx_tgt, filepath.c_str(), slot->id, tokens.data(), token_count); const int64_t t_end = ggml_time_us(); const double t_save_ms = (t_end - t_start) / 1000.0; @@ -2048,7 +1997,7 @@ private: llama_tokens tokens; tokens.resize(slot->n_ctx); size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); + size_t nread = llama_state_seq_load_file(ctx_tgt, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); if (nread == 0) { slot->prompt.tokens.clear(); // KV may already been invalidated? send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); @@ -2207,8 +2156,13 @@ private: SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); - llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard); - llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); + llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, n_keep , n_keep + n_discard); + llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); + + if (ctx_dft) { + llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, n_keep , n_keep + n_discard); + llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard); + } // add generated tokens to cache // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481 @@ -2236,12 +2190,10 @@ private: // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; - auto accept_special_token = [&](server_slot & slot, llama_token token) { - return params_base.special || - slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); - }; + std::vector generating; + std::vector drafting; - // first, add sampled tokens from any ongoing sequences + // determine which slots are generating and drafting for (auto & slot : slots) { if (slot.state != SLOT_STATE_GENERATING) { continue; @@ -2254,12 +2206,103 @@ private: continue; } + generating.push_back(&slot); + + if (spec) { + common_speculative_get_draft_params(spec.get(), slot.id).drafting = false; + + const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + const bool use_ckpt_dft = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + + const int n_draft_max = slot.get_n_draft_max(); + + if (n_draft_max > 0) { + GGML_ASSERT(slot.can_speculate()); + + if (!slot.spec_draft.empty()) { + // we have a previous (partial) draft to reuse + if (use_ckpt_tgt) { + GGML_ASSERT(!slot.spec_ckpt.empty()); + } + } else { + GGML_ASSERT(slot.spec_i_batch.empty()); + + slot.spec_ckpt.update_pos( + slot.prompt.n_tokens(), + llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id), + llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id)); + + if (use_ckpt_dft) { + slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + } + + slot.spec_prompt = slot.prompt.tokens.get_text_tokens(); + + common_speculative_get_draft_params(spec.get(), slot.id) = { + /* .drafting = */ true, + /* .n_max = */ n_draft_max, + /* .n_past = */ slot.prompt.n_tokens(), + /* .id_last = */ slot.sampled, + /* .prompt = */ &slot.spec_prompt, + /* .result = */ &slot.spec_draft, + }; + + drafting.push_back(&slot); + } + } + } + } + + // generate the actual drafts (if any) + { + common_speculative_draft(spec.get()); + } + + // make checkpoints if needed + for (auto * slot_ptr : drafting) { + auto & slot = *slot_ptr; + + auto & draft = slot.spec_draft; + auto & ckpt = slot.spec_ckpt; + + slot.n_draft_total += draft.size(); + + // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] + if (ctx_dft) { + ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + + llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, ckpt.pos_max + 1, -1); + } + + if (!draft.empty()) { + const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + + if (use_ckpt_tgt) { + //const int64_t t_start = ggml_time_us(); + + ckpt.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + + //const int64_t t_total = ggml_time_us() - t_start; + //printf("checkpoint total: %f ms\n", t_total / 1000.0); + + SLT_DBG(slot, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %d, size = %.3f MiB, draft = %.3f MiB)\n", + ckpt.pos_min, ckpt.pos_max, slot.prompt.n_tokens(), + (float) ckpt.size() / 1024 / 1024, + (float) ckpt.data_dft.size() / 1024 / 1024); + } + } + } + + // update the batch with the sampled/drafted tokens + for (auto * slot_ptr : generating) { + auto & slot = *slot_ptr; + slot.update_batch(batch); } // process in chunks of params.n_batch - int32_t n_batch = llama_n_batch(ctx); - int32_t n_ubatch = llama_n_ubatch(ctx); + int32_t n_batch = llama_n_batch(ctx_tgt); + int32_t n_ubatch = llama_n_ubatch(ctx_tgt); float alora_scale = -1.0f; size_t alora_disabled_id = 0; @@ -2303,12 +2346,12 @@ private: /*if (1) { // first 16 tokens (avoid flooding logs) for (int i = 0; i < std::min(16, input_tokens.size()); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_tgt, input_tokens[i]).c_str()); } } else { // all for (int i = 0; i < (int) input_tokens.size(); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_tgt, input_tokens[i]).c_str()); } }*/ @@ -2327,7 +2370,7 @@ private: } // TODO: support memory-less logits computation - if (slot.task->need_logits() && !llama_get_memory(ctx)) { + if (slot.task->need_logits() && !llama_get_memory(ctx_tgt)) { send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); slot.release(); continue; @@ -2379,7 +2422,7 @@ private: const auto n_cache_reuse = slot.task->params.n_cache_reuse; const bool can_cache_reuse = - llama_memory_can_shift(llama_get_memory(ctx)) && + llama_memory_can_shift(llama_get_memory(ctx_tgt)) && !slot.prompt.tokens.has_mtmd; if (!can_cache_reuse && n_cache_reuse > 0) { @@ -2413,13 +2456,18 @@ private: if (n_match >= (size_t) n_cache_reuse) { SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); //for (size_t i = head_p; i < head_p + n_match; i++) { - // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx_tgt, prompt_tokens[i]).c_str()); //} const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; - llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c); - llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); + llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, head_p, head_c); + llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, head_c, head_c + n_match, kv_shift); + + if (ctx_dft) { + llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, head_p, head_c); + llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, head_c, head_c + n_match, kv_shift); + } for (size_t i = 0; i < n_match; i++) { slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]); @@ -2446,7 +2494,7 @@ private: const auto pos_min_thold = std::max(0, pos_next - n_swa); if (n_past > 0 && n_past < slot.prompt.n_tokens()) { - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); if (pos_min == -1) { SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); @@ -2475,14 +2523,14 @@ private: { const auto token = slot.prompt.tokens[i]; - const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; + const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_tgt, token) : "[mtmd]"; ss0 << piece; st0 << std::setw(8) << token; } { const auto token = slot.task->tokens[i]; - const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; + const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_tgt, token) : "[mtmd]"; ss1 << piece; st1 << std::setw(8) << token; } @@ -2514,18 +2562,13 @@ private: if (!do_reset) { // restore the context checkpoint - const size_t checkpoint_size = it->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - if (n != checkpoint_size) { - SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024); - do_reset = true; - //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); - } else { - pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max)); - n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens); - SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) checkpoint_size / 1024 / 1024); - } + it->load_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + it->load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max)); + n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens); + SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) it->size() / 1024 / 1024); } if (do_reset) { @@ -2542,7 +2585,7 @@ private: for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { const auto & cur = *it; if (cur.pos_max > pos_next) { - SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024); it = slot.prompt.checkpoints.erase(it); } else { ++it; @@ -2582,14 +2625,18 @@ private: SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); - if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) { + if (!llama_memory_seq_rm(llama_get_memory(ctx_tgt), slot.id, p0, -1)) { SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); slot.prompt_clear(true); // there is no common part left slot.n_prompt_tokens_cache = 0; - } + } else { + if (ctx_dft && !llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, p0, -1)) { + GGML_ABORT("failed to truncate draft context\n"); + } + } // If using an alora, there may be uncached tokens that come // before the invocation sequence. When this happens, the @@ -2615,7 +2662,7 @@ private: // - the model does not support partial sequence removal // - the model uses SWA (and we are not using `swa_full`) do_checkpoint = do_checkpoint && ( - (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) || + (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) || (n_swa > 0)); bool has_mtmd = false; @@ -2624,7 +2671,7 @@ private: while (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { // process the image size_t n_tokens_out = 0; - int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); + int32_t res = input_tokens.process_chunk(ctx_tgt, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); if (res != 0) { SLT_ERR(slot, "failed to process image, res = %d\n", res); send_error(slot, "failed to process image", ERROR_TYPE_SERVER); @@ -2632,6 +2679,16 @@ private: continue; } + if (ctx_dft) { + // TODO: in the future, figure out how to infuse target embeddings to the images + // for now, we skip this for simplicity + // maybe we simply need to call `common_speculative_process()` on the mtmd batches in the `process_chunk` above? + res = input_tokens.process_chunk(ctx_dft.get(), mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); + if (res != 0) { + GGML_ABORT("failed to process multi-modal data on draft context\n"); + } + } + slot.n_prompt_tokens_processed += n_tokens_out; // add the image chunk to cache @@ -2733,8 +2790,8 @@ private: SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); } - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); - const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id); // no need for empty or small checkpoints do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64); @@ -2765,9 +2822,14 @@ private: SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + auto accept_special_token = [&](server_slot & slot, llama_token token) { + return params_base.special || + slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); + }; + if (slot_batched) { // apply lora, only need to do it once per batch - common_set_adapter_lora(ctx, slot_batched->lora); + common_set_adapter_lora(ctx_tgt, slot_batched->lora); // if the lora is temporarily disabled for an alora, re-enable it // for next time @@ -2776,7 +2838,7 @@ private: slot_batched->lora[alora_disabled_id].scale = alora_scale; } - llama_set_embeddings(ctx, slot_batched->task->need_embd()); + llama_set_embeddings(ctx_tgt, slot_batched->task->need_embd()); } if (batch.n_tokens == 0) { @@ -2805,7 +2867,7 @@ private: batch.logits + i, }; - const int ret = llama_decode(ctx, batch_view); + const int ret = llama_decode(ctx_tgt, batch_view); metrics.on_decoded(slots); @@ -2858,11 +2920,63 @@ private: continue; // continue loop of n_batch } + // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] + // for now, always re-evaluate for simplicity + // ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384 + // + // | spec type | need re-eval | + // | --- | --- | + // | draft model | no | because the draft model does not use embeddings from the target + // | MTP (std) | yes | + // | MTP Gemma4 | no | because the KV cache is shared + // | Eagle3 | yes | + // | DFlash | yes | https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4405406982 + // + // note: this logic is now moved in `common_speculative_process()` + // keeping the sketch here until for a bit, until the logic is finalized + // + //if (ctx_dft) { + // // TODO: update as needed for MTP, Eagle3, etc. + // const bool need_tgt_embd = false; + + // if (need_tgt_embd) { + // llama_synchronize(ctx_tgt); + // } + + // // the logic here varies depending on the speculative decoding method + // // - some draft contexts require embeddings from the target context, others don't + // // - some draft contexts involve an encoder step to transform the target embeddings to draft embeddings + // // TODO: extract this in a function ? + // { + // // TODO: hook the embeddings from the last target batch here + // if (llama_model_has_encoder(model_dft.get())) { + // //llama_encode(ctx_dft, ...); + + // GGML_ABORT("not implemented yet\n"); + // } + + // const int ret = llama_decode(ctx_dft.get(), batch_view); + + // if (ret != 0) { + // SRV_ERR("failed to decode draft batch, ret = %d\n", ret); + + // // TODO: handle error + // break; + // } + // } + //} + if (!common_speculative_process(spec.get(), batch_view)) { + SRV_ERR("%s", "failed to process speculative batch\n"); + + // TODO: handle error + break; + } + // move the head of the batch forward with the number of tokens we just processed i_next = i + n_tokens; // on successful decode, restore the original batch size - n_batch = llama_n_batch(ctx); + n_batch = llama_n_batch(ctx_tgt); // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too for (auto & slot : slots) { @@ -2921,7 +3035,7 @@ private: slot.state = SLOT_STATE_GENERATING; if (slot.can_speculate()) { - common_speculative_begin(slot.spec.get(), slot.prompt.tokens.get_text_tokens()); + common_speculative_begin(spec.get(), slot.id, slot.prompt.tokens.get_text_tokens()); } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots @@ -2933,7 +3047,7 @@ private: const int tok_idx = slot.i_batch - i; - llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx, tok_idx); + llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx); slot.i_batch = -1; @@ -2954,7 +3068,7 @@ private: completion_token_output result; result.tok = id; - result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok)); + result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs if (slot.task->params.sampling.n_probs > 0) { @@ -2985,23 +3099,23 @@ private: // verify and try to accept the draft { - const bool use_ckpt = slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; // only save the sampler sampler state if we use checkpoints common_sampler_ptr smpl_save; - if (use_ckpt) { + if (use_ckpt_tgt) { smpl_save.reset(common_sampler_clone(slot.smpl.get())); } GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); - auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft); + auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft); slot.spec_i_batch.clear(); GGML_ASSERT(accepted.size() >= 1); // check for partial draft acceptance if (accepted.size() < slot.spec_draft.size() + 1) { - if (use_ckpt) { + if (use_ckpt_tgt) { if (trace > 0) { SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size()); } @@ -3011,16 +3125,19 @@ private: const auto & ckpt = slot.spec_ckpt; - SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", - ckpt.pos_min, ckpt.pos_max, ckpt.size()); + SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size()); - const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - if (n != ckpt.size()) { - GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", - __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n); + { + ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + + llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, ckpt.pos_max + 1, -1); } - llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, ckpt.pos_max + 1, -1); + if (slot.ctx_dft) { + ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + + llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, ckpt.pos_max + 1, -1); + } slot.prompt.tokens.keep_first(ckpt.n_tokens); slot.smpl = std::move(smpl_save); @@ -3033,7 +3150,7 @@ private: SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft); } - common_speculative_accept(slot.spec.get(), accepted.size() - 1); + common_speculative_accept(spec.get(), slot.id, accepted.size() - 1); slot.spec_draft = std::move(accepted); } @@ -3055,13 +3172,16 @@ private: slot.sampled = ids.back(); // last accepted token SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); - llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, slot.prompt.tokens.pos_next(), -1); + llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, slot.prompt.tokens.pos_next(), -1); + if (slot.ctx_dft) { + llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, slot.prompt.tokens.pos_next(), -1); + } for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; result.tok = ids[i]; - result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok)); + result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // set later // TODO: set result.probs @@ -3113,7 +3233,7 @@ void server_context::terminate() { } llama_context * server_context::get_llama_context() const { - return impl->ctx; + return impl->ctx_tgt; } server_response_reader server_context::get_response_reader() { @@ -3123,8 +3243,8 @@ server_response_reader server_context::get_response_reader() { server_context_meta server_context::get_meta() const { auto bos_id = llama_vocab_bos(impl->vocab); auto eos_id = llama_vocab_eos(impl->vocab); - auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, bos_id, true) : ""; - auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, eos_id, true) : ""; + auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, bos_id, true) : ""; + auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, eos_id, true) : ""; return server_context_meta { /* build_info */ std::string(llama_build_info()), @@ -3137,7 +3257,7 @@ server_context_meta server_context::get_meta() const { /* has_inp_audio */ impl->chat_params.allow_audio, /* json_webui_settings */ impl->json_webui_settings, /* slot_n_ctx */ impl->get_slot_n_ctx(), - /* pooling_type */ llama_pooling_type(impl->ctx), + /* pooling_type */ llama_pooling_type(impl->ctx_tgt), /* chat_params */ impl->chat_params, /* chat_template_caps */ common_chat_templates_get_caps(impl->chat_params.tmpls.get()), @@ -3155,10 +3275,10 @@ server_context_meta server_context::get_meta() const { /* model_vocab_type */ llama_vocab_type(impl->vocab), /* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab), - /* model_n_ctx_train */ llama_model_n_ctx_train(impl->model), - /* model_n_embd_inp */ llama_model_n_embd(impl->model), - /* model_n_params */ llama_model_n_params(impl->model), - /* model_size */ llama_model_size(impl->model), + /* model_n_ctx_train */ llama_model_n_ctx_train(impl->model_tgt), + /* model_n_embd_inp */ llama_model_n_embd(impl->model_tgt), + /* model_n_params */ llama_model_n_params(impl->model_tgt), + /* model_size */ llama_model_size(impl->model_tgt), }; } @@ -3502,10 +3622,6 @@ void server_routes::init_routes() { {"name", "n_tokens_max"}, {"help", "Largest observed n_tokens."}, {"value", res_task->n_tokens_max} - }, { - {"name", "n_busy_slots_per_decode"}, - {"help", "Average number of busy slots per llama_decode() call"}, - {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)} }}}, {"gauge", {{ {"name", "prompt_tokens_seconds"}, @@ -3523,6 +3639,10 @@ void server_routes::init_routes() { {"name", "requests_deferred"}, {"help", "Number of requests deferred."}, {"value", (uint64_t) res_task->n_tasks_deferred} + },{ + {"name", "n_busy_slots_per_decode"}, + {"help", "Average number of busy slots per llama_decode() call"}, + {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)} }}} }; @@ -4045,7 +4165,7 @@ void server_routes::init_routes() { std::vector tasks; tasks.reserve(documents.size()); for (size_t i = 0; i < documents.size(); i++) { - auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); + auto tmp = format_prompt_rerank(ctx_server.model_tgt, ctx_server.vocab, ctx_server.mctx, query, documents[i]); server_task task = server_task(SERVER_TASK_TYPE_RERANK); task.id = rd.get_new_id(); task.tokens = std::move(tmp); diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 45e5168fa..b9b4a704a 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -76,7 +76,7 @@ json task_params::to_json(bool only_metrics) const { {"reasoning_in_content", chat_parser_params.reasoning_in_content}, {"generation_prompt", chat_parser_params.generation_prompt}, {"samplers", samplers}, - {"speculative.type", common_speculative_type_to_str(speculative.type)}, + {"speculative.types", common_speculative_type_name_str(speculative.types)}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, {"backend_sampling", sampling.backend_sampling}, @@ -133,7 +133,7 @@ json task_params::to_json(bool only_metrics) const { {"reasoning_in_content", chat_parser_params.reasoning_in_content}, {"generation_prompt", chat_parser_params.generation_prompt}, {"samplers", samplers}, - {"speculative.type", common_speculative_type_to_str(speculative.type)}, + {"speculative.types", common_speculative_type_name_str(speculative.types)}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, {"backend_sampling", sampling.backend_sampling}, @@ -296,6 +296,8 @@ task_params server_task::params_from_json_cmpl( params.speculative = defaults.speculative; + // TODO: to keep things simple, we disable speculative parameter adjustments for now +#if 0 // TODO: for now, be able to adjust only the draft-model based speculative parameters params.speculative.draft.n_min = json_value(data, "speculative.n_min", defaults.speculative.draft.n_min); params.speculative.draft.n_max = json_value(data, "speculative.n_max", defaults.speculative.draft.n_max); @@ -305,7 +307,6 @@ task_params server_task::params_from_json_cmpl( params.speculative.draft.n_min = std::max(params.speculative.draft.n_min, 0); params.speculative.draft.n_max = std::max(params.speculative.draft.n_max, 0); -#if 0 // for debugging and research purposes params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type))); @@ -1981,7 +1982,7 @@ size_t server_prompt_cache::n_tokens() const { return res; } -server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) { +server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_tgt, size_t state_size_dft) { // first check if the current state is contained fully in the cache for (auto it = states.begin(); it != states.end(); ++it) { const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens); @@ -2005,11 +2006,13 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t } } - std::vector state_data; + std::vector state_data_tgt; + std::vector state_data_dft; // check if we can allocate enough memory for the new state try { - state_data.resize(state_size); + state_data_tgt.resize(state_size_tgt); + state_data_dft.resize(state_size_dft); } catch (const std::bad_alloc & e) { SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what()); @@ -2022,17 +2025,19 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t return nullptr; } - auto & cur = states.emplace_back(); - cur = { + states.push_back({ /*.tokens =*/ prompt.tokens.clone(), - /*.data =*/ std::move(state_data), + /*.data =*/ { + /*.main =*/ std::move(state_data_tgt), + /*.drft =*/ std::move(state_data_dft), + }, /*.checkpoints =*/ prompt.checkpoints, - }; + }); - return &cur; + return &states.back(); } -bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { +bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_tgt, llama_context * ctx_dft, int32_t id_slot) { const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins @@ -2065,16 +2070,39 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok if (it_best != states.end()) { SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); - const size_t size = it_best->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0); - if (n != size) { - SRV_WRN("failed to restore state with size %zu\n", size); + { + auto & data = it_best->data.main; - return false; + const size_t size = data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx_tgt, data.data(), size, id_slot, 0); + if (n != size) { + SRV_WRN("failed to restore state with size %zu\n", size); + + return false; + } + + data.clear(); + data.shrink_to_fit(); } - it_best->data.clear(); - it_best->data.shrink_to_fit(); + { + auto & data = it_best->data.drft; + + if (!data.empty()) { + GGML_ASSERT(ctx_dft); + + const size_t size = data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx_dft, data.data(), size, id_slot, 0); + if (n != size) { + SRV_WRN("failed to restore state with size %zu\n", size); + + return false; + } + + data.clear(); + data.shrink_to_fit(); + } + } prompt = std::move(*it_best); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 289e1fb8d..64bdecd79 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -565,42 +565,29 @@ struct server_task_result_apply_lora : server_task_result { virtual json to_json() override; }; -struct server_prompt_checkpoint { - llama_pos pos_min; - llama_pos pos_max; - - int64_t n_tokens; - - std::vector data; +struct server_prompt_data { + std::vector main; + std::vector drft; size_t size() const { - return data.size(); - } - - bool empty() const { - return data.empty(); - } - - void clear() { - pos_min = 0; - pos_max = 0; - n_tokens = 0; - data.clear(); + return main.size() + drft.size(); } }; struct server_prompt { server_tokens tokens; - std::vector data; + server_prompt_data data; - std::list checkpoints; + std::list checkpoints; size_t size() const { - size_t res = data.size(); + size_t res = 0; - for (const auto & checkpoint : checkpoints) { - res += checkpoint.size(); + res += data.size(); + + for (const auto & ckpt : checkpoints) { + res += ckpt.size(); } return res; @@ -614,7 +601,7 @@ struct server_prompt { return server_prompt { tokens.clone(), data, - checkpoints + checkpoints, }; } }; @@ -637,9 +624,9 @@ struct server_prompt_cache { size_t n_tokens() const; - server_prompt * alloc(const server_prompt & prompt, size_t state_size); + server_prompt * alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_drft); - bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot); + bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_main, llama_context * ctx_drft, int32_t id_slot); void update(); }; diff --git a/tools/server/tests/unit/test_speculative.py b/tools/server/tests/unit/test_speculative.py index eebd3cc8f..84cd77e6f 100644 --- a/tools/server/tests/unit/test_speculative.py +++ b/tools/server/tests/unit/test_speculative.py @@ -5,7 +5,7 @@ from utils import * server = ServerPreset.stories15m_moe() -MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf" +MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/tiny-llamas/resolve/main/stories15M-q4_0.gguf" def create_server(): global server