From f5636f8fc77e403596a8524d789c4a16350837ac Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 11 May 2026 12:07:17 +0200 Subject: [PATCH 01/17] convert : add image break token fallback (#22914) * convert : add image break token fallback This commit adds a image_break_token_id fallback for mistral where the config contains a image_break_token_id of -1: ```console "vision_encoder": { "image_token_id": 10, "image_break_token_id": -1, ... ``` But the tokenizer.json has this token: ```console 115 "id": 12, 116 "content": "[IMG_BREAK]", 117 "single_word": false, 118 "lstrip": false, 119 "rstrip": false, 120 "normalized": false, 121 "special": true 122 }, ``` If we look in convert_hf_to_gguf.py we have: ```python elif self.is_mistral_format: # hparams is already vision config here so norm_eps is only defined in global_config. self.hparams["norm_eps"] = self.global_config.get("norm_eps", None) assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json" if self.use_break_tok: self.img_break_tok_id = self.find_vparam(["image_break_token_id"]) ``` The motivation for this is that currently converting this models results in the following error: ```console load_hparams: model size: 5131.60 MiB load_hparams: metadata size: 0.15 MiB clip_init: failed to load model 'models/mmproj-Mistral-Medium-3.5-128B.gguf': operator(): unable to find tensor v.token_embd.img_break mtmd_init_from_file: error: Failed to load CLIP model from models/mmproj-Mistral-Medium-3.5-128B.gguf Failed to load vision model from models/mmproj-Mistral-Medium-3.5-128B.gguf ``` With this fallback the model loads successfully. Resolves: https://github.com/ggml-org/llama.cpp/issues/22901 * Revert "convert : add image break token fallback" This reverts commit 292e40cfdf9a7553863007c018236f5f554f71d8. * convert : add image break token fallback This commit adds a image_break_token_id fallback for mistral where the config contains a image_break_token_id of -1: ```console "vision_encoder": { "image_token_id": 10, "image_break_token_id": -1, ... ``` But the tokenizer.json has this token: ```console 115 "id": 12, 116 "content": "[IMG_BREAK]", 117 "single_word": false, 118 "lstrip": false, 119 "rstrip": false, 120 "normalized": false, 121 "special": true 122 }, ``` If we look in convert_hf_to_gguf.py we have: ```python elif self.is_mistral_format: # hparams is already vision config here so norm_eps is only defined in global_config. self.hparams["norm_eps"] = self.global_config.get("norm_eps", None) assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json" if self.use_break_tok: self.img_break_tok_id = self.find_vparam(["image_break_token_id"]) ``` The motivation for this is that currently converting this models results in the following error: ```console load_hparams: model size: 5131.60 MiB load_hparams: metadata size: 0.15 MiB clip_init: failed to load model 'models/mmproj-Mistral-Medium-3.5-128B.gguf': operator(): unable to find tensor v.token_embd.img_break mtmd_init_from_file: error: Failed to load CLIP model from models/mmproj-Mistral-Medium-3.5-128B.gguf Failed to load vision model from models/mmproj-Mistral-Medium-3.5-128B.gguf ``` With this fallback the model loads successfully. Co-authored-by: Pascal Resolves: https://github.com/ggml-org/llama.cpp/issues/22901 * convert : allow zero value for img_break_tok_id --- convert_hf_to_gguf.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e5dea18ae..bf76fa406 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2176,7 +2176,8 @@ class MmprojModel(ModelBase): text_config = { k: v for k, v in self.hparams.items() if k not in ["vision_encoder", "audio_encoder"] } - self.n_embd_text = text_config.get("hidden_dim", 0) + # mistral native params.json: "dim" is the text hidden size ("hidden_dim" is the FFN intermediate size) + self.n_embd_text = text_config.get("dim", 0) assert self.n_embd_text > 0, "n_embd not found in hparams" @@ -3137,6 +3138,11 @@ class LlavaVisionModel(MmprojModel): assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json" if self.use_break_tok: self.img_break_tok_id = self.find_vparam(["image_break_token_id"]) + + # params.json may ship -1 placeholders (Mistral Medium 3.5) + # resolve the real id from the bundled tokenizer in that case + if self.img_break_tok_id < 0: + self.img_break_tok_id = self.get_mistral_token_id("[IMG_BREAK]") else: raise ValueError(f"Unsupported model type: {self.hparams['model_type']}") logger.info(f"Image break token id: {self.img_break_tok_id}") @@ -3156,6 +3162,24 @@ class LlavaVisionModel(MmprojModel): return int(token_data["id"]) raise ValueError(f"Token '{token}' not found in tokenizer config.") + def get_mistral_token_id(self, token: str) -> int: + # mistral native format ships tekken.json or a versioned spm tokenizer + tekken_file = self.dir_model / "tekken.json" + if tekken_file.is_file(): + with open(tekken_file, "r", encoding="utf-8") as f: + data = json.load(f) + for entry in data.get("special_tokens", []): + if entry.get("token_str") == token: + return int(entry["rank"]) + tokenizer_json_file = self.dir_model / "tokenizer.json" + if tokenizer_json_file.is_file(): + with open(tokenizer_json_file, "r", encoding="utf-8") as f: + data = json.load(f) + for entry in data.get("added_tokens", []): + if entry.get("content") == token: + return int(entry["id"]) + raise ValueError(f"Token '{token}' not found in mistral tokenizer files.") + def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams From 8cef8201a1e0213662abbfcbcd3ff2eb773174df Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Mon, 11 May 2026 12:16:38 +0200 Subject: [PATCH 02/17] CUDA: directly include cuda/iterator (#22936) Before, we relied on a transient import from `cub/cub.cuh`, which is bad practice to do as cub may not always expose cuda/iterator --- ggml/src/ggml-cuda/argsort.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 0f3f017b5..c4f08091e 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -4,6 +4,7 @@ # include # if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1) # define STRIDED_ITERATOR_AVAILABLE +# include # endif using namespace cub; #endif // GGML_CUDA_USE_CUB From dd9280a6643d2c4931df7c9246b2f344c0a0513a Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Mon, 11 May 2026 05:49:03 -0500 Subject: [PATCH 03/17] vulkan: Support asymmetric FA in scalar/mmq/coopmat1 paths (#22589) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 255 +++++++-------- .../vulkan-shaders/flash_attn.comp | 176 +++++----- .../vulkan-shaders/flash_attn_base.glsl | 206 +++--------- .../vulkan-shaders/flash_attn_cm1.comp | 154 ++++----- .../vulkan-shaders/flash_attn_cm2.comp | 43 +-- .../vulkan-shaders/flash_attn_dequant.glsl | 123 +++++++ .../vulkan-shaders/flash_attn_mmq_funcs.glsl | 304 +++++++++++------- .../vulkan-shaders/mul_mmq_shmem_types.glsl | 11 +- .../vulkan-shaders/vulkan-shaders-gen.cpp | 40 +-- 9 files changed, 632 insertions(+), 680 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 0a7931002..7e450a559 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -855,7 +855,7 @@ struct vk_device_struct { vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32; vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32; - std::map pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT]; + std::map pipeline_flash_attn_f32_f16; std::map, vk_pipeline> pipeline_fa_mask_opt; @@ -2933,10 +2933,10 @@ struct vk_fa_tuning_params { } }; -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type); +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type); static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); -static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { +static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { vk_fa_tuning_params result{}; result.path = FA_SCALAR; @@ -2988,7 +2988,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0; - if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) { + if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, k_type, v_type)) { result.block_rows /= 2; } @@ -3011,10 +3011,11 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, return result; } -static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { +static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { GGML_UNUSED(n_rows); GGML_UNUSED(n_kv); - GGML_UNUSED(kv_type); + GGML_UNUSED(k_type); + GGML_UNUSED(v_type); GGML_UNUSED(f32acc); vk_fa_tuning_params result{}; @@ -3070,12 +3071,6 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device } static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { - // Mixed K/V is only implemented on the coopmat2 (flash_attn_cm2) path; never use scalar/cm1. - if (k_type != v_type) { - GGML_ASSERT(device->coopmat2); - return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); - } - FaCodePath path = device->coopmat2 ? FA_COOPMAT2 : device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; @@ -3087,7 +3082,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ if (path == FA_COOPMAT1) { bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) || (!f32acc && device->coopmat_support_16x16x16_f16acc); - const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); + const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc); if (!shape_ok || !shmem_ok) { @@ -3107,9 +3102,9 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ switch (path) { case FA_SCALAR: - return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); + return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); case FA_COOPMAT1: - return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); + return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); case FA_COOPMAT2: return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); default: @@ -3279,6 +3274,20 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev return 0; // If no matching configuration is found } +// Whether scalar flash attention will use the MMQ path for the given k_type. +static bool ggml_vk_fa_scalar_uses_mmq(const vk_device& device, ggml_type k_type) { +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + return device->integer_dot_product && device->subgroup_clustered && + (k_type == GGML_TYPE_Q4_0 || k_type == GGML_TYPE_Q4_1 || + k_type == GGML_TYPE_Q5_0 || k_type == GGML_TYPE_Q5_1 || + k_type == GGML_TYPE_Q8_0); +#else + GGML_UNUSED(device); + GGML_UNUSED(k_type); + return false; +#endif +} + static void ggml_vk_load_shaders(vk_device& device) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); @@ -3525,121 +3534,96 @@ static void ggml_vk_load_shaders(vk_device& device) { align, disable_robustness, require_full_subgroups, required_subgroup_size); }; -#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ - for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \ - FaCodePath path = fa.first.path; \ - uint32_t Br = fa.first.Br; \ - uint32_t Bc = fa.first.Bc; \ - bool aligned = fa.first.aligned; \ - bool f32acc = fa.first.f32acc; \ - uint32_t fa_sgs = fa.first.subgroup_size; \ - bool fa_ds = fa.first.subgroup_size == 0; \ - if (path == FAPATH) { \ - if (aligned) { \ - if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ - } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ - } \ - } else { \ - if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ - } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ - } \ - } \ - } \ - } + // FA scalar has two SPIR-V modules (MMQ vs non-MMQ); FA cm1 has one. K/V + // quant type is selected at runtime via the FaTypeK / FaTypeV spec constants. - if (device->fp16) { - CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) + for (auto &fa : device->pipeline_flash_attn_f32_f16) { + if (fa.first.path != FA_SCALAR) continue; + const uint32_t Br = fa.first.Br; + const uint32_t Bc = fa.first.Bc; + const bool aligned = fa.first.aligned; + const bool f32acc = fa.first.f32acc; + const uint32_t fa_sgs = fa.first.subgroup_size; + const bool fa_ds = fa.first.subgroup_size == 0; + const bool use_mmq = ggml_vk_fa_scalar_uses_mmq(device, fa.first.k_type); + const void * spv_data = nullptr; + size_t spv_size = 0; + if (use_mmq) { #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (device->integer_dot_product && device->subgroup_clustered) { - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8) - } else + if (device->fp16) { + if (f32acc) { spv_data = flash_attn_f32_f16_int8_data; spv_size = flash_attn_f32_f16_int8_len; } + else { spv_data = flash_attn_f32_f16_f16acc_int8_data; spv_size = flash_attn_f32_f16_f16acc_int8_len; } + } else { + spv_data = flash_attn_f32_f16_fp32_int8_data; + spv_size = flash_attn_f32_f16_fp32_int8_len; + } #endif - { - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, ) - } - } else { - CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32) - -#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (device->integer_dot_product && device->subgroup_clustered) { - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8) - } else -#endif - { - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32) + } else { + if (device->fp16) { + if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; } + else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; } + } else { + spv_data = flash_attn_f32_f16_fp32_data; + spv_size = flash_attn_f32_f16_fp32_len; + } } + const char *name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; + ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, + sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, + get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, + !fa_ds, !fa_ds ? fa_sgs : 0); } + #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat1_fa_support) { - CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT1, _cm1) - } -#endif -#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) -#define CREATE_FA_CM2_MIXED() \ - for (int fa_k_ty = 0; fa_k_ty < (int)GGML_TYPE_COUNT; ++fa_k_ty) { \ - for (auto &fa : device->pipeline_flash_attn_f32_f16[fa_k_ty]) { \ - FaCodePath path = fa.first.path; \ - uint32_t Br = fa.first.Br; \ - uint32_t Bc = fa.first.Bc; \ - bool aligned = fa.first.aligned; \ - bool f32acc = fa.first.f32acc; \ - if (path == FA_COOPMAT2) { \ - if (aligned) { \ - if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \ - } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \ - } \ - } else { \ - if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \ - } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \ - } \ - } \ - } \ - } \ + for (auto &fa : device->pipeline_flash_attn_f32_f16) { + if (fa.first.path != FA_COOPMAT1) continue; + const uint32_t Br = fa.first.Br; + const uint32_t Bc = fa.first.Bc; + const bool aligned = fa.first.aligned; + const bool f32acc = fa.first.f32acc; + const uint32_t fa_sgs = fa.first.subgroup_size; + const bool fa_ds = fa.first.subgroup_size == 0; + + const void * spv_data; + size_t spv_size; + if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; } + else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; } + const char *name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1"; + ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, + sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, + get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, + !fa_ds, !fa_ds ? fa_sgs : 0); + } + } +#endif + +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + for (auto &fa : device->pipeline_flash_attn_f32_f16) { + if (fa.first.path != FA_COOPMAT2) continue; + const uint32_t Br = fa.first.Br; + const uint32_t Bc = fa.first.Bc; + const bool aligned = fa.first.aligned; + const bool f32acc = fa.first.f32acc; + + const void * spv_data; + size_t spv_size; + const char * name; + if (aligned) { + if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_aligned_f32acc_cm2"; } + else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_aligned_f16acc_cm2"; } + } else { + if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_f32acc_cm2"; } + else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_f16acc_cm2"; } + } + ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, + sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, + get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, false, 0); } - if (device->coopmat2) { - CREATE_FA_CM2_MIXED(); } -#undef CREATE_FA_CM2_MIXED #endif -#undef CREATE_FA const int mul_mat_id_param_count = 5; @@ -8940,8 +8924,9 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type) { GGML_UNUSED(f32acc); + GGML_UNUSED(v_type); // Needs to be kept up to date on shader changes const uint32_t wg_size = params.workgroup_size; const uint32_t Br = params.block_rows; @@ -8949,10 +8934,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); - const bool mmq = device->integer_dot_product && device->subgroup_clustered && - (kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 || - kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 || - kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL); + const bool mmq = ggml_vk_fa_scalar_uses_mmq(device, k_type); // tmpsh is overestimated slightly const uint32_t tmpsh = wg_size * sizeof(float); @@ -8969,17 +8951,10 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con // kvsh uses D = HSV (K goes through kblocksh instead) kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size; - // block_a_cache size depends on quant type - uint32_t block_a_size; - switch (kv_type) { - case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break; - case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break; - case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break; - case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break; - case GGML_TYPE_Q8_0: - case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break; - default: block_a_size = 0; break; - } + // The mixed MMQ shader uses a superset block_a_cache that fits every + // FA-supported quant: int32_t qs[8] + uint32_t qh + FLOAT_TYPEV2 dm. + // Single-scale types leave dm.y unused; non-Q5_* leave qh unused. + const uint32_t block_a_size = 8 * sizeof(int32_t) + sizeof(uint32_t) + 2 * float_type_size; kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size; } else { Qf = Br * (hsk / 4 + 1) * 4 * float_type_size; @@ -9117,10 +9092,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc); - if (tuning_params.path != FA_COOPMAT2) { - GGML_ASSERT(k->type == v->type); - } - const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); @@ -9164,7 +9135,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx { std::lock_guard guard(ctx->device->mutex); - auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type]; + auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16; auto it = pipelines.find(fa_pipeline_state); if (it != pipelines.end()) { pipeline = it->second; @@ -15642,10 +15613,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { return false; } - // mismatching K/V type is currently supported for coopmat2 only. - if (op->src[1]->type != op->src[2]->type && !coopmat2) { - return false; - } auto fa_kv_ok = [coopmat2](ggml_type t) { switch (t) { case GGML_TYPE_F32: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 6e6bdabc9..6ac095489 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -22,6 +22,7 @@ #include "types.glsl" #include "flash_attn_base.glsl" +#include "flash_attn_dequant.glsl" const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; @@ -128,18 +129,20 @@ void main() { Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals)); -#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) - if (buf_iqs == 0) { - Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0); - } -#else // Q4_0, Q4_1, Q5_0, Q5_1 - const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w; - const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8); + // Q8_0 K only needs (qd, _); the asymmetric Q4_*/Q5_* family also stores + // the row-sum scaled by qd, used in k_dot_correction. + if (FaTypeK == FA_TYPE_Q8_0) { + if (buf_iqs == 0) { + Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0); + } + } else { + const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w; + const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8); - if (buf_iqs == 0) { - Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd); + if (buf_iqs == 0) { + Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd); + } } -#endif #endif } barrier(); @@ -177,13 +180,9 @@ void main() { // mo_offset will point to the tile starting at row i*Br and col 0 uint32_t mo_offset = mo_stride * i; -#if BLOCK_SIZE > 1 - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; -#else - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; -#endif + // FaBlockBytesK/V == 2 for f16, 16 for f32, ggml block byte size for quants. + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV; uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; @@ -257,21 +256,21 @@ void main() { if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) { FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); if (!KV_bounds_check || j * Bc + c < KV) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else - K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); -#endif + if (USE_DECODE_K) { + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d; + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + } else { + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + } } kvsh[c * kvsh_stride + d] = K_Tf; } } #else // MMQ - const uint ints_per_block = 8 / QUANT_R_MMQ; + const uint ints_per_block = 8u / fa_quant_r_mmq(FaTypeK); const uint quant_iters = Bc * HSK / 32 * ints_per_block; [[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) { const uint32_t iqs = (idx + tid) % ints_per_block; @@ -310,15 +309,13 @@ void main() { FLOAT_TYPEV4 K_Tf; if (SHMEM_STAGING != 0) { K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; - } else { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); + } else if (USE_DECODE_K) { + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else + } else { K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); -#endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf)); @@ -335,15 +332,13 @@ void main() { FLOAT_TYPEV4 K_Tf; if (SHMEM_STAGING != 0) { K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; - } else { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); + } else if (USE_DECODE_K) { + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else + } else { K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); -#endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf)); @@ -366,72 +361,47 @@ void main() { int32_t k_quants[d_per_step]; ACC_TYPEV2 k_dm; + // Q4_*/Q5_* take the block-8 fast path when one step covers a full + // block; Q8_0 always goes through the per-int get_k_qs* helpers + // (its qs is byte-packed, not nibble-packed). + const bool block8_fast = (d_per_step == 8) && (FaTypeK != FA_TYPE_Q8_0); + if (SHMEM_STAGING != 0) { const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8; const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx; -#if QUANT_AUXF == 1 - k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0); -#else k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm); -#endif -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - if (d_per_step == 8) { + if (block8_fast) { + const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1); [[unroll]] for (uint32_t d = 0; d < 4; d++) { uint vui = kblocksh[buf_ib].qs[d]; k_quants[d ] = int32_t( vui & 0x0F0F0F0F); k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); -#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF; - uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF; - k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); - k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); -#endif + if (has_qh) { + uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF; + uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF; + k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); + k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); + } } - } else -#endif - { + } else { [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d); } } } else { - const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block); - const uint ib = coord / BLOCK_SIZE; - const uint iqs = (coord % BLOCK_SIZE); + const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d_tid * (HSK_per_thread / 4) + d_block); + const uint ib = coord / BLOCK_SIZE_K; + const uint iqs = (coord % BLOCK_SIZE_K); -#if QUANT_AUXF == 1 - k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0); -#else - k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset)); -#endif -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - if (d_per_step == 8) { -#if defined(DATA_A_Q5_0) - uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0], - k_packed.k_data_packed16[k_offset + ib].qh[1])); -#elif defined(DATA_A_Q5_1) - uint qh = k_packed.k_data_packed16[k_offset + ib].qh; -#endif - [[unroll]] for (uint32_t d = 0; d < 4; d++) { -#if defined(A_TYPE_PACKED32) - uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d]; -#else - uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0], - k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1])); -#endif - k_quants[d ] = int32_t( vui & 0x0F0F0F0F); - k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); -#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - uint qh_lo = (qh >> (d * 4)) & 0xF; - uint qh_hi = (qh >> (d * 4 + 16)) & 0xF; - k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); - k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); -#endif + k_dm = ACC_TYPEV2(get_k_scale(ib, k_offset)); + + if (block8_fast) { + fa_k_qs_block8 blk = get_k_qs_block8(ib, k_offset); + [[unroll]] for (uint32_t d = 0; d < 8; d++) { + k_quants[d] = blk.qs[d]; } - } else -#endif - { + } else { [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset); } @@ -516,14 +486,14 @@ void main() { if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) { FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0); if (!KV_bounds_check || j * Bc + c < KV) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); -#else - V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); -#endif + if (USE_DECODE_V) { + uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d; + uint ib = coord / BLOCK_SIZE_V; + uint iqs = (coord % BLOCK_SIZE_V); + V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + } else { + V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); + } } kvsh[c * kvsh_stride + d] = V_Tf; @@ -547,15 +517,13 @@ void main() { FLOAT_TYPEV4 Vf; if (SHMEM_STAGING != 0) { Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; - } else { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); + } else if (USE_DECODE_V) { + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE_V + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE_V; + uint iqs = (coord % BLOCK_SIZE_V); Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); -#else + } else { Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); -#endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index efed3a73e..9a7957da9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -87,176 +87,58 @@ layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];}; #define BINDING_IDX_K 0 #define BINDING_IDX_V 1 -#if defined(DATA_A_F32) -layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed; -layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed; -#elif defined(A_TYPE_PACKED16) -layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed; -layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; -#endif -#if defined(A_TYPE_PACKED32) -layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32; -layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32; -#endif +// FaTypeK / FaTypeV spec constant values. These mirror enum ggml_type so the +// host can pass the type directly. Keep in sync with ggml.h. +#define FA_TYPE_F32 0u +#define FA_TYPE_F16 1u +#define FA_TYPE_Q4_0 2u +#define FA_TYPE_Q4_1 3u +#define FA_TYPE_Q5_0 6u +#define FA_TYPE_Q5_1 7u +#define FA_TYPE_Q8_0 8u +#define FA_TYPE_Q1_0 41u -#ifndef BLOCK_SIZE -#define BLOCK_SIZE 1 -#endif - -#if defined(DATA_A_F32) -#undef BLOCK_SIZE -#define BLOCK_SIZE 4 -#define BLOCK_BYTE_SIZE 16 - -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - // iqs is currently always zero in the flash attention shaders - if (binding_idx == BINDING_IDX_K) { - return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]); - } else { - return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]); +// Number of matrix elements per buffer block, derived from the K/V type spec +// constant. F32 is treated as a vec4 "block" of 4 floats. F16 uses block size 1 +// and bypasses the dequant path entirely. Quants follow their ggml block sizes. +uint fa_block_elems(uint ty) { + switch (ty) { + case FA_TYPE_F32: return 4u; + case FA_TYPE_F16: return 1u; + case FA_TYPE_Q4_0: return uint(QUANT_K_Q4_0); + case FA_TYPE_Q4_1: return uint(QUANT_K_Q4_1); + case FA_TYPE_Q5_0: return uint(QUANT_K_Q5_0); + case FA_TYPE_Q5_1: return uint(QUANT_K_Q5_1); + case FA_TYPE_Q8_0: return uint(QUANT_K_Q8_0); + case FA_TYPE_Q1_0: return uint(QUANT_K_Q1_0); // cm2-only, harmless elsewhere + default: return 1u; } } -#endif -#if defined(DATA_A_Q4_0) -#define BLOCK_BYTE_SIZE 18 -#elif defined(DATA_A_Q4_1) -#define BLOCK_BYTE_SIZE 20 -#endif - -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - if (binding_idx == BINDING_IDX_K) { - uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); -#ifdef DATA_A_Q4_1 - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m); -#else - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); -#endif - } else { - uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); -#ifdef DATA_A_Q4_1 - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m); -#else - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); -#endif +// QUANT_R_MMQ for FA-eligible K types. Q4_*/Q5_* store two nibbles per byte +// (R==2); Q8_0 stores one byte per element (R==1). Used to derive the number +// of int32s per 32-element block on the MMQ K path: ints_per_block == 8 / R. +uint fa_quant_r_mmq(uint ty) { + switch (ty) { + case FA_TYPE_Q4_0: return uint(QUANT_R_Q4_0); + case FA_TYPE_Q4_1: return uint(QUANT_R_Q4_1); + case FA_TYPE_Q5_0: return uint(QUANT_R_Q5_0); + case FA_TYPE_Q5_1: return uint(QUANT_R_Q5_1); + case FA_TYPE_Q8_0: return uint(QUANT_R_Q8_0); + default: return 1u; } } -#endif -#if defined(DATA_A_Q5_0) -#define BLOCK_BYTE_SIZE 22 -#elif defined(DATA_A_Q5_1) -#define BLOCK_BYTE_SIZE 24 -#endif - -#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - if (binding_idx == BINDING_IDX_K) { - uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - -#ifdef DATA_A_Q5_1 - uint qh = k_packed.k_data_packed16[a_offset + ib].qh; -#else - uint qh = uint(k_packed.k_data_packed16[a_offset + ib].qh[0]) | (uint(k_packed.k_data_packed16[a_offset + ib].qh[1]) << 16); -#endif - FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f); - - FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); -#ifdef DATA_A_Q5_1 - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m); -#else - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); -#endif - } else { - uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - -#ifdef DATA_A_Q5_1 - uint qh = v_packed.v_data_packed16[a_offset + ib].qh; -#else - uint qh = uint(v_packed.v_data_packed16[a_offset + ib].qh[0]) | (uint(v_packed.v_data_packed16[a_offset + ib].qh[1]) << 16); -#endif - FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f); - - FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); -#ifdef DATA_A_Q5_1 - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m); -#else - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); -#endif - } -} -#endif - - -#if defined(DATA_A_IQ4_NL) -#define BLOCK_BYTE_SIZE 18 - -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - if (binding_idx == BINDING_IDX_K) { - uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4( - kvalues_iq4nl[vui_lo & 0xF], - kvalues_iq4nl[(vui_lo >> 8) & 0xF], - kvalues_iq4nl[vui_hi & 0xF], - kvalues_iq4nl[(vui_hi >> 8) & 0xF]); - } else { - uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4( - kvalues_iq4nl[vui_lo & 0xF], - kvalues_iq4nl[(vui_lo >> 8) & 0xF], - kvalues_iq4nl[vui_hi & 0xF], - kvalues_iq4nl[(vui_hi >> 8) & 0xF]); - } -} -#endif -#if defined(DATA_A_Q8_0) -#define BLOCK_BYTE_SIZE 34 -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - if (binding_idx == BINDING_IDX_K) { - const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); - } else { - const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); - } -} -#endif +// These can't be `const` globals because GLSL forbids function calls in global +// const initializers, even when the spec constants would let the driver fold +// them. Macros expand at the use site and fold after specialization. +#define BLOCK_SIZE_K fa_block_elems(FaTypeK) +#define BLOCK_SIZE_V fa_block_elems(FaTypeV) +// F16 reads f16 elements directly from the binding; everything else routes +// through dequantize4 / the MMQ helpers to unpack from the packed block layout. +#define USE_DECODE_K (FaTypeK != FA_TYPE_F16) +#define USE_DECODE_V (FaTypeV != FA_TYPE_F16) #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 526e8da38..bffcc095b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -14,6 +14,7 @@ #include "types.glsl" #include "flash_attn_base.glsl" +#include "flash_attn_dequant.glsl" // These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd const uint32_t MatBr = 16; @@ -127,13 +128,9 @@ void main() { // mo_offset will point to the tile starting at row i*Br and col 0 uint32_t mo_offset = mo_stride * i; -#if BLOCK_SIZE > 1 - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; -#else - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; -#endif + // FaBlockBytesK/V == 2 for f16 (sizeof f16) and == 16 for f32 (vec4) and == ggml block size for quants. + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV; uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; @@ -227,14 +224,14 @@ void main() { if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) { f16vec4 K_Tf = f16vec4(0); if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); -#endif + if (USE_DECODE_K) { + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d; + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + } else { + K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + } } kvsh[c * kvsh_stride + d] = K_Tf; @@ -256,47 +253,40 @@ void main() { // staged through a Bc * MatBr size staging buffer. // If K is not type f16, then it is always staged for dequantization. if (SHMEM_STAGING == 0) { -#if BLOCK_SIZE == 1 - if (KV_bounds_check || d * 16 + 16 > HSK) { -#endif - barrier(); - [[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) { - uint32_t col_vec = (idx + tid) % (MatBr / 4); - uint32_t row = (idx + tid) / (MatBr / 4); - if (idx + tid < Bc * MatBr / 4) { - f16vec4 K_Tf = f16vec4(0); - if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); -#endif + // For quants we always need to dequant into kvsh; for f16 we can load + // directly from global memory when alignment / bounds allow it. + const bool stage_k = USE_DECODE_K || KV_bounds_check || d * 16 + 16 > HSK; + if (stage_k) { + barrier(); + [[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) { + uint32_t col_vec = (idx + tid) % (MatBr / 4); + uint32_t row = (idx + tid) / (MatBr / 4); + if (idx + tid < Bc * MatBr / 4) { + f16vec4 K_Tf = f16vec4(0); + if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) { + if (USE_DECODE_K) { + uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE_K + d * 16 + col_vec * 4; + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + } else { + K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); + } + } + + kvsh[row * kvsh_stride + col_vec] = K_Tf; } - - kvsh[row * kvsh_stride + col_vec] = K_Tf; } + barrier(); } - barrier(); -#if BLOCK_SIZE == 1 - } -#endif -#if BLOCK_SIZE == 1 - if (KV_bounds_check || d * 16 + 16 > HSK) -#endif - { + if (stage_k) { uint coord = (gl_SubgroupID * MatBc) * kvsh_stride; coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); - } -#if BLOCK_SIZE == 1 - else { + } else { const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4; coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor); } -#endif } else { uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4; coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); @@ -397,14 +387,14 @@ void main() { if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) { f16vec4 V_Tf = f16vec4(0); if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); -#else - V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); -#endif + if (USE_DECODE_V) { + uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d; + uint ib = coord / BLOCK_SIZE_V; + uint iqs = (coord % BLOCK_SIZE_V); + V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + } else { + V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); + } } kvsh[c * kvsh_stride + d] = V_Tf; @@ -431,36 +421,33 @@ void main() { // staged through a Bc * MatBr size staging buffer. // If V is not type f16, then it is always staged for dequantization. if (SHMEM_STAGING == 0) { -#if BLOCK_SIZE == 1 - // For f16, only preload if not aligned - if (KV_bounds_check) { -#endif - [[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) { - const uint idx = i * gl_WorkGroupSize.x + tid; - const uint row = idx / v_cols; - const uint col = idx % v_cols; + // For quants we always preload via kvsh. For f16 we only preload when + // alignment / bounds force it (otherwise we coopMatLoad direct from data_vv4). + const bool stage_v = USE_DECODE_V || KV_bounds_check; + if (stage_v) { + [[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) { + const uint idx = i * gl_WorkGroupSize.x + tid; + const uint row = idx / v_cols; + const uint col = idx % v_cols; - const uint v_row = j * Bc + row; - const uint v_col = hsv_tile * MatBc * row_split + col * 4; + const uint v_row = j * Bc + row; + const uint v_col = hsv_tile * MatBc * row_split + col * 4; - const uint coord = v_row * v_stride * BLOCK_SIZE + v_col; - const uint ib = coord / BLOCK_SIZE; - const uint iqs = coord % BLOCK_SIZE; + const uint coord = v_row * v_stride * BLOCK_SIZE_V + v_col; + const uint ib = coord / BLOCK_SIZE_V; + const uint iqs = coord % BLOCK_SIZE_V; - if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { -#if BLOCK_SIZE > 1 - kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); -#else - kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; -#endif - } else { - kvsh[row * vsh_stride + col] = f16vec4(0.0f); + if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { + if (USE_DECODE_V) { + kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + } else { + kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; + } + } else { + kvsh[row * vsh_stride + col] = f16vec4(0.0f); + } } } - -#if BLOCK_SIZE == 1 - } -#endif } barrier(); @@ -471,15 +458,12 @@ void main() { coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor); if (SHMEM_STAGING == 0) { -#if BLOCK_SIZE == 1 - if (!KV_bounds_check) { + if (!USE_DECODE_V && !KV_bounds_check) { // F16 values can be loaded directly from global memory const uint v_tile_row = j * Bc + bc_chunk * MatBc; const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); - } else -#endif - { + } else { const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4); coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 8a7bbaeb9..141bb8708 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -28,43 +28,28 @@ layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_ uint8_t raw[FaBlockBytesV]; }; -uint fa_block_elems(uint ty) { - switch (ty) { - case 0u: return 4u; // GGML_TYPE_F32: vec4 block (matches decodeBufF32 / dequantFuncF32) - case 1u: return 1u; // GGML_TYPE_F16 - case 2u: return uint(QUANT_K_Q4_0); - case 3u: return uint(QUANT_K_Q4_1); - case 6u: return uint(QUANT_K_Q5_0); - case 7u: return uint(QUANT_K_Q5_1); - case 8u: return uint(QUANT_K_Q8_0); - case 41u: return uint(QUANT_K_Q1_0); - default: - return 1u; - } -} - float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { switch (FaTypeK) { - case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock); - case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); - case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); - case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); - case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); - case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); - case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_0: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_1: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_0: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_1: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q8_0: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q1_0: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); default: return float16_t(0); } } float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { switch (FaTypeV) { - case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock); - case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); - case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); - case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); - case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); - case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); - case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_0: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_1: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_0: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_1: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q8_0: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q1_0: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); default: return float16_t(0); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl new file mode 100644 index 000000000..02106f33c --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl @@ -0,0 +1,123 @@ +// Asymmetric K/V flash attention: aliased SSBO views of bindings 1 (K) and 2 (V) +// covering every supported FA element type, plus an uber dequantize4() that +// switches on FaTypeK / FaTypeV. After spec-constant specialization the driver +// folds away every path except the one matching the K/V type for this pipeline. +// +// Included by flash_attn.comp and flash_attn_cm1.comp. Not included by +// flash_attn_cm2.comp, which has its own buffer_reference-based decode path. +// +// We use macros (rather than per-quant decode functions taking a struct) on +// purpose: the FA shaders don't enable GL_EXT_shader_explicit_arithmetic_types_float16 +// when FLOAT16 isn't defined, which makes float16-containing struct values +// illegal to return from / pass to functions. Macros expand inline where the +// float16 stays in storage and is converted to FLOAT_TYPE at use. + +// F32 is fed as a vec4 "block" (4 floats), matching what dequant_funcs_cm2.glsl +// does for F32 in the cm2 shader. FaBlockBytesK/V == 16 for F32. +layout (binding = 1) readonly buffer K_PACKED_F32 { vec4 data[]; } k_packed_f32; +layout (binding = 2) readonly buffer V_PACKED_F32 { vec4 data[]; } v_packed_f32; + +layout (binding = 1) readonly buffer K_PACKED_Q4_0 { block_q4_0_packed16 data[]; } k_packed_q4_0; +layout (binding = 2) readonly buffer V_PACKED_Q4_0 { block_q4_0_packed16 data[]; } v_packed_q4_0; +layout (binding = 1) readonly buffer K_PACKED_Q4_1 { block_q4_1_packed16 data[]; } k_packed_q4_1; +layout (binding = 2) readonly buffer V_PACKED_Q4_1 { block_q4_1_packed16 data[]; } v_packed_q4_1; +layout (binding = 1) readonly buffer K_PACKED_Q5_0 { block_q5_0_packed16 data[]; } k_packed_q5_0; +layout (binding = 2) readonly buffer V_PACKED_Q5_0 { block_q5_0_packed16 data[]; } v_packed_q5_0; +layout (binding = 1) readonly buffer K_PACKED_Q5_1 { block_q5_1_packed16 data[]; } k_packed_q5_1; +layout (binding = 2) readonly buffer V_PACKED_Q5_1 { block_q5_1_packed16 data[]; } v_packed_q5_1; +layout (binding = 1) readonly buffer K_PACKED_Q8_0 { block_q8_0_packed16 data[]; } k_packed_q8_0; +layout (binding = 2) readonly buffer V_PACKED_Q8_0 { block_q8_0_packed16 data[]; } v_packed_q8_0; + +// Q4_1 and Q5_1 packed32 views: aliased to the same memory as the packed16 +// views, used by the MMQ K-side hot path for fast 4-uint loads. +layout (binding = 1) readonly buffer K_PACKED_Q4_1_P32 { block_q4_1_packed32 data[]; } k_packed_q4_1_p32; +layout (binding = 1) readonly buffer K_PACKED_Q5_1_P32 { block_q5_1_packed32 data[]; } k_packed_q5_1_p32; + +// Per-quant decode bodies are expanded once for the K view set and once for +// the V view set. The macros take the buffer name as a parameter. +#define FA_DEQUANT4_F32(BUF) \ + return FLOAT_TYPEV4(BUF.data[a_offset + ib]); + +#define FA_DEQUANT4_Q4_0(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); \ +} + +#define FA_DEQUANT4_Q4_1(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * nibbles \ + + FLOAT_TYPE(BUF.data[a_offset + ib].m); \ +} + +#define FA_DEQUANT4_Q5_0(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + uint qh = uint(BUF.data[a_offset + ib].qh[0]) \ + | (uint(BUF.data[a_offset + ib].qh[1]) << 16); \ + FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, \ + (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) \ + * FLOAT_TYPE(16.0f); \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); \ +} + +#define FA_DEQUANT4_Q5_1(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + uint qh = BUF.data[a_offset + ib].qh; \ + FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, \ + (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) \ + * FLOAT_TYPE(16.0f); \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles + hb) \ + + FLOAT_TYPE(BUF.data[a_offset + ib].m); \ +} + +#define FA_DEQUANT4_Q8_0(BUF) { \ + const i8vec2 v0 = unpack8(int32_t(BUF.data[a_offset + ib].qs[iqs / 2 ])).xy; \ + const i8vec2 v1 = unpack8(int32_t(BUF.data[a_offset + ib].qs[iqs / 2 + 1])).xy; \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); \ +} + +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + if (binding_idx == BINDING_IDX_K) { + switch (FaTypeK) { + case FA_TYPE_F32: FA_DEQUANT4_F32 (k_packed_f32) + case FA_TYPE_Q4_0: FA_DEQUANT4_Q4_0(k_packed_q4_0) + case FA_TYPE_Q4_1: FA_DEQUANT4_Q4_1(k_packed_q4_1) + case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(k_packed_q5_0) + case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(k_packed_q5_1) + case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(k_packed_q8_0) + } + } else { + switch (FaTypeV) { + case FA_TYPE_F32: FA_DEQUANT4_F32 (v_packed_f32) + case FA_TYPE_Q4_0: FA_DEQUANT4_Q4_0(v_packed_q4_0) + case FA_TYPE_Q4_1: FA_DEQUANT4_Q4_1(v_packed_q4_1) + case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(v_packed_q5_0) + case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(v_packed_q5_1) + case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(v_packed_q8_0) + } + } + return FLOAT_TYPEV4(0); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl index e14e62d54..6bf10a7cf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl @@ -1,149 +1,203 @@ -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) +// MMQ K-side helpers, asymmetric form. Each function dispatches on FaTypeK and +// reads from the matching aliased K binding declared in flash_attn_dequant.glsl. +// Spec-constant specialization folds the unused paths. + int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { -#ifdef DATA_A_Q4_0 - uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], - k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); -#else - uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4]; -#endif - - uint shift = (iqs & 0x10) >> 2; - vui >>= shift; - - return int32_t(vui & 0x0F0F0F0F); + switch (FaTypeK) { + case FA_TYPE_Q4_0: { + uint vui = pack32(u16vec2(k_packed_q4_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed_q4_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + return int32_t(vui & 0x0F0F0F0F); + } + case FA_TYPE_Q4_1: { // uses packed32 alias + uint vui = k_packed_q4_1_p32.data[a_offset + ib].qs[(iqs & 0xF) / 4]; + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + return int32_t(vui & 0x0F0F0F0F); + } + case FA_TYPE_Q5_0: { + uint vui = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed_q5_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); + uint qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qh[0], + k_packed_q5_0.data[a_offset + ib].qh[1])); + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + uint qh_bits = (qh >> iqs) & 0xF; + return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u); + } + case FA_TYPE_Q5_1: { // qs via packed32, qh via packed16 + uint vui = k_packed_q5_1_p32.data[a_offset + ib].qs[(iqs & 0xF) / 4]; + uint qh = k_packed_q5_1.data[a_offset + ib].qh; + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + uint qh_bits = (qh >> iqs) & 0xF; + return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u); + } + case FA_TYPE_Q8_0: { + return pack32(i16vec2(k_packed_q8_0.data[a_offset + ib].qs[iqs / 2], + k_packed_q8_0.data[a_offset + ib].qs[iqs / 2 + 1])); + } + default: return 0; + } } -#endif -#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) -int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { -#ifdef DATA_A_Q5_0 - uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], - k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); - uint qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qh[0], - k_packed.k_data_packed16[a_offset + ib].qh[1])); -#else - uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4]; - uint qh = k_packed.k_data_packed16[a_offset + ib].qh; -#endif - - uint shift = (iqs & 0x10) >> 2; - vui >>= shift; - - uint qh_bits = (qh >> iqs) & 0xF; - return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u); +// Per-block scale/min, packed as (d, m). Single-scale types (Q4_0, Q5_0, Q8_0) +// return (d, 0) so call sites always see the same shape. +FLOAT_TYPEV2 get_k_scale(uint ib, uint a_offset) { + switch (FaTypeK) { + case FA_TYPE_Q4_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q4_0.data[a_offset + ib].d), 0.0); + case FA_TYPE_Q4_1: return FLOAT_TYPEV2(k_packed_q4_1_p32.data[a_offset + ib].dm); + case FA_TYPE_Q5_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q5_0.data[a_offset + ib].d), 0.0); + case FA_TYPE_Q5_1: return FLOAT_TYPEV2(k_packed_q5_1_p32.data[a_offset + ib].dm); + case FA_TYPE_Q8_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q8_0.data[a_offset + ib].d), 0.0); + default: return FLOAT_TYPEV2(0); + } } -#endif - -#if defined(DATA_A_Q8_0) -int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { - return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])); -} -#endif - -#if defined(DATA_A_IQ4_NL) -int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { - uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], - k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); - uint shift = (iqs & 0x10) >> 2; - vui >>= shift; - - u8vec4 idx = unpack8(vui & 0x0F0F0F0F); - return pack32(i8vec4(kvalues_iq4nl_const[idx.x], - kvalues_iq4nl_const[idx.y], - kvalues_iq4nl_const[idx.z], - kvalues_iq4nl_const[idx.w])); -} -#endif - -#if QUANT_AUXF == 1 -FLOAT_TYPE get_k_d(uint ib, uint a_offset) { - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d); -} -#else -FLOAT_TYPEV2 get_k_dm(uint ib, uint a_offset) { - return FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + ib].dm); -} -#endif void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) { -#if defined(DATA_A_Q4_0) - kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], - k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); -#elif defined(DATA_A_Q4_1) - kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs]; -#elif defined(DATA_A_Q5_0) - kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], - k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); - if (iqs == 0) { - kblocksh[buf_ib].qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qh[0], - k_packed.k_data_packed16[a_offset + global_ib].qh[1])); + // kblocksh[].qs is int32_t for the unified MMQ struct; uint sources need + // explicit casts. The bit pattern is what we care about here -- the actual + // signed/unsigned interpretation happens downstream in the dot product. + switch (FaTypeK) { + case FA_TYPE_Q4_0: { + kblocksh[buf_ib].qs[iqs] = int32_t(pack32(u16vec2(k_packed_q4_0.data[a_offset + global_ib].qs[iqs * 2], + k_packed_q4_0.data[a_offset + global_ib].qs[iqs * 2 + 1]))); + break; + } + case FA_TYPE_Q4_1: { + kblocksh[buf_ib].qs[iqs] = int32_t(k_packed_q4_1_p32.data[a_offset + global_ib].qs[iqs]); + break; + } + case FA_TYPE_Q5_0: { + kblocksh[buf_ib].qs[iqs] = int32_t(pack32(u16vec2(k_packed_q5_0.data[a_offset + global_ib].qs[iqs * 2], + k_packed_q5_0.data[a_offset + global_ib].qs[iqs * 2 + 1]))); + if (iqs == 0) { + kblocksh[buf_ib].qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + global_ib].qh[0], + k_packed_q5_0.data[a_offset + global_ib].qh[1])); + } + break; + } + case FA_TYPE_Q5_1: { + kblocksh[buf_ib].qs[iqs] = int32_t(k_packed_q5_1_p32.data[a_offset + global_ib].qs[iqs]); + if (iqs == 0) { + kblocksh[buf_ib].qh = k_packed_q5_1.data[a_offset + global_ib].qh; + } + break; + } + case FA_TYPE_Q8_0: { + kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed_q8_0.data[a_offset + global_ib].qs[iqs * 2], + k_packed_q8_0.data[a_offset + global_ib].qs[iqs * 2 + 1])); + break; + } } -#elif defined(DATA_A_Q5_1) - kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs]; - if (iqs == 0) { - kblocksh[buf_ib].qh = k_packed.k_data_packed16[a_offset + global_ib].qh; - } -#elif defined(DATA_A_Q8_0) - kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], - k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); -#elif defined(DATA_A_IQ4_NL) - const uint qs = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], - k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); - const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); - const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); - kblocksh[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_iq4nl_const[i_a0.x], kvalues_iq4nl_const[i_a0.y], - kvalues_iq4nl_const[i_a0.z], kvalues_iq4nl_const[i_a0.w])); - kblocksh[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_iq4nl_const[i_a1.x], kvalues_iq4nl_const[i_a1.y], - kvalues_iq4nl_const[i_a1.z], kvalues_iq4nl_const[i_a1.w])); -#endif if (iqs == 0) { -#if QUANT_AUXF == 1 - kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[a_offset + global_ib].d); -#else - kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + global_ib].dm); -#endif + // Q4_0/Q5_0/Q8_0 store dm.x = d; Q4_1/Q5_1 store dm = (d, m) pair. + switch (FaTypeK) { + case FA_TYPE_Q4_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q4_0.data[a_offset + global_ib].d), 0.0); break; + case FA_TYPE_Q4_1: kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed_q4_1_p32.data[a_offset + global_ib].dm); break; + case FA_TYPE_Q5_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q5_0.data[a_offset + global_ib].d), 0.0); break; + case FA_TYPE_Q5_1: kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed_q5_1_p32.data[a_offset + global_ib].dm); break; + case FA_TYPE_Q8_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q8_0.data[a_offset + global_ib].d), 0.0); break; + } } } +// d_per_step==8 hot path: read one full 32-element block worth of nibble-packed +// int32 quants. Equivalent to 8 calls to get_k_qs(ib, d*4, a_offset) but reads +// qh (Q5_*) and runs pack32 (Q4_0/Q5_0) once per block instead of per nibble +// quad. iqs is always 0 in this path (hsk4 % 8 == 0 implies block-aligned). +// Q8_0 takes the generic get_k_qs path because its qs layout (i8 pairs) doesn't +// share this nibble shape. +// +// Returned via a struct so the caller's k_quants array (sized from spec +// constants) doesn't need to match a fixed[8] out-parameter type. +struct fa_k_qs_block8 { + int32_t qs[8]; +}; + +fa_k_qs_block8 get_k_qs_block8(uint ib, uint a_offset) { + fa_k_qs_block8 r; + uint qh = 0; + if (FaTypeK == FA_TYPE_Q5_0) { + qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qh[0], + k_packed_q5_0.data[a_offset + ib].qh[1])); + } else if (FaTypeK == FA_TYPE_Q5_1) { + qh = k_packed_q5_1.data[a_offset + ib].qh; + } + const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1); + [[unroll]] for (uint32_t d = 0; d < 4; d++) { + uint vui = 0; + switch (FaTypeK) { + case FA_TYPE_Q4_0: { // packed16 + vui = pack32(u16vec2(k_packed_q4_0.data[a_offset + ib].qs[d * 2 + 0], + k_packed_q4_0.data[a_offset + ib].qs[d * 2 + 1])); + break; + } + case FA_TYPE_Q4_1: { // packed32 alias + vui = k_packed_q4_1_p32.data[a_offset + ib].qs[d]; + break; + } + case FA_TYPE_Q5_0: { // packed16 + vui = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qs[d * 2 + 0], + k_packed_q5_0.data[a_offset + ib].qs[d * 2 + 1])); + break; + } + case FA_TYPE_Q5_1: { // packed32 alias + vui = k_packed_q5_1_p32.data[a_offset + ib].qs[d]; + break; + } + } + r.qs[d ] = int32_t( vui & 0x0F0F0F0F); + r.qs[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); + if (has_qh) { + uint qh_lo = (qh >> (d * 4)) & 0xFu; + uint qh_hi = (qh >> (d * 4 + 16)) & 0xFu; + r.qs[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); + r.qs[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); + } + } + return r; +} + int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) { -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) - uint sub = pos % 4; - uint shift = ((pos % 8) >= 4) ? 4 : 0; - return int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F); -#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - uint sub = pos % 4; - uint shift = ((pos % 8) >= 4) ? 4 : 0; - int32_t result = int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F); - uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4)) & 0xF; - return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u); -#elif defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) - return kblocksh[buf_ib].qs[pos]; -#endif + switch (FaTypeK) { + case FA_TYPE_Q4_0: + case FA_TYPE_Q4_1: { + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4u : 0u; + return int32_t((uint(kblocksh[buf_ib].qs[sub]) >> shift) & 0x0F0F0F0Fu); + } + case FA_TYPE_Q5_0: + case FA_TYPE_Q5_1: { + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4u : 0u; + int32_t result = int32_t((uint(kblocksh[buf_ib].qs[sub]) >> shift) & 0x0F0F0F0Fu); + uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4u)) & 0xFu; + return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u); + } + case FA_TYPE_Q8_0: { + return kblocksh[buf_ib].qs[pos]; + } + default: return 0; + } } ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) { -#if defined(DATA_A_Q4_0) - return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; -#elif defined(DATA_A_Q5_0) - return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; -#elif defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) - return ACC_TYPE(Qf[qib].ds.y) * k_dm.y; -#else - return ACC_TYPE(0.0); -#endif + switch (FaTypeK) { + case FA_TYPE_Q4_0: return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; + case FA_TYPE_Q5_0: return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; + case FA_TYPE_Q4_1: + case FA_TYPE_Q5_1: return ACC_TYPE(Qf[qib].ds.y) * k_dm.y; + default: return ACC_TYPE(0.0); + } } void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) { kblocksh[buf_ib].qs[iqs] = 0; -#if defined(DATA_A_IQ4_NL) - kblocksh[buf_ib].qs[iqs + 4] = 0; -#endif if (iqs == 0) { -#if QUANT_AUXF == 1 - kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f); -#else kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f); -#endif } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl index 10552d013..79c933f40 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl @@ -1,4 +1,13 @@ -#if defined(DATA_A_Q4_0) +#if defined(FA_MMQ_MIXED) +// Mixed-K flash attention MMQ: superset cache that fits Q4_0/Q4_1/Q5_0/Q5_1/Q8_0. +// Q4_*/Q5_* only use qs[0..3] and (for Q5_*) qh. Q8_0 uses qs[0..7]. Single-scale +// types (Q4_0/Q5_0/Q8_0) leave dm.y unused. +struct block_a_cache { + int32_t qs[8]; + uint32_t qh; + FLOAT_TYPEV2 dm; +}; +#elif defined(DATA_A_Q4_0) #define QUANT_R_MMQ 2 struct block_a_cache { uint32_t qs[16/4]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 6f2a929c4..d99b2b5d8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -643,42 +643,22 @@ void process_shaders() { if (fp16) { #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - string_to_spv("flash_attn_f32_f16_mixed", "flash_attn_cm2.comp", + string_to_spv("flash_attn_f32_f16", "flash_attn_cm2.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc); #endif - } - for (const auto& tname : type_names) { - if (tname == "bf16") continue; - - if (fp16) { #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); - } else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); - } + string_to_spv("flash_attn_f32_f16", "flash_attn_cm1.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); #endif - } - - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); - } else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc); -#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (tname != "f32") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8"); - } -#endif - } } + + string_to_spv("flash_attn_f32_f16", "flash_attn.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + string_to_spv("flash_attn_f32_f16", "flash_attn.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"MMQ", "1"}, {"FA_MMQ_MIXED", "1"}}), fp16, false, false, f16acc, "_int8"); +#endif } } From 7dbb0e998a125973091914aec2928a5104b36725 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 11 May 2026 13:00:57 +0200 Subject: [PATCH 04/17] examples : update args speculative-simple README.md [no ci] (#22938) This commit updates the command line arguments to use the correct names and values which are now required. The motivation for this change is that currently running the example command as is will generate the following errors: ```console error while handling argument "--color": error: unknown value for --color: '--sampling-seq' usage: -co, --color [on|off|auto] Colorize output to distinguish prompt and user input from generations ('on', 'off', or 'auto', default: 'auto') 'auto' enables colors when output is to a terminal error while handling argument "-fa": error: unknown value for --flash-attn: '--temp' usage: -fa, --flash-attn [on|off|auto] set Flash Attention use ('on', 'off', or 'auto', default: 'auto') (env: LLAMA_ARG_FLASH_ATTN) error while handling argument "--draft-max": the argument has been removed. use --spec-draft-n-max or --spec-ngram-mod-n-max usage: --draft, --draft-n, --draft-max N the argument has been removed. use --spec-draft-n-max or --spec-ngram-mod-n-max (env: LLAMA_ARG_DRAFT_MAX) error while handling argument "--draft-min": the argument has been removed. use --spec-draft-n-min or --spec-ngram-mod-n-min usage: --draft-min, --draft-n-min N the argument has been removed. use --spec-draft-n-min or --spec-ngram-mod-n-min (env: LLAMA_ARG_DRAFT_MIN) ``` --- examples/speculative-simple/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/speculative-simple/README.md b/examples/speculative-simple/README.md index e3a6c6b4a..f72129b3f 100644 --- a/examples/speculative-simple/README.md +++ b/examples/speculative-simple/README.md @@ -6,7 +6,7 @@ Demonstration of basic greedy speculative decoding ./bin/llama-speculative-simple \ -m ../models/qwen2.5-32b-coder-instruct/ggml-model-q8_0.gguf \ -md ../models/qwen2.5-1.5b-coder-instruct/ggml-model-q4_0.gguf \ - -f test.txt -c 0 -ngl 99 --color \ - --sampling-seq k --top-k 1 -fa --temp 0.0 \ - -ngld 99 --draft-max 16 --draft-min 5 --draft-p-min 0.9 + -f test.txt -c 0 -ngl 99 --color on \ + --sampling-seq k --top-k 1 -fa on --temp 0.0 \ + -ngld 99 --spec-draft-n-max 16 --spec-draft-n-draft-min 5 --draft-p-min 0.9 ``` From 928b486b0c8ef4a126086e078126cdb42e977fc7 Mon Sep 17 00:00:00 2001 From: Kevin Pouget Date: Mon, 11 May 2026 15:38:22 +0200 Subject: [PATCH 05/17] ggml-virtgpu: Add a GHA build check (#22943) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [ggml-virtgpu] Add a GHA build check * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret --------- Co-authored-by: Sigbjørn Skjæret --- .github/workflows/build-virtgpu.yml | 50 +++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 .github/workflows/build-virtgpu.yml diff --git a/.github/workflows/build-virtgpu.yml b/.github/workflows/build-virtgpu.yml new file mode 100644 index 000000000..5b740590d --- /dev/null +++ b/.github/workflows/build-virtgpu.yml @@ -0,0 +1,50 @@ +name: CI (virtgpu) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: [ + '.github/workflows/build-virtgpu.yml', + '**/CMakeLists.txt', + '**/.cmake', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp' + ] + + pull_request: + types: [opened, synchronize, reopened] + paths: [ + '.github/workflows/build-virtgpu.yml', + 'ggml/src/ggml-virtgpu/**' + ] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + ubuntu-24-virtgpu: + runs-on: ${{ 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install -y build-essential libdrm-dev pkg-config libssl-dev + + - name: Build + id: cmake_build + run: | + cmake -B build \ + -DGGML_VIRTGPU=ON \ + -DGGML_VIRTGPU_BACKEND=ON + cmake --build build --config Release -j $(nproc) From 68e7ea3eabef29a3e222681c81e0cc7ed070c09d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 11 May 2026 19:09:43 +0300 Subject: [PATCH 06/17] spec : parallel drafting support (#22838) * spec : refactor * spec : drop support for incompatible vocabs * spec : update common_speculative_init() * cont : pass seq_id * cont : dedup ctx_seq_rm_type * server : sketch the ctx_dft decode loop * server : draft prompt cache and checkpoints * server : improve ctx names * server, spec : transition to unified spec context * cont : sync main and drft contexts * cont : async drft eval when possible * cont : handle non-ckpt models * cont : pass correct n_past for drafting * cont : process images throught the draft context * spec : handle draft running out of context * server : fix mtmd draft processing * server : fix URL for draft model * server : add comment * server : clean-up + dry * speculative-simple : update * spec : fix n_past type * server : fix slot ctx_drft ptr * tools : update readme * naming : improve consistency * spec : refactor for multi-sequence speculative context * cont : prepare params * cont : prepare params * spec : support parallel drafts * server : support parallel drafting * llama : reuse device buffers when possible * server, spec : clean-up * cont : clean-up * cont : minor * spec : reset `drafting` flag at the end * spec : introduce `common_speculative_process()` * spec : allow for multiple spec types (chain of speculators) * replace old type field of type common_speculative_type in the common_params_speculative struct with a vector to allow multiple types to be specified * introduce common_get_enabled_speculative_impls(const std::vector) to figure out which implementations the user has enabled * introduce common_speculative_type_from_names(const std::vector & names) to parse the already user provided spec types * all speculators run sequentially, best one wins (we verify its drafted tokens) * maximize expected accepted tokens for current round by calculating the product between the probability of accepting current token (n_acc_tokens / n_gen_drafts) and the draft's length --------- Co-authored-by: Petros Sideris --- common/arg.cpp | 41 +- common/common.cpp | 101 +- common/common.h | 59 +- common/speculative.cpp | 1239 ++++++++--------- common/speculative.h | 53 +- .../speculative-simple/speculative-simple.cpp | 130 +- include/llama.h | 2 + src/llama-context.cpp | 27 +- tools/cli/README.md | 2 - tools/server/README.md | 2 - tools/server/server-context.cpp | 592 ++++---- tools/server/server-task.cpp | 66 +- tools/server/server-task.h | 41 +- tools/server/tests/unit/test_speculative.py | 2 +- 14 files changed, 1286 insertions(+), 1071 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 55ec9389b..9fefe411e 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -622,10 +622,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context for (auto & seq_breaker : params.sampling.dry_sequence_breakers) { string_process_escapes(seq_breaker); } - for (auto & pair : params.speculative.draft.replacements) { - string_process_escapes(pair.first); - string_process_escapes(pair.second); - } } if (!params.kv_overrides.empty()) { @@ -3518,13 +3514,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.draft.p_min = std::stof(value); } ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_P_MIN")); - add_opt(common_arg( - {"--spec-draft-ctx-size", "-cd", "--ctx-size-draft"}, "N", - string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.draft.n_ctx), - [](common_params & params, int value) { - params.speculative.draft.n_ctx = value; - } - ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_CTX_SIZE")); add_opt(common_arg( {"--spec-draft-device", "-devd", "--device-draft"}, "", "comma-separated list of devices to use for offloading the draft model (none = don't offload)\n" @@ -3561,32 +3550,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL")); add_opt(common_arg( - {"--spec-draft-replace", "--spec-replace"}, "TARGET", "DRAFT", - "translate the string in TARGET into DRAFT if the draft model and main model are not compatible", - [](common_params & params, const std::string & tgt, const std::string & dft) { - params.speculative.draft.replacements.push_back({ tgt, dft }); - } - ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); - add_opt(common_arg( - {"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]", + {"--spec-type"}, common_speculative_all_types_str(), string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n", - common_speculative_type_to_str(params.speculative.type).c_str()), + common_speculative_type_name_str(params.speculative.types).c_str()), [](common_params & params, const std::string & value) { - if (value == "none") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE; - } else if (value == "ngram-cache") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE; - } else if (value == "ngram-simple") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE; - } else if (value == "ngram-map-k") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K; - } else if (value == "ngram-map-k4v") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V; - } else if (value == "ngram-mod") { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD; - } else { - throw std::invalid_argument("unknown speculative decoding type without draft model"); - } + const auto enabled_types = string_split(value, ','); + params.speculative.types = common_speculative_types_from_names(enabled_types); } ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_TYPE")); add_opt(common_arg( @@ -4075,7 +4044,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--spec-default"}, string_format("enable default speculative decoding config"), [](common_params & params) { - params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD; + params.speculative.types = { COMMON_SPECULATIVE_TYPE_NGRAM_MOD }; params.speculative.ngram_mod.n_match = 24; params.speculative.ngram_mod.n_min = 48; params.speculative.ngram_mod.n_max = 64; diff --git a/common/common.cpp b/common/common.cpp index 793b8fee7..352af0b17 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1422,7 +1422,7 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) { // try to remove the last tokens if (!llama_memory_seq_rm(mem, 0, 1, -1)) { - LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); + LOG_WRN("%s: the context does not support partial sequence removal\n", __func__); res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL; goto done; } @@ -1960,3 +1960,102 @@ bool common_prompt_batch_decode( return true; } + +size_t common_prompt_checkpoint::size() const { + return data_tgt.size() + data_dft.size(); +} + +bool common_prompt_checkpoint::empty() const { + return data_tgt.empty(); +} + +void common_prompt_checkpoint::clear() { + n_tokens = 0; + + pos_min = 0; + pos_max = 0; + + data_tgt.clear(); + data_dft.clear(); +} + +void common_prompt_checkpoint::update_pos( + int64_t n_tokens, + llama_pos pos_min, + llama_pos pos_max) { + this->n_tokens = n_tokens; + this->pos_min = pos_min; + this->pos_max = pos_max; +} + +void common_prompt_checkpoint::update_tgt( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) { + if (ctx == nullptr) { + return; + } + + const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags); + + data_tgt.resize(ckpt_size); + + const size_t n = llama_state_seq_get_data_ext(ctx, data_tgt.data(), ckpt_size, seq_id, flags); + if (n != ckpt_size) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n); + } +} + +void common_prompt_checkpoint::update_dft( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) { + if (ctx == nullptr) { + return; + } + + const size_t ckpt_size = llama_state_seq_get_size_ext(ctx, seq_id, flags); + + data_dft.resize(ckpt_size); + + const size_t n = llama_state_seq_get_data_ext(ctx, data_dft.data(), ckpt_size, seq_id, flags); + if (n != ckpt_size) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", ckpt_size, n); + } +} + +void common_prompt_checkpoint::load_tgt( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) const { + if (ctx == nullptr) { + return; + } + + if (data_tgt.empty()) { + return; + } + + const size_t n = llama_state_seq_set_data_ext(ctx, data_tgt.data(), data_tgt.size(), seq_id, flags); + if (n != data_tgt.size()) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_tgt.size(), n); + } +} + +void common_prompt_checkpoint::load_dft( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) const { + if (ctx == nullptr) { + return; + } + + if (data_dft.empty()) { + return; + } + + const size_t n = llama_state_seq_set_data_ext(ctx, data_dft.data(), data_dft.size(), seq_id, flags); + if (n != data_dft.size()) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_dft.size(), n); + } +} diff --git a/common/common.h b/common/common.h index a564b3b8c..aafc376f2 100644 --- a/common/common.h +++ b/common/common.h @@ -295,8 +295,6 @@ struct common_params_model { std::string name = ""; // in format /[:] (tag is optional) // NOLINT }; -struct common_ngram_mod; - // draft-model-based speculative decoding parameters struct common_params_speculative_draft { int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding @@ -307,11 +305,9 @@ struct common_params_speculative_draft { common_params_model mparams; - llama_model * model = nullptr; // a llama_model that can be shared by multiple speculative contexts + llama_context * ctx_tgt = nullptr; + llama_context * ctx_dft = nullptr; - llama_context_params cparams; // these are the parameters for the draft llama_context - - int32_t n_ctx = 0; // draft context size int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K @@ -322,7 +318,6 @@ struct common_params_speculative_draft { std::vector devices; // devices to use for offloading - std::vector> replacements; // main to speculative model replacements std::vector tensor_buft_overrides; }; @@ -331,9 +326,6 @@ struct common_params_speculative_ngram_mod { int32_t n_max = 64; int32_t n_min = 48; - - // shared instance of the ngram container for all speculative decoding contexts - std::shared_ptr obj; }; struct common_params_speculative_ngram_map { @@ -348,8 +340,7 @@ struct common_params_speculative_ngram_cache { }; struct common_params_speculative { - // TODO: become a vector in order to support "chains of speculators" - common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; + std::vector types = { COMMON_SPECULATIVE_TYPE_NONE }; common_params_speculative_draft draft; @@ -1026,3 +1017,47 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std // "adamw" or "sgd" (case insensitive) enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *); + +// +// prompt utils +// + +struct common_prompt_checkpoint { + int64_t n_tokens; + + llama_pos pos_min; + llama_pos pos_max; + + std::vector data_tgt; + std::vector data_dft; + + size_t size() const; + + bool empty() const; + void clear(); + + void update_pos( + int64_t n_tokens, + llama_pos pos_min, + llama_pos pos_max); + + void update_tgt( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags); + + void update_dft( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags); + + void load_tgt( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) const; + + void load_dft( + llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags) const; +}; diff --git a/common/speculative.cpp b/common/speculative.cpp index e9fa751e2..e487e003d 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -10,6 +10,7 @@ #include "sampling.h" #include +#include #include #include #include @@ -18,26 +19,15 @@ #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 -const std::vector common_speculative_types = { - COMMON_SPECULATIVE_TYPE_NONE, - COMMON_SPECULATIVE_TYPE_DRAFT, - COMMON_SPECULATIVE_TYPE_EAGLE3, - COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, - COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, - COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, - COMMON_SPECULATIVE_TYPE_NGRAM_MOD, - COMMON_SPECULATIVE_TYPE_NGRAM_CACHE -}; - -const std::map common_speculative_type_from_name_map = { +const std::map common_speculative_type_from_name_map = { {"none", COMMON_SPECULATIVE_TYPE_NONE}, {"draft", COMMON_SPECULATIVE_TYPE_DRAFT}, {"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3}, - {"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, - {"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, - {"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, - {"ngram_mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD}, - {"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE} + {"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, + {"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, + {"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, + {"ngram-mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD}, + {"ngram-cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE} }; struct common_speculative_config { @@ -115,12 +105,16 @@ static bool common_speculative_are_compatible( return true; } +using common_speculative_draft_params_vec = std::vector; + // state of an implementation of speculative decoding // // each implementation has a unique type and a state that is implementation-specific -// in a subclass of common_speculative_state -struct common_speculative_state { - const enum common_speculative_type type; +// in a subclass of common_speculative_impl +struct common_speculative_impl { + const common_speculative_type type; + + uint32_t n_seq; size_t n_call_begin = 0; // number of times this implementation was called for refresh. size_t n_call_draft = 0; // number of times this implementation was called for generation. @@ -138,65 +132,34 @@ struct common_speculative_state { int64_t t_draft_us = 0; // total time spent in generating drafts in this implementation in microseconds. int64_t t_accept_us = 0; // total time spent in accumulation of this implementation in microseconds. - common_speculative_state(enum common_speculative_type type) : type(type) {} + common_speculative_impl(common_speculative_type type, uint32_t n_seq) : type(type), n_seq(n_seq) {} - virtual ~common_speculative_state() = default; + virtual ~common_speculative_impl() = default; - virtual void begin(const llama_tokens & prompt) = 0; + virtual void begin(llama_seq_id seq_id, const llama_tokens & prompt) = 0; - virtual void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) = 0; + virtual bool process(const llama_batch & batch) = 0; - virtual void accept(uint16_t n_accepted) = 0; + virtual void draft(common_speculative_draft_params_vec & dparams) = 0; - virtual int32_t n_max(const common_params_speculative & params) const = 0; - virtual int32_t n_min(const common_params_speculative & params) const = 0; + virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0; }; -struct common_speculative_checkpoint { - llama_pos pos_min = 0; - llama_pos pos_max = 0; +struct common_speculative_state_draft : public common_speculative_impl { + common_params_speculative_draft params; - int64_t n_tokens = 0; + llama_batch batch; - std::vector data; + std::vector smpls; - size_t size() const { - return data.size(); - } -}; - -struct common_speculative_state_draft : public common_speculative_state { - llama_context * ctx_tgt; // only used for retokenizing from ctx_dft - llama_context * ctx_dft; - - bool use_ckpt = false; - common_speculative_checkpoint ckpt; - - common_sampler * smpl; - - llama_batch batch; - llama_tokens prompt_dft; - - bool vocab_cmpt = true; // whether retokenization is needed - std::unordered_map vocab_map; - - common_speculative_state_draft( - enum common_speculative_type type, - llama_context * ctx_tgt, - llama_context * ctx_dft, - const std::vector> & replacements, - bool use_ckpt) - : common_speculative_state(type) - , ctx_tgt(ctx_tgt) - , ctx_dft(ctx_dft) - , use_ckpt(use_ckpt) + common_speculative_state_draft(const common_params_speculative & params, uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT, n_seq) + , params(params.draft) { + auto * ctx_dft = this->params.ctx_dft; + auto * ctx_tgt = this->params.ctx_tgt; + batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); - smpl = nullptr; // TODO: optimize or pass from outside? // { @@ -214,7 +177,9 @@ struct common_speculative_state_draft : public common_speculative_state { // // result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); // } - { + + smpls.resize(n_seq); + for (auto & smpl : smpls) { common_params_sampling params; params.no_perf = false; params.top_k = 10; @@ -222,482 +187,321 @@ struct common_speculative_state_draft : public common_speculative_state { COMMON_SAMPLER_TYPE_TOP_K, }; - smpl = common_sampler_init(llama_get_model(ctx_dft), params); + smpl.reset(common_sampler_init(llama_get_model(ctx_dft), params)); } - vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); - LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt); + const bool vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); + LOG_DBG("%s: vocab_cmpt = %d\n", __func__, vocab_cmpt); if (!vocab_cmpt) { - LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n"); + LOG_ERR("%s: the target and draft vocabs are not compatible\n", __func__); - for (const auto & pair : replacements) { - vocab_map[pair.first] = pair.second; - } + throw std::runtime_error("draft model vocab type must match target model to use speculation"); + } + + if (n_seq != llama_n_seq_max(ctx_dft)) { + LOG_ERR("%s: n_seq mismatch: %d != %d\n", __func__, n_seq, llama_n_seq_max(ctx_dft)); + + throw std::runtime_error("the draft model number of sequences is incompatible with the speculative n_seq"); } } ~common_speculative_state_draft() override { - llama_perf_context_print(ctx_dft); - - llama_free(ctx_dft); - - common_sampler_free(smpl); - llama_batch_free(batch); } - void begin(const llama_tokens & /*prompt*/) override { + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { + // noop } - size_t create_checkpoint(int n_tokens_prompt) { - int slot_id = 0; - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + bool process(const llama_batch & batch) override { + auto * ctx_dft = params.ctx_dft; - ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id); - ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id); - ckpt.n_tokens = n_tokens_prompt; - ckpt.data.resize(checkpoint_size); + const int ret = llama_decode(ctx_dft, batch); - const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - if (n != checkpoint_size) { - GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); + if (ret != 0) { + LOG_ERR("%s: failed to decode draft batch, ret = %d\n", __func__, ret); + + return false; } - LOG_DBG("%s: pos_min = %d, pos_max = %d, size = %.3f MiB\n", __func__, - ckpt.pos_min, ckpt.pos_max, (float) ckpt.data.size() / 1024 / 1024); - return n; + return true; } - size_t restore_checkpoint() { - int slot_id = 0; - LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max); - const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - if (n != ckpt.size()) { - GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu", - __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size()); - } - llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1); + void draft(common_speculative_draft_params_vec & dparams) override { + auto & ctx_dft = params.ctx_dft; - return n; - } + common_batch_clear(batch); - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - const auto & sparams = params.draft; + // keep track of which sequences are still drafting + int n_drafting = 0; + std::vector drafting(n_seq); - auto * spec = this; + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; - auto & batch = spec->batch; - auto & ctx_tgt = spec->ctx_tgt; - auto & ctx_dft = spec->ctx_dft; - auto & smpl = spec->smpl; - auto & prompt_dft = spec->prompt_dft; + if (!dp.drafting) { + continue; + } - auto * mem_dft = llama_get_memory(ctx_dft); + n_drafting++; + drafting[seq_id] = true; + common_sampler_reset(smpls[seq_id].get()); - int reuse_i = 0; // index of part to be reused in prompt_dft - int reuse_n = 0; // length of part to be reused in prompt_dft - - const int n_ctx = llama_n_ctx(ctx_dft) - sparams.n_max; - - llama_tokens prompt_cnv; - if (!spec->vocab_cmpt) { - std::string text; - - text = common_detokenize(ctx_tgt, prompt_tgt, true); - text = replace_to_dft(text); - - LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str()); - - prompt_cnv = common_tokenize(ctx_dft, text, false, true); - - // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation - const auto * model_tgt = llama_get_model(ctx_tgt); - const auto * vocab_tgt = llama_model_get_vocab(model_tgt); - - int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false); - GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last"); - - text.resize(-n_chars); - llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false); - text = replace_to_dft(text); - - LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str()); - id_last = common_tokenize(ctx_dft, text, false, true)[0]; + common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true); } - const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv; - - const int i_start = std::max(0, (int) prompt_cur.size() - n_ctx); - - if (use_ckpt && i_start > 0) { - LOG_WRN("%s: context shift is not supported with checkpoint-based contexts - skipping\n", __func__); + int ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode returned %d\n", __func__, ret); return; } - // reuse as much as possible from the old draft context - // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt - for (int i = 0; i < (int) prompt_dft.size(); ++i) { - int cur = 0; - while (i_start + cur < (int) prompt_cur.size() && - i + cur < (int) prompt_dft.size() && - prompt_cur[i_start + cur] == prompt_dft[i + cur]) { - cur++; - } + int i = 0; - if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) { - reuse_i = i; - reuse_n = cur; - } + while (n_drafting > 0) { + int i_batch = 0; - if (use_ckpt) { - break; - } - } - - LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n", - __func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size()); - if (use_ckpt && ckpt.n_tokens > reuse_n) { - LOG_DBG("%s: checkpoint (n_tokens = %d) is outdated -> delete it\n", __func__, (int) ckpt.n_tokens); - - reuse_i = 0; - reuse_n = 0; - - ckpt = {}; - } - - result.clear(); - result.reserve(sparams.n_max); - - if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) { - llama_memory_clear(mem_dft, false); - prompt_dft.clear(); - } else { - // this happens when a previous draft has been discarded (for example, due to being too small), but the - // target model agreed with it. in this case, we simply pass back the previous results to save compute - if (reuse_i + reuse_n < (int64_t) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { - for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { - result.push_back(prompt_dft[i]); - - if (sparams.n_max <= (int) result.size()) { - break; - } - } - - return; - } - - if (reuse_i > 0) { - GGML_ASSERT(!use_ckpt); - - bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); - if (!is_removed) { - LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i); - return; - } - llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i); - - prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); - } - - if (reuse_n < (int) prompt_dft.size()) { - if (use_ckpt) { - if (ckpt.n_tokens > 0) { - LOG_DBG("%s: restoring checkpoint, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size()); - restore_checkpoint(); - reuse_n = ckpt.n_tokens; - prompt_dft.resize(reuse_n); - } - } else { - const bool is_removed = llama_memory_seq_rm(mem_dft, 0, reuse_n, -1); - if (!is_removed) { - LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size()); - return; - } - prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); - } - } - } - - // prepare a batch to evaluate any new tokens in the prompt - common_batch_clear(batch); - - for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) { - //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]); - common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false); - - prompt_dft.push_back(prompt_cur[i]); - } - - // we should rarely end-up here during normal decoding - if (batch.n_tokens > 0) { - //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); - LOG_DBG("%s: draft prompt batch: %d tokens\n", __func__, batch.n_tokens); - - int ret = llama_decode(ctx_dft, batch); - if (ret != 0 && ret != 1) { - LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n", - __func__, ret, prompt_cur.size()); - } - - if (use_ckpt) { - create_checkpoint(prompt_dft.size()); - } - } - - const llama_pos n_past = prompt_dft.size(); - - LOG_DBG("%s: n_past = %d\n", __func__, n_past); - - common_batch_clear(batch); - common_batch_add (batch, id_last, n_past, { 0 }, true); - - prompt_dft.push_back(id_last); - - //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); - - int ret = llama_decode(ctx_dft, batch); - if (ret != 0 && ret != 1) { - LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n", - __func__, ret, prompt_cur.size(), prompt_dft.size()); - } - - common_sampler_reset(smpl); - - // sample n_draft tokens from the draft model - for (int i = 0; i < sparams.n_max; ++i) { common_batch_clear(batch); - common_sampler_sample(smpl, ctx_dft, 0, true); + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (!drafting[seq_id]) { + continue; + } - const auto * cur_p = common_sampler_get_candidates(smpl, true); + auto * smpl = smpls[seq_id].get(); - for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + common_sampler_sample(smpl, ctx_dft, i_batch, true); + ++i_batch; + + const auto * cur_p = common_sampler_get_candidates(smpl, true); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p, + common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + drafting[seq_id] = false; + n_drafting--; + + continue; + } + + common_sampler_accept(smpl, id, true); + + auto & dp = dparams.at(seq_id); + auto & result = *dp.result; + + result.push_back(id); + + if ((params.n_max <= (int) result.size()) || + (dp.n_max > 0 && dp.n_max <= (int) result.size())) { + drafting[seq_id] = false; + n_drafting--; + continue; + } + + common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); } - // add drafted token for each sequence - const llama_token id = cur_p->data[0].id; - - common_sampler_accept(smpl, id, true); - - // only collect very high-confidence draft tokens - if (cur_p->data[0].p < sparams.p_min) { + if (batch.n_tokens == 0) { break; } - result.push_back(id); - - if (sparams.n_max <= (int) result.size()) { - break; - } - - common_batch_add(batch, id, n_past + i + 1, { 0 }, true); - // evaluate the drafted tokens on the draft model ret = llama_decode(ctx_dft, batch); if (ret != 0) { - LOG_WRN("%s: llama_decode[%d] returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n", - __func__, i, ret, prompt_cur.size(), prompt_dft.size()); + LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); + break; } - prompt_dft.push_back(id); + ++i; } - if (!spec->vocab_cmpt) { - std::string detokenized = common_detokenize(ctx_dft, result, true); - detokenized = replace_to_tgt(detokenized); - LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str()); - result = common_tokenize(ctx_tgt, detokenized, false, true); - if (result.size() > (size_t) sparams.n_max) { - result.resize(sparams.n_max); + for (auto & dp : dparams) { + if (!dp.drafting) { + continue; } - } - if (result.size() < (size_t) sparams.n_min) { - result.clear(); + if (dp.result->size() < (size_t) params.n_min) { + dp.result->clear(); + } } } - void accept(uint16_t n_accepted) override { + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { // noop - GGML_UNUSED(n_accepted); - } - - int32_t n_max(const common_params_speculative & params) const override { - return params.draft.n_max; - } - - int32_t n_min(const common_params_speculative & params) const override { - return params.draft.n_min; - } - - std::string replace_to_dft(const std::string & input) const { - std::string result = input; - - for (const auto & pair : this->vocab_map) { - size_t pos = result.find(pair.first); - while (pos != std::string::npos) { - result.replace(pos, pair.first.length(), pair.second); - pos = result.find(pair.first, pos + pair.second.length()); - } - } - - return result; - } - - std::string replace_to_tgt(const std::string & input) const { - std::string result = input; - - for (const auto & pair : this->vocab_map) { - size_t pos = result.find(pair.second); - while (pos != std::string::npos) { - result.replace(pos, pair.second.length(), pair.first); - pos = result.find(pair.second, pos + pair.first.length()); - } - } - - return result; } }; -struct common_speculative_state_eagle3 : public common_speculative_state { - common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {} +struct common_speculative_state_eagle3 : public common_speculative_impl { + //common_params_speculative_eagle3 params; - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - } + common_speculative_state_eagle3(const common_params_speculative & /*params*/, uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_EAGLE3, n_seq) {} - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & draft_tokens) override { - // TODO: implement - GGML_UNUSED(params); - GGML_UNUSED(prompt_tgt); - GGML_UNUSED(id_last); - GGML_UNUSED(draft_tokens); - } - - void accept(uint16_t n_accepted) override { + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop - GGML_UNUSED(n_accepted); } - int32_t n_max(const common_params_speculative & params) const override { - return params.draft.n_max; + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; } - int32_t n_min(const common_params_speculative & params) const override { - return params.draft.n_min; + void draft(common_speculative_draft_params_vec & /*dparams*/) override { + // TODO: implement + } + + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { + // noop } }; // state of self-speculation (simple implementation, not ngram-map) -struct common_speculative_state_ngram_simple : public common_speculative_state { +struct common_speculative_state_ngram_simple : public common_speculative_impl { + common_params_speculative_ngram_map params; + + // shared across all sequences common_ngram_simple_config config; common_speculative_state_ngram_simple( - enum common_speculative_type type, + const common_params_speculative & params, uint32_t n_seq, common_ngram_simple_config config) - : common_speculative_state(type), config(config) {} + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, n_seq) + , params(params.ngram_simple) + , config(config) {} - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); - } - - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - - result = common_ngram_simple_draft(config, prompt_tgt, id_last); - GGML_UNUSED(params); - } - - void accept(uint16_t n_accepted) override { + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop - GGML_UNUSED(n_accepted); } - int32_t n_max(const common_params_speculative & /*params*/) const override { - return config.size_mgram; + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; } - int32_t n_min(const common_params_speculative & /*params*/) const override { - return config.size_mgram; + void draft(common_speculative_draft_params_vec & dparams) override { + assert(dparams.size() == n_seq); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + if (!dp.drafting) { + continue; + } + + *dp.result = common_ngram_simple_draft(config, *dp.prompt, dp.id_last); + } + } + + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { + // noop } }; -struct common_speculative_state_ngram_map_k : public common_speculative_state { - // draft ngram map for speculative decoding without draft model - common_ngram_map config; +struct common_speculative_state_ngram_map_k : public common_speculative_impl { + common_params_speculative_ngram_map params; + + // n_seq configs + std::vector config; common_speculative_state_ngram_map_k( - enum common_speculative_type type, - common_ngram_map config) - : common_speculative_state(type), config(std::move(config)) {} - - void begin(const llama_tokens & prompt) override { - common_ngram_map_begin(config, prompt); - } - - void draft( const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - common_ngram_map_draft(config, prompt_tgt, id_last, result); - GGML_UNUSED(params); + const common_ngram_map & config, + uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, n_seq) + , params(params.ngram_map_k) { + for (uint32_t i = 0; i < n_seq; i++) { + this->config.push_back(config); + } } - void accept(uint16_t n_accepted) override { - common_ngram_map_accept(config, n_accepted); + void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { + GGML_ASSERT(seq_id < (llama_seq_id) n_seq); + + common_ngram_map_begin(config[seq_id], prompt); } - int32_t n_max(const common_params_speculative & /*params*/) const override { - return config.size_value; + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; } - int32_t n_min(const common_params_speculative & /*params*/) const override { - return config.size_value; + void draft(common_speculative_draft_params_vec & dparams) override { + assert(dparams.size() == n_seq); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + if (!dp.drafting) { + continue; + } + + common_ngram_map_draft(config[seq_id], *dp.prompt, dp.id_last, *dp.result); + } + } + + void accept(llama_seq_id seq_id, uint16_t n_accepted) override { + GGML_ASSERT((seq_id < (llama_seq_id) config.size())); + + common_ngram_map_accept(config[seq_id], n_accepted); } }; -struct common_speculative_state_ngram_mod : public common_speculative_state { - common_ngram_mod & mod; +struct common_speculative_state_ngram_mod : public common_speculative_impl { + common_params_speculative_ngram_mod params; - // the last position in the prompt that was added to the ngram container - size_t i_last = 0; - - // length of the last drafted n‑gram (number of tokens returned by draft) - size_t n_draft_last = 0; - - // consecutive accept rounds with low acceptance fraction (< 0.5) - int n_low = 0; + // shared across all sequences + common_ngram_mod mod; // enable trace logging if LLAMA_TRACE is set const bool verbose; - common_speculative_state_ngram_mod(enum common_speculative_type type, common_ngram_mod & mod) - : common_speculative_state(type), mod(mod), verbose(std::getenv("LLAMA_TRACE") != nullptr) { + struct seq_info { + // the last position in the prompt that was added to the ngram container + size_t i_last = 0; + + // length of the last drafted n‑gram (number of tokens returned by draft) + size_t n_draft_last = 0; + + // consecutive accept rounds with low acceptance fraction (< 0.5) + int n_low = 0; + }; + + std::vector sinfos; + + common_speculative_state_ngram_mod( + const common_params_speculative & params, + uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, n_seq) + , params(params.ngram_mod) + , mod(params.ngram_mod.n_match, 4*1024*1024) + , verbose(std::getenv("LLAMA_TRACE") != nullptr) { static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t)); + + LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__, + this->params.n_match, mod.size(), (float)(mod.size_bytes())/1024/1024); + + if (this->params.n_match < 16) { + LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, " + "see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, this->params.n_match); + } + + sinfos.resize(n_seq); } - void begin(const llama_tokens & prompt) override { - i_last = 0; + void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { + auto & sinfo = sinfos[seq_id]; - n_draft_last = 0; + sinfo.i_last = 0; + sinfo.n_draft_last = 0; const size_t n = mod.get_n(); - if (prompt.size() < n) { return; } @@ -706,7 +510,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { mod.add(prompt.data() + i); } - i_last = prompt.size() - n; + sinfo.i_last = prompt.size() - n; const double f = (double)mod.get_used() / (double)mod.size(); LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f); @@ -719,16 +523,17 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { } } - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - const auto & sparams = params.ngram_mod; + void draft_one( + llama_seq_id seq_id, + common_speculative_draft_params & dparams) { + auto & sinfo = sinfos[seq_id]; + auto & result = *dparams.result; - n_draft_last = 0; + const auto & prompt = *dparams.prompt; - const size_t cur_len = prompt_tgt.size(); + sinfo.n_draft_last = 0; + + const size_t cur_len = prompt.size(); if (cur_len < mod.get_n()) { return; } @@ -736,24 +541,24 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { const size_t n = mod.get_n(); // add new ngrams in chunks - if (i_last + 32 < cur_len) { - for (size_t i = i_last; i < cur_len - n; ++i) { - mod.add(prompt_tgt.data() + i); + if (sinfo.i_last + 32 < cur_len) { + for (size_t i = sinfo.i_last; i < cur_len - n; ++i) { + mod.add(prompt.data() + i); } - i_last = cur_len - n; + sinfo.i_last = cur_len - n; } - result.resize(n + sparams.n_max); + result.resize(n + params.n_max); for (size_t i = 0; i < n - 1; ++i) { - result[i] = prompt_tgt[cur_len - n + 1 + i]; + result[i] = prompt.at(cur_len - n + 1 + i); } - result[n - 1] = id_last; + result[n - 1] = dparams.id_last; - for (int i = 0; i < sparams.n_max; ++i) { + for (int i = 0; i < params.n_max; ++i) { const llama_token token = mod.get(result.data() + i); if (token == common_ngram_mod::EMPTY) { - if (i < sparams.n_min) { + if (i < params.n_min) { result.clear(); return; } @@ -771,65 +576,92 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { result.resize(result.size() - n); // store length of drafted n‑gram for later acceptance analysis - n_draft_last = result.size(); + sinfo.n_draft_last = result.size(); } - void accept(uint16_t n_accepted) override { - // compute acceptance fraction if we have a recorded draft length - if (n_draft_last > 0) { - const double f_acc = (double)n_accepted / (double)n_draft_last; - if (f_acc < 0.5) { - n_low++; - if (n_low >= 3) { - if (verbose) { - LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low); - } + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; + } - mod.reset(); - n_low = 0; - i_last = 0; - } - } else { - n_low = 0; + void draft(common_speculative_draft_params_vec & dparams) override { + assert(dparams.size() == n_seq); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + if (!dp.drafting) { + continue; } + + draft_one(seq_id, dp); } } - int32_t n_max(const common_params_speculative & params) const override { - return params.ngram_mod.n_max; - } + void accept(llama_seq_id seq_id, uint16_t n_accepted) override { + auto & sinfo = sinfos[seq_id]; - int32_t n_min(const common_params_speculative & params) const override { - return params.ngram_mod.n_min; + // compute acceptance fraction if we have a recorded draft length + if (sinfo.n_draft_last > 0) { + const double f_acc = (double)n_accepted / (double)sinfo.n_draft_last; + if (f_acc < 0.5) { + sinfo.n_low++; + if (sinfo.n_low >= 3) { + if (verbose) { + LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, sinfo.n_low); + } + + mod.reset(); + sinfo.n_low = 0; + sinfo.i_last = 0; + } + } else { + sinfo.n_low = 0; + } + } } }; -struct common_speculative_state_ngram_cache : public common_speculative_state { +struct common_speculative_state_ngram_cache : public common_speculative_impl { + common_params_speculative_ngram_cache params; + uint16_t n_draft; + bool save_dynamic; bool save_static; - common_ngram_cache ngram_cache_context; - common_ngram_cache ngram_cache_dynamic; - common_ngram_cache ngram_cache_static; + struct seq_info { + size_t cache_size = 0; // number of tokens in n-gram cache - size_t cache_size = 0; // number of tokens in n-gram cache + common_ngram_cache ngram_cache_context; + common_ngram_cache ngram_cache_dynamic; + common_ngram_cache ngram_cache_static; + }; + + std::vector sinfos; common_speculative_state_ngram_cache( - const enum common_speculative_type type, + const common_params_speculative & params, + uint32_t n_seq, + uint16_t n_draft, const std::string & path_static, const std::string & path_dynamic, - uint16_t n_draft, - bool save_dynamic, - bool save_static) - : common_speculative_state(type) + bool save_dynamic, + bool save_static) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, n_seq) + , params(params.ngram_cache) , n_draft(n_draft) , save_dynamic(save_dynamic) , save_static(save_static) { + sinfos.resize(n_seq); + if (!path_static.empty()) { try { - ngram_cache_static = common_ngram_cache_load(path_static); + auto ngram_cache_static = common_ngram_cache_load(path_static); + + for (auto & sinfo : sinfos) { + sinfo.ngram_cache_static = ngram_cache_static; + } } catch (...) { LOG_ERR("failed to open static lookup cache: %s", path_static.c_str()); GGML_ABORT("Couldn't read static lookup cache"); @@ -838,7 +670,11 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { if (!path_dynamic.empty()) { try { - ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); + auto ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); + + for (auto & sinfo : sinfos) { + sinfo.ngram_cache_dynamic = ngram_cache_dynamic; + } } catch (...) { LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str()); GGML_ABORT("Couldn't read dynamic lookup cache"); @@ -846,44 +682,48 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { } } - void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); + void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { + // noop } - void draft( - const common_params_speculative & params, - const llama_tokens & prompt_tgt, - llama_token id_last, - llama_tokens & result) override { - GGML_UNUSED(params); + void draft_one( + llama_seq_id seq_id, + common_speculative_draft_params & dparams) { + auto & sinfo = sinfos[seq_id]; + auto & result = *dparams.result; - if (cache_size < prompt_tgt.size() + 1) { + const auto & prompt = *dparams.prompt; + + if (sinfo.cache_size < prompt.size() + 1) { llama_tokens tokens_new; - tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); - for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { - tokens_new.push_back(prompt_tgt[j]); + tokens_new.reserve(prompt.size() + 1 - sinfo.cache_size); + for (size_t j = sinfo.cache_size; j < prompt.size(); ++j) { + tokens_new.push_back(prompt[j]); } - tokens_new.push_back(id_last); // add the last token + tokens_new.push_back(dparams.id_last); // add the last token - // Update context ngram cache with new prompt_tgt: - common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + // Update context ngram cache with new dparams.prompt: + common_ngram_cache_update( + sinfo.ngram_cache_context, + LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, tokens_new, tokens_new.size(), false); - cache_size = prompt_tgt.size() + 1; + sinfo.cache_size = prompt.size() + 1; } llama_tokens inp; - inp.reserve(prompt_tgt.size() + 1); - for (size_t j = 0; j < prompt_tgt.size(); ++j) { - inp.push_back(prompt_tgt[j]); + inp.reserve(prompt.size() + 1); + for (size_t j = 0; j < prompt.size(); ++j) { + inp.push_back(prompt[j]); } - inp.push_back(id_last); + inp.push_back(dparams.id_last); - result.push_back(id_last); + result.push_back(dparams.id_last); - common_ngram_cache_draft(inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, - ngram_cache_context, - ngram_cache_dynamic, - ngram_cache_static); + common_ngram_cache_draft( + inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + sinfo.ngram_cache_context, + sinfo.ngram_cache_dynamic, + sinfo.ngram_cache_static); if (result.size() > 0) { // delete first token in result (which is the id_last token) @@ -891,24 +731,37 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { } } - void accept(uint16_t n_accepted) override { - // TODO: noop - GGML_UNUSED(n_accepted); + bool process(const llama_batch & /*batch*/) override { + // TODO: implement + return true; } - int32_t n_max(const common_params_speculative & /*params*/) const override { - return n_draft; + void draft(common_speculative_draft_params_vec & dparams) override { + assert(dparams.size() == n_seq); + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + auto & dp = dparams[seq_id]; + if (!dp.drafting) { + continue; + } + + draft_one(seq_id, dp); + } } - int32_t n_min(const common_params_speculative & /*params*/) const override { - return 0; + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { + // noop } }; struct common_speculative { - std::vector> impls; // list of implementations to use and their states + common_speculative_draft_params_vec dparams; - common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) + // list of implementations to use and their states + std::vector> impls; + + // which implementaion was used for a given seq_id + std::vector impl_last; }; static common_ngram_map get_common_ngram_map( @@ -923,45 +776,79 @@ static common_ngram_map get_common_ngram_map( } static common_speculative_state_ngram_cache create_state_ngram_cache( - const std::string & path_static, const std::string & path_dynamic, - const common_speculative_config & config) { + const common_speculative_config & config, + uint32_t n_seq, + const std::string & path_static, + const std::string & path_dynamic) { uint16_t n_draft = 8; // TODO get from config? // TODO bool param in common/common.h to set save_static/save_dynamic? bool save_static = false; bool save_dynamic = false; - common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic); + common_speculative_state_ngram_cache state(config.params, n_seq, n_draft, path_static, path_dynamic, save_static, save_dynamic); return state; } -std::string common_speculative_type_name_str() { +std::string common_speculative_type_name_str(const std::vector & types) { std::string result; - for (size_t i = 0; i < common_speculative_types.size(); i++) { + + for (size_t i = 0; i < types.size(); i++) { if (i > 0) { - result += ", "; + result += ","; } - result += common_speculative_type_to_str(common_speculative_types[i]); + result += common_speculative_type_to_str(types[i]); } return result; } -std::string common_speculative_type_to_str(enum common_speculative_type type) { +const char * common_speculative_all_types_str() { + static std::string all_types_str = []() { + std::vector types; + types.reserve(COMMON_SPECULATIVE_TYPE_COUNT); + for (int i = 0; i < COMMON_SPECULATIVE_TYPE_COUNT; i++) { + types.push_back((common_speculative_type) i); + } + return common_speculative_type_name_str(types); + }(); + return all_types_str.c_str(); +} + +std::string common_speculative_type_to_str(common_speculative_type type) { switch (type) { case COMMON_SPECULATIVE_TYPE_NONE: return "none"; case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft"; case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3"; - case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v"; - case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod"; - case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache"; + case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram-mod"; + case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram-cache"; default: return "unknown"; } } -enum common_speculative_type common_speculative_type_from_name(const std::string & name) { +std::vector common_speculative_types_from_names(const std::vector & names) { + std::vector types; + types.reserve(names.size()); + + for (const auto & name : names) { + auto type = common_speculative_type_from_name_map.find(name); + if (type != common_speculative_type_from_name_map.end()) { + if (type->second == COMMON_SPECULATIVE_TYPE_NONE) { + return std::vector { COMMON_SPECULATIVE_TYPE_NONE }; + } + types.push_back(type->second); + continue; + } + throw std::invalid_argument("unknown speculative type: " + name); + } + + return types; +} + +common_speculative_type common_speculative_type_from_name(const std::string & name) { const auto it = common_speculative_type_from_name_map.find(name); if (it == common_speculative_type_from_name_map.end()) { return COMMON_SPECULATIVE_TYPE_COUNT; @@ -969,34 +856,39 @@ enum common_speculative_type common_speculative_type_from_name(const std::string return it->second; } +static uint32_t common_get_enabled_speculative_configs(const std::vector & configs) { + uint32_t result = 0; + for (size_t i = 0; i < configs.size(); i++) { + result |= (1u << configs[i]); + } + return result; +} + // initialization of the speculative decoding system // -common_speculative * common_speculative_init( - common_params_speculative & params, - llama_context * ctx_tgt) { - llama_context * ctx_dft = nullptr; - if (params.draft.model) { - ctx_dft = llama_init_from_model(params.draft.model, params.draft.cparams); - if (ctx_dft == nullptr) { - LOG_ERR("%s", "failed to create draft context\n"); - return nullptr; - } - } - +common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq) { // Compute the implementations to use based on the config and their order of preference std::vector configs = {}; // list of speculative configs to try { - bool has_draft = !params.draft.mparams.path.empty(); + uint32_t enabled_configs = common_get_enabled_speculative_configs(params.types); + + bool has_draft = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT)); + bool has_draft_model = !params.draft.mparams.path.empty(); + + // bool has_mtp = false; // TODO: add MTP here bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 - bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE); - bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE); - bool has_ngram_map_k = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); - bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V); - bool has_ngram_mod = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD); + bool has_ngram_cache = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_CACHE)); + bool has_ngram_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE)); + bool has_ngram_map_k = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K)); + bool has_ngram_map_k4v = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V)); + bool has_ngram_mod = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MOD)); - // In a more complex implementation we could use the same implementation but with different parameters. - // This was initially used in PR-18471 but removed to simplify the code. + // when adding a new type - update here the logic above + static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 8); + + // this list here defines the priority of the speculators + // the one with highest priority are listed first if (has_ngram_simple) { // This implementation can guess a lot of tokens without any draft model. configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, params)); @@ -1009,53 +901,43 @@ common_speculative * common_speculative_init( configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params)); } if (has_ngram_mod) { - auto & sparams = params.ngram_mod; - - if (!sparams.obj) { - sparams.obj = std::make_shared(sparams.n_match, 4*1024*1024); - - LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__, - sparams.n_match, sparams.obj->size(), (float)(sparams.obj->size_bytes())/1024/1024); - - if (sparams.n_match < 16) { - LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, " - "see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, sparams.n_match); - } - } - configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, params)); } if (has_ngram_cache) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params)); } + if (has_draft) { + if (!has_draft_model) { + LOG_WRN("%s: draft model is not specified - cannot use 'draft' type\n", __func__); + has_draft = false; + } + } else if (has_draft_model) { + LOG_WRN("%s: draft model is specified but 'draft' speculative type is not explicitly enabled - enabling it\n", __func__); + has_draft = true; + } + if (has_draft) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params)); } + // TODO: add MTP here if (has_draft_eagle3) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params)); } } - std::vector> impls = {}; + std::vector> impls = {}; for (const common_speculative_config & config : configs) { - LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str()); + LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(config.type).c_str()); switch (config.type) { case COMMON_SPECULATIVE_TYPE_NONE: break; case COMMON_SPECULATIVE_TYPE_DRAFT: { - const bool use_ckpt = common_context_can_seq_rm(ctx_dft) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; - - impls.push_back(std::make_unique(config.type, - /* .ctx_tgt = */ ctx_tgt, - /* .ctx_dft = */ ctx_dft, - /* .replacements = */ params.draft.replacements, - /* .use_ckpt = */ use_ckpt - )); + impls.push_back(std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_EAGLE3: { - impls.push_back(std::make_unique(config.type)); + impls.push_back(std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { @@ -1069,27 +951,30 @@ common_speculative * common_speculative_init( /* .size_mgram = */ mgram_size_value }; auto state = std::make_unique( - /* .type = */ config.type, - /* .state = */ config_simple + /* .params = */ config.params, + /* .n_seq = */ n_seq, + /* .state = */ config_simple ); impls.push_back(std::move(state)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: { - impls.push_back(std::make_unique( - (config.type), - get_common_ngram_map(config.type, config.params.ngram_map_k) - )); + impls.push_back( + std::make_unique( + config.params, get_common_ngram_map(config.type, config.params.ngram_map_k), n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: { - GGML_ASSERT(config.params.ngram_mod.obj); - impls.push_back(std::make_unique(config.type, *config.params.ngram_mod.obj)); + impls.push_back( + std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: { - auto state = create_state_ngram_cache(params.ngram_cache.lookup_cache_static, params.ngram_cache.lookup_cache_dynamic, config); + auto state = create_state_ngram_cache( + config, n_seq, + params.ngram_cache.lookup_cache_static, + params.ngram_cache.lookup_cache_dynamic); impls.push_back(std::make_unique(state)); break; } @@ -1104,8 +989,9 @@ common_speculative * common_speculative_init( } auto * result = new common_speculative { + /* .dparams = */ common_speculative_draft_params_vec(n_seq), /* .impls = */ std::move(impls), - /* .curr_impl = */ nullptr, + /* .impl_last = */ std::vector(n_seq, nullptr) }; return result; @@ -1119,65 +1005,128 @@ void common_speculative_free(common_speculative * spec) { delete spec; } -void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt) { +common_speculative_draft_params & common_speculative_get_draft_params( + common_speculative * spec, + llama_seq_id seq_id) { + GGML_ASSERT(spec); + GGML_ASSERT(seq_id < (llama_seq_id) spec->dparams.size()); + + return spec->dparams[seq_id]; +} + +void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt) { if (spec == nullptr) { return; } for (auto & impl : spec->impls) { common_time_meas tm(impl->t_begin_us, !impl->gen_perf); - impl->begin(prompt); + impl->begin(seq_id, prompt); impl->n_call_begin++; } } -llama_tokens common_speculative_draft( - common_speculative * spec, - const common_params_speculative & params, - const llama_tokens & prompt_tgt, // specified in target model vocab - llama_token id_last) { - llama_tokens result; +bool common_speculative_process(common_speculative * spec, const llama_batch & batch) { + bool result = true; - spec->curr_impl = nullptr; // reset current implementation + if (spec == nullptr) { + return result; + } for (auto & impl : spec->impls) { - { - common_time_meas tm(impl->t_draft_us, !impl->gen_perf); - impl->draft(params, prompt_tgt, id_last, result); - impl->n_call_draft++; - } - - { - const int n_min = impl->n_min(params); - - if (!result.empty() && (int) result.size() < n_min) { - LOG_DBG("%s: ignoring small draft: %d < %d\n", __func__, (int) result.size(), n_min); - result.clear(); - } - } - - if (!result.empty()) { - LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, - common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(), - impl.get()->n_call_draft, result.size()); - - spec->curr_impl = impl.get(); // set current implementation for stats - impl->n_gen_drafts++; - impl->n_gen_tokens += result.size(); - - break; // we have a draft, so break out of the loop and return it. - } + result = result && impl->process(batch); } return result; } -void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { +void common_speculative_draft(common_speculative * spec) { + if (spec == nullptr) { + return; + } + + auto & dparams = spec->dparams; + + { + int n_drafting = 0; + + for (auto & dp : dparams) { + GGML_ASSERT(!dp.drafting || dp.result->empty()); + + if (dp.drafting) { + n_drafting++; + } + } + + if (n_drafting == 0) { + return; + } + } + + for (auto & impl : spec->impls) { + { + common_time_meas tm(impl->t_draft_us, !impl->gen_perf); + impl->draft(dparams); + impl->n_call_draft++; + } + + int n_drafting = 0; + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) dparams.size(); ++seq_id) { + auto & dp = dparams[seq_id]; + + auto & result = *dp.result; + + // a new draft has been sampled + if (dp.drafting && !result.empty()) { + dp.drafting = false; + + if (dp.n_max > 0) { + if (!result.empty() && (int) result.size() > dp.n_max) { + LOG_DBG("%s: truncating draft to %d tokens\n", __func__, dp.n_max); + result.resize(dp.n_max); + } + } + + if (!result.empty()) { + LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, + common_speculative_type_to_str(impl.get()->type).c_str(), dp.prompt->size(), + impl.get()->n_call_draft, result.size()); + + // remember which implementation was used + spec->impl_last[seq_id] = impl.get(); + + impl->n_gen_drafts++; + impl->n_gen_tokens += result.size(); + } + } + + if (dp.drafting) { + n_drafting++; + } + } + + if (n_drafting == 0) { + break; + } + } + + // these sequences failed to generate a draft + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) dparams.size(); ++seq_id) { + auto & dp = dparams[seq_id]; + + if (dp.drafting) { + dp.drafting = false; + } + } +} + +void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, uint16_t n_accepted) { if (n_accepted == 0) { return; } - common_speculative_state * impl = spec->curr_impl; + common_speculative_impl * impl = spec->impl_last[seq_id]; GGML_ASSERT(impl); @@ -1188,37 +1137,11 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { impl->n_acc_tokens += n_accepted; } - impl->accept(n_accepted); + impl->accept(seq_id, n_accepted); impl->n_call_accept++; } } -int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params) { - if (spec == nullptr) { - return 0; - } - - int32_t n_max = 0; - for (const auto & impl : spec->impls) { - n_max = std::max(n_max, impl->n_max(params)); - } - - return n_max; -} - -int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params) { - if (spec == nullptr) { - return 0; - } - - int32_t n_min = 0; - for (const auto & impl : spec->impls) { - n_min = std::max(n_min, impl->n_min(params)); - } - - return n_min; -} - void common_speculative_print_stats(const common_speculative * spec) { if (spec == nullptr) { return; diff --git a/common/speculative.h b/common/speculative.h index 147447631..51f0b059f 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -5,8 +5,14 @@ struct common_speculative; +// comma separated list the provided types +std::string common_speculative_type_name_str(const std::vector & types); + // comma separated list of all types -std::string common_speculative_type_name_str(); +const char * common_speculative_all_types_str(); + +// parse user provided types +std::vector common_speculative_types_from_names(const std::vector & names); // convert string to type enum common_speculative_type common_speculative_type_from_name(const std::string & name); @@ -14,27 +20,44 @@ enum common_speculative_type common_speculative_type_from_name(const std::string // convert type to string std::string common_speculative_type_to_str(enum common_speculative_type type); -common_speculative * common_speculative_init( - common_params_speculative & params, - llama_context * ctx_tgt); +common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq); void common_speculative_free(common_speculative * spec); +struct common_speculative_draft_params { + // this flag is used to chain the drafts through all the available implementations + // after the first successful draft from an implementation, we set it + // to false to prevent further drafts for that sequence + // at the end of the draft() call, all drafting flags will be reset to false + bool drafting = false; + + // overrides individual configurations (-1 disabled) + // can be used to constraint the max draft based on the remaining context size + int32_t n_max = -1; + + llama_pos n_past; + llama_token id_last; + + // TODO: remove in the future by keeping track of the prompt from the _begin() call and the consecutive accept calls + const llama_tokens * prompt; + + // the generated draft from the last _draft() call + llama_tokens * result; +}; + +common_speculative_draft_params & common_speculative_get_draft_params(common_speculative * spec, llama_seq_id seq_id); + // optionally call once at the beginning of a new generation -void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt); +void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt); -// sample up to n_draft tokens and add them to the batch using the draft model -llama_tokens common_speculative_draft( - common_speculative * spec, - const common_params_speculative & params, - const llama_tokens & prompt, - llama_token id_last); +// process the batch and update the internal state of the speculative context +bool common_speculative_process(common_speculative * spec, const llama_batch & batch); -// informs the speculative decoder that n_accepted tokens were accepted by the target model -void common_speculative_accept(common_speculative * spec, uint16_t n_accepted); +// generate drafts for the sequences specified with `common_speculative_get_draft_params` +void common_speculative_draft(common_speculative * spec); -int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params); -int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params); +// informs the speculative context that n_accepted tokens were accepted by the target model +void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted); // print statistics about the speculative decoding void common_speculative_print_stats(const common_speculative * spec); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 5b61b62a1..5325bcc9e 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -13,20 +13,6 @@ #include #include -struct spec_checkpoint { - int64_t n_tokens = 0; - - std::vector data; - - size_t size() const { - return data.size(); - } - - bool empty() const { - return data.empty(); - } -}; - int main(int argc, char ** argv) { std::setlocale(LC_NUMERIC, "C"); @@ -43,11 +29,6 @@ int main(int argc, char ** argv) { return 1; } - if (params.speculative.draft.mparams.path.empty()) { - LOG_ERR("%s: --model-draft is required\n", __func__); - return 1; - } - // init llama.cpp llama_backend_init(); llama_numa_init(params.numa); @@ -62,18 +43,11 @@ int main(int argc, char ** argv) { model_tgt = llama_init_tgt->model(); ctx_tgt = llama_init_tgt->context(); - // check if the context supports partial sequence removal - const auto ctx_seq_rm = common_context_can_seq_rm(ctx_tgt); - const bool use_ckpt = (ctx_seq_rm == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); - - if (use_ckpt) { - LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n"); - } - const llama_vocab * vocab = llama_model_get_vocab(model_tgt); // load the draft model llama_model_ptr model_dft; + llama_context_ptr ctx_dft; // TODO: simplify this logic { @@ -81,9 +55,6 @@ int main(int argc, char ** argv) { auto params_dft = params; - params_dft.n_parallel = 1; - params_dft.n_ctx = params_spec.n_ctx; - params_dft.n_batch = llama_n_ctx_seq(ctx_tgt); params_dft.devices = params_spec.devices; params_dft.model = params_spec.mparams; params_dft.n_gpu_layers = params_spec.n_gpu_layers; @@ -103,8 +74,19 @@ int main(int argc, char ** argv) { return 1; } - params.speculative.draft.model = model_dft.get(); - params.speculative.draft.cparams = common_context_params_to_llama(params_dft); + auto cparams = common_context_params_to_llama(params_dft); + ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); + + params.speculative.draft.ctx_tgt = ctx_tgt; + params.speculative.draft.ctx_dft = ctx_dft.get(); + } + + // check if the context supports partial sequence removal + const bool use_ckpt_tgt = (common_context_can_seq_rm(ctx_tgt) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); + const bool use_ckpt_dft = (common_context_can_seq_rm(ctx_dft.get()) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); + + if (use_ckpt_tgt) { + LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n"); } // Tokenize the prompt @@ -136,6 +118,8 @@ int main(int argc, char ** argv) { // used to determine end of generation bool has_eos = false; + llama_seq_id seq_id = 0; + // ================================================ // everything until here is standard initialization // the relevant stuff for speculative decoding starts here @@ -146,7 +130,8 @@ int main(int argc, char ** argv) { common_sampler_ptr smpl(common_sampler_init(model_tgt, params.sampling)); // eval the prompt - llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); + llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); + llama_decode(ctx_dft.get(), llama_batch_get_one(inp.data(), inp.size() - 1)); // note: keep the last token separate! llama_token id_last = inp.back(); @@ -160,16 +145,16 @@ int main(int argc, char ** argv) { // init the speculator const auto & params_spec = params.speculative; - struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt); + struct common_speculative * spec = common_speculative_init(params.speculative, 1); - common_speculative_begin(spec, prompt_tgt); + common_speculative_begin(spec, seq_id, prompt_tgt); llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); size_t n_draft = 0; llama_tokens draft; - spec_checkpoint spec_ckpt; + common_prompt_checkpoint ckpt; const auto t_enc_end = ggml_time_us(); @@ -184,40 +169,57 @@ int main(int argc, char ** argv) { // from a cache or lookup tables. // if (draft.empty()) { + ckpt.update_pos( + prompt_tgt.size(), + llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), seq_id), + llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), seq_id)); + + if (use_ckpt_dft) { + ckpt.update_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + } + // generate a new draft - draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last); + common_speculative_get_draft_params(spec, seq_id) = { + /* .drafting = */ true, + /* .n_max = */ -1, + /* .n_past = */ n_past, + /* .id_last = */ id_last, + /* .prompt = */ &prompt_tgt, + /* .result = */ &draft, // output + }; + common_speculative_draft(spec); // save the original draft size n_draft = draft.size(); // save a checkpoint of the target context before evaluating the draft // this allows us to restore the state if partial draft acceptance occurs - if (!draft.empty() && use_ckpt) { - const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_tgt, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - spec_ckpt.data.resize(ckpt_size); + if (!draft.empty()) { + if (use_ckpt_tgt) { + ckpt.update_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + } + } - const size_t n = llama_state_seq_get_data_ext(ctx_tgt, spec_ckpt.data.data(), ckpt_size, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - GGML_ASSERT(n == ckpt_size); + { + ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - spec_ckpt.n_tokens = (int64_t) prompt_tgt.size(); - LOG_DBG("created speculative checkpoint (n_tokens = %" PRId64 ", size = %.3f MiB)\n", - spec_ckpt.n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024); + llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1); } } else { // we have a previous (partial) draft to reuse from checkpoint restoration - if (use_ckpt) { - GGML_ASSERT(!spec_ckpt.empty()); + if (use_ckpt_tgt) { + GGML_ASSERT(!ckpt.empty()); } } // always have a token to evaluate from before - id_last common_batch_clear(batch_tgt); - common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true); + common_batch_add (batch_tgt, id_last, n_past++, { seq_id }, true); // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] { for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); + common_batch_add(batch_tgt, draft[i], n_past + i, { seq_id }, true); } //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); @@ -225,9 +227,15 @@ int main(int argc, char ** argv) { llama_decode(ctx_tgt, batch_tgt); } + // evaluate the same batch with the draft model + { + // TODO: extend to support MTP, Eagle, etc. See server code for reference + llama_decode(ctx_dft.get(), batch_tgt); + } + // only save the sampler sampler state if we use checkpoints common_sampler_ptr smpl_save; - if (use_ckpt) { + if (use_ckpt_tgt) { smpl_save.reset(common_sampler_clone(smpl.get())); } @@ -247,17 +255,24 @@ int main(int argc, char ** argv) { // check for partial draft acceptance: // if the context doesn't support partial sequence removal, restore the checkpoint // and make the accepted tokens the new partial draft for the next iteration - if (use_ckpt && ids.size() - 1 < draft.size()) { + if (use_ckpt_tgt && ids.size() - 1 < draft.size()) { LOG_DBG("partial acceptance: %zu < %zu, restoring checkpoint\n", ids.size() - 1, draft.size()); draft = std::move(ids); - const size_t n = llama_state_seq_set_data_ext(ctx_tgt, spec_ckpt.data.data(), spec_ckpt.size(), 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - GGML_ASSERT(n == spec_ckpt.size()); + { + ckpt.load_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, spec_ckpt.n_tokens, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), seq_id, ckpt.pos_max + 1, -1); + } - prompt_tgt.resize(spec_ckpt.n_tokens); + { + ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + + llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1); + } + + prompt_tgt.resize(ckpt.n_tokens); smpl = std::move(smpl_save); n_past = (int) prompt_tgt.size(); @@ -265,7 +280,7 @@ int main(int argc, char ** argv) { continue; } - common_speculative_accept(spec, ids.size() - 1); + common_speculative_accept(spec, seq_id, ids.size() - 1); // full acceptance: consume the draft and commit accepted tokens n_past += ids.size() - 1; @@ -305,7 +320,8 @@ int main(int argc, char ** argv) { { LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); - llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), seq_id, n_past, -1); + llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, n_past, -1); } if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { diff --git a/include/llama.h b/include/llama.h index 2ea226726..308e8ba9d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -858,6 +858,8 @@ extern "C" { size_t n_token_capacity, size_t * n_token_count_out); +#define LLAMA_STATE_SEQ_FLAGS_NONE 0 + // for backwards-compat #define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1 diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 71a59395e..3d9714ab1 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2475,11 +2475,29 @@ public: } if (need_alloc) { - mbuf_cur = std::move(mbuf); + if (!mbuf_cur.buf || mbuf_cur.total_size != mbuf.total_size) { + mbuf_cur = std::move(mbuf); - mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft)); + mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft)); - LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + } else { + //LLAMA_LOG_INFO("%s: reallocating tensors in '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + + // save the old buffer and allocate the new tensors in it + auto buf = std::move(mbuf_cur.buf); + + mbuf_cur = std::move(mbuf); + + ggml_tallocr talloc = ggml_tallocr_new(buf.get()); + + for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { + ggml_backend_view_init(mbuf_cur.org[i]); + ggml_tallocr_alloc(&talloc, mbuf_cur.cpy[i]); + } + + mbuf_cur.buf = std::move(buf); + } } for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { @@ -2559,8 +2577,7 @@ public: mbuf.org.push_back(ggml_view_1d(mbuf.ctx.get(), rinfo.tensor, n, rinfo.offset)); - auto & view = mbuf.org.back(); - view->buffer = rinfo.tensor->buffer; + ggml_backend_view_init(mbuf.org.back()); } for (auto & [buft, mbuf] : mbufs_new) { diff --git a/tools/cli/README.md b/tools/cli/README.md index bca4da7ef..02c564a29 100644 --- a/tools/cli/README.md +++ b/tools/cli/README.md @@ -195,11 +195,9 @@ | `--spec-draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)
(env: LLAMA_ARG_SPEC_DRAFT_N_MIN) | | `--spec-draft-p-split, --draft-p-split P` | speculative decoding split probability (default: 0.10)
(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT) | | `--spec-draft-p-min, --draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.75)
(env: LLAMA_ARG_SPEC_DRAFT_P_MIN) | -| `--spec-draft-ctx-size, -cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)
(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE) | | `--spec-draft-device, -devd, --device-draft ` | comma-separated list of devices to use for offloading the draft model (none = don't offload)
use --list-devices to see a list of available devices | | `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) | | `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)
(env: LLAMA_ARG_SPEC_DRAFT_MODEL) | -| `--spec-draft-replace, --spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible | | `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none)

(env: LLAMA_ARG_SPEC_TYPE) | | `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) | | `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) | diff --git a/tools/server/README.md b/tools/server/README.md index 024760da6..77eddb335 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -244,11 +244,9 @@ For the full list of features, please refer to [server's changelog](https://gith | `--spec-draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)
(env: LLAMA_ARG_SPEC_DRAFT_N_MIN) | | `--spec-draft-p-split, --draft-p-split P` | speculative decoding split probability (default: 0.10)
(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT) | | `--spec-draft-p-min, --draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.75)
(env: LLAMA_ARG_SPEC_DRAFT_P_MIN) | -| `--spec-draft-ctx-size, -cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)
(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE) | | `--spec-draft-device, -devd, --device-draft ` | comma-separated list of devices to use for offloading the draft model (none = don't offload)
use --list-devices to see a list of available devices | | `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) | | `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)
(env: LLAMA_ARG_SPEC_DRAFT_MODEL) | -| `--spec-draft-replace, --spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible | | `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none)

(env: LLAMA_ARG_SPEC_TYPE) | | `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) | | `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) | diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 637f8d216..0a51390af 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -36,32 +36,6 @@ using json = nlohmann::ordered_json; constexpr int HTTP_POLLING_SECONDS = 1; -static void server_prompt_checkpoint_update(server_prompt_checkpoint & ckpt, llama_context * ctx, int id, int64_t n_tokens, bool on_device, llama_pos pos_min = -1, llama_pos pos_max = -1) { - if (pos_min == -1) { - pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id); - } - if (pos_max == -1) { - pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), id); - } - - auto flags = LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY; - if (on_device) { - flags |= LLAMA_STATE_SEQ_FLAGS_ON_DEVICE; - } - - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, flags); - - ckpt.pos_min = pos_min; - ckpt.pos_max = pos_max; - ckpt.n_tokens = n_tokens; - ckpt.data.resize(checkpoint_size); - - const size_t n = llama_state_seq_get_data_ext(ctx, ckpt.data.data(), checkpoint_size, id, flags); - if (n != checkpoint_size) { - GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); - } -} - // state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 enum slot_state { SLOT_STATE_IDLE, @@ -80,18 +54,19 @@ enum server_state { struct server_slot { int id; - llama_context * ctx = nullptr; - - common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + llama_context * ctx_tgt = nullptr; + llama_context * ctx_dft = nullptr; // multimodal mtmd_context * mctx = nullptr; // speculative decoding + common_speculative * spec; + llama_tokens spec_draft; + llama_tokens spec_prompt; std::vector spec_i_batch; - server_prompt_checkpoint spec_ckpt; - common_speculative_ptr spec; + common_prompt_checkpoint spec_ckpt; // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state // see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837 @@ -135,21 +110,27 @@ struct server_slot { void prompt_save(server_prompt_cache & prompt_cache) const { GGML_ASSERT(prompt.data.size() == 0); - const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); + const size_t cur_size_tgt = llama_state_seq_get_size_ext(ctx_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE); + const size_t cur_size_dft = ctx_dft ? llama_state_seq_get_size_ext(ctx_dft, id, LLAMA_STATE_SEQ_FLAGS_NONE) : 0; - SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n", - (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); + const size_t cur_size = cur_size_tgt + cur_size_dft; - auto * cur = prompt_cache.alloc(prompt, cur_size); + SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB (draft: %.3f MiB)\n", + (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0), cur_size_dft / (1024.0 * 1024.0)); + + auto * cur = prompt_cache.alloc(prompt, cur_size_tgt, cur_size_dft); if (cur == nullptr) { return; } - llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0); + llama_state_seq_get_data_ext(ctx_tgt, cur->data.main.data(), cur_size_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE); + if (ctx_dft) { + llama_state_seq_get_data_ext(ctx_dft, cur->data.drft.data(), cur_size_dft, id, LLAMA_STATE_SEQ_FLAGS_NONE); + } } bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { - bool res = prompt_cache.load(prompt, tokens, ctx, id); + bool res = prompt_cache.load(prompt, tokens, ctx_tgt, ctx_dft, id); if (!res) { SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); } @@ -164,7 +145,11 @@ struct server_slot { SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size()); - llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), id, -1, -1); + if (ctx_dft) { + llama_memory_seq_rm(llama_get_memory(ctx_dft), id, -1, -1); + } + prompt.tokens.clear(); } @@ -222,7 +207,7 @@ struct server_slot { task_prev = std::move(task); task.reset(); - llama_set_sampler(ctx, id, nullptr); + llama_set_sampler(ctx_tgt, id, nullptr); // clear alora start alora_invocation_start = -1; @@ -259,7 +244,7 @@ struct server_slot { return !task->need_embd() || - (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + (llama_get_memory(ctx_tgt) && llama_pooling_type(ctx_tgt) == LLAMA_POOLING_TYPE_LAST); } bool can_batch_with(server_slot & other_slot) const { @@ -310,14 +295,10 @@ struct server_slot { return 0; } - const int n_draft_min = common_speculative_n_min(spec.get(), task->params.speculative); - // determine the max draft that fits the current slot state - int n_draft_max = common_speculative_n_max(spec.get(), task->params.speculative); - // note: slot.prompt is not yet expanded with the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts - n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2); + int n_draft_max = n_ctx - prompt.n_tokens() - 2; if (n_remaining > 0) { n_draft_max = std::min(n_draft_max, n_remaining - 1); @@ -325,61 +306,10 @@ struct server_slot { SLT_DBG(*this, "max possible draft: %d\n", n_draft_max); - if (n_draft_max < n_draft_min) { - SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, n_draft_min); - n_draft_max = 0; - } - return n_draft_max; } void update_batch(llama_batch & batch) { - const int n_draft_max = get_n_draft_max(); - if (n_draft_max > 0) { - GGML_ASSERT(can_speculate()); - - // generate draft tokens in speculative decoding mode - // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] - // perform the speculative drafting for all sequences at the same time in a single batch - const llama_tokens & tokens = prompt.tokens.get_text_tokens(); - - const auto & params_spec = task->params.speculative; - - if (!spec_draft.empty()) { - // we have a previous (partial) draft to reuse - if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { - GGML_ASSERT(!spec_ckpt.empty()); - } - } else { - GGML_ASSERT(spec_i_batch.empty()); - - // generate a new draft - spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled); - n_draft_total += spec_draft.size(); - - if (spec_draft.size() > (size_t) n_draft_max) { - SLT_WRN(*this, "draft size %d exceeds max %d, truncating\n", (int) spec_draft.size(), n_draft_max); - spec_draft.resize(n_draft_max); - } - - if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { - const auto n_tokens = prompt.tokens.size(); - - //const int64_t t_start = ggml_time_us(); - - server_prompt_checkpoint_update(spec_ckpt, ctx, this->id, n_tokens, true); - - //const int64_t t_total = ggml_time_us() - t_start; - //printf("checkpoint total: %f ms\n", t_total / 1000.0); - - SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n", - spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024); - } - } - - GGML_ASSERT(spec_draft.size() <= (size_t) n_draft_max); - } - if (spec_draft.empty()) { // no speculative decoding i_batch = batch.n_tokens; @@ -511,7 +441,7 @@ struct server_slot { ); } - common_speculative_print_stats(spec.get()); + common_speculative_print_stats(spec); } json to_json(bool only_metrics = false) const { @@ -539,7 +469,7 @@ struct server_slot { }; if (!only_metrics) { - res["prompt"] = ptask->tokens.detokenize(ctx, true); + res["prompt"] = ptask->tokens.detokenize(ctx_tgt, true); res["generated"] = generated_text.empty() ? debug_generated_text : generated_text; } } @@ -550,8 +480,13 @@ struct server_slot { void copy_state_to(server_slot & other) const { GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT); - llama_memory_seq_rm(llama_get_memory(ctx), other.id, -1, -1); - llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, -1, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), other.id, -1, -1); + llama_memory_seq_cp(llama_get_memory(ctx_tgt), id, other.id, -1, -1); + + if (ctx_dft) { + llama_memory_seq_rm(llama_get_memory(ctx_dft), other.id, -1, -1); + llama_memory_seq_cp(llama_get_memory(ctx_dft), id, other.id, -1, -1); + } other.n_decoded = n_decoded; other.n_remaining = n_remaining; @@ -642,7 +577,8 @@ public: // only use these pointers outside of this class: // - when not in sleeping state // - and, with thread-safe APIs (e.g., tokenizer calls) - llama_model * model = nullptr; + llama_model * model_tgt = nullptr; + mtmd_context * mctx = nullptr; const llama_vocab * vocab = nullptr; @@ -669,11 +605,17 @@ private: // note: keep these alive - they determine the lifetime of the model, context, etc. common_init_result_ptr llama_init; - llama_context * ctx = nullptr; + llama_context * ctx_tgt = nullptr; llama_batch batch {}; llama_model_ptr model_dft; + llama_context_ptr ctx_dft; + + common_context_seq_rm_type ctx_tgt_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + common_context_seq_rm_type ctx_dft_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + + common_speculative_ptr spec; bool add_bos_token = true; @@ -708,18 +650,12 @@ private: void destroy() { llama_init.reset(); - ctx = nullptr; - model = nullptr; + ctx_tgt = nullptr; + model_tgt = nullptr; mtmd_free(mctx); mctx = nullptr; - for (server_slot & slot : slots) { - if (slot.can_speculate()) { - slot.spec.reset(); - } - } - llama_batch_free(batch); } @@ -759,17 +695,17 @@ private: llama_init = common_init_from_params(params_base); - model = llama_init->model(); - ctx = llama_init->context(); + model_tgt = llama_init->model(); + ctx_tgt = llama_init->context(); - if (model == nullptr) { + if (model_tgt == nullptr) { SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); return false; } - vocab = llama_model_get_vocab(model); + vocab = llama_model_get_vocab(model_tgt); - n_ctx = llama_n_ctx(ctx); + n_ctx = llama_n_ctx(ctx_tgt); add_bos_token = llama_vocab_get_add_bos(vocab); @@ -781,9 +717,6 @@ private: auto params_dft = params_base; - params_dft.n_parallel = 1; - params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx; - params_dft.n_batch = llama_n_ctx_seq(ctx); params_dft.devices = params_spec.devices; params_dft.model = params_spec.mparams; params_dft.n_gpu_layers = params_spec.n_gpu_layers; @@ -805,8 +738,13 @@ private: return false; } - params_base.speculative.draft.model = model_dft.get(); - params_base.speculative.draft.cparams = common_context_params_to_llama(params_dft); + auto cparams = common_context_params_to_llama(params_dft); + ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); + + ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); + + params_base.speculative.draft.ctx_tgt = ctx_tgt; + params_base.speculative.draft.ctx_dft = ctx_dft.get(); } std::string & mmproj_path = params_base.mmproj.path; @@ -826,7 +764,7 @@ private: mparams.image_max_tokens = params_base.image_max_tokens; mparams.media_marker = get_media_marker(); - mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); + mctx = mtmd_init_from_file(mmproj_path.c_str(), model_tgt, mparams); if (mctx == nullptr) { SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); return false; @@ -844,7 +782,7 @@ private: } } - if (!llama_memory_can_shift(llama_get_memory(ctx))) { + if (!llama_memory_can_shift(llama_get_memory(ctx_tgt))) { if (params_base.ctx_shift) { params_base.ctx_shift = false; SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled"); @@ -856,14 +794,14 @@ private: } } - if (llama_model_n_swa(model) == 0) { + if (llama_model_n_swa(model_tgt) == 0) { if (params_base.swa_full) { params_base.swa_full = false; SRV_WRN("%s\n", "swa_full is not supported by this model, it will be disabled"); } } - n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model); + n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model_tgt); // Necessary similarity of prompt for slot selection slot_prompt_similarity = params_base.slot_prompt_similarity; @@ -871,9 +809,9 @@ private: // setup slots SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); - const int n_ctx_train = llama_model_n_ctx_train(model); + const int n_ctx_train = llama_model_n_ctx_train(model_tgt); - int n_ctx_slot = llama_n_ctx_seq(ctx); + int n_ctx_slot = llama_n_ctx_seq(ctx_tgt); if (n_ctx_slot > n_ctx_train) { SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train); n_ctx_slot = n_ctx_train; @@ -881,12 +819,12 @@ private: slots.clear(); - const auto ctx_seq_rm_type = common_context_can_seq_rm(ctx); - if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) { + ctx_tgt_seq_rm_type = common_context_can_seq_rm(ctx_tgt); + if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) { SRV_WRN("%s", "speculative decoding not supported by this context\n"); } - if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { + if (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { SRV_WRN("%s", "speculative decoding will use checkpoints\n"); } @@ -895,27 +833,33 @@ private: slots.emplace_back(); } + // try speculative decoding + if (ctx_tgt_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) { + try { + spec.reset(common_speculative_init(params_base.speculative, params_base.n_parallel)); + } catch (const std::exception & e) { + SRV_ERR("failed to initialize speculative decoding context: %s\n", e.what()); + } + } + + if (spec) { + SRV_INF("%s", "speculative decoding context initialized\n"); + } else { + ctx_dft.reset(); + } + for (int i = 0; i < params_base.n_parallel; i++) { server_slot & slot = slots[i]; - slot.id = i; - slot.ctx = ctx; - slot.n_ctx = n_ctx_slot; - - slot.ctx_seq_rm_type = ctx_seq_rm_type; + slot.id = i; + slot.ctx_tgt = ctx_tgt; + slot.ctx_dft = ctx_dft.get(); + slot.spec = spec.get(); + slot.n_ctx = n_ctx_slot; slot.mctx = mctx; slot.prompt.tokens.has_mtmd = mctx != nullptr; - // try speculative decoding - if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) { - slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx)); - - if (slot.spec) { - SLT_INF(slot, "%s", "speculative decoding context initialized\n"); - } - } - SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); slot.callback_on_release = [this](int id_slot) { @@ -946,7 +890,7 @@ private: // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { - const int32_t n_batch = llama_n_batch(ctx); + const int32_t n_batch = llama_n_batch(ctx_tgt); batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); } @@ -990,8 +934,9 @@ private: // unlike load_model(), this is only called once during initialization bool init() { - GGML_ASSERT(ctx != nullptr); - GGML_ASSERT(model != nullptr); + GGML_ASSERT(ctx_tgt != nullptr); + GGML_ASSERT(model_tgt != nullptr); + GGML_ASSERT(!sleeping); // wiring up server queues @@ -1037,7 +982,7 @@ private: common_chat_templates_ptr chat_templates; try { - chat_templates = common_chat_templates_init(model, params_base.chat_template); + chat_templates = common_chat_templates_init(model_tgt, params_base.chat_template); LOG_INF("%s: chat template, example_format: '%s'\n", __func__, common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); @@ -1300,7 +1245,7 @@ private: } } - if (!task.tokens.validate(ctx)) { + if (!task.tokens.validate(ctx_tgt)) { send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); return false; } @@ -1310,7 +1255,7 @@ private: // initialize samplers if (task.need_sampling()) { try { - slot.smpl.reset(common_sampler_init(model, task.params.sampling)); + slot.smpl.reset(common_sampler_init(model_tgt, task.params.sampling)); } catch (std::exception & e) { std::string err_msg = std::string("Failed to initialize samplers: ") + e.what(); send_error(task, err_msg, ERROR_TYPE_INVALID_REQUEST); @@ -1324,16 +1269,16 @@ private: backend_sampling &= task.params.sampling.backend_sampling; // TODO: speculative decoding requires multiple samples per batch - not supported yet - backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0); + backend_sampling &= !(slot.can_speculate()); // TODO: getting pre sampling logits is not yet supported with backend sampling backend_sampling &= !need_pre_sample_logits; // TODO: tmp until backend sampling is fully implemented if (backend_sampling) { - llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get())); + llama_set_sampler(ctx_tgt, slot.id, common_sampler_get(slot.smpl.get())); } else { - llama_set_sampler(ctx, slot.id, nullptr); + llama_set_sampler(ctx_tgt, slot.id, nullptr); } SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str()); @@ -1512,13 +1457,13 @@ private: result.probs.push_back({ cur_p->data[i].id, - common_token_to_piece(ctx, cur_p->data[i].id, special), + common_token_to_piece(ctx_tgt, cur_p->data[i].id, special), cur_p->data[i].p }); } } else { // TODO: optimize this with min-p optimization - std::vector cur = get_token_probabilities(ctx, idx); + std::vector cur = get_token_probabilities(ctx_tgt, idx); const size_t max_probs = cur.size(); const size_t n_probs = std::min(max_probs, n_probs_request); @@ -1536,7 +1481,7 @@ private: for (size_t i = 0; i < n_probs; i++) { result.probs.push_back({ cur[i].id, - common_token_to_piece(ctx, cur[i].id, special), + common_token_to_piece(ctx_tgt, cur[i].id, special), cur[i].p }); } @@ -1639,7 +1584,7 @@ private: res->tokens = std::move(slot.generated_tokens); } res->timings = slot.get_timings(); - res->prompt = slot.task->tokens.detokenize(ctx, true); + res->prompt = slot.task->tokens.detokenize(ctx_tgt, true); res->response_fields = std::move(slot.task->params.response_fields); res->truncated = slot.truncated; @@ -1662,7 +1607,7 @@ private: // populate res.probs_output if (slot.task->params.sampling.n_probs > 0) { if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) { - const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); + const llama_tokens stop_word_toks = common_tokenize(ctx_tgt, slot.stopping_word, false); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); res->probs_output = std::vector( @@ -1687,7 +1632,7 @@ private: res->n_tokens = slot.task->n_tokens(); res->res_type = slot.task->params.res_type; - const int n_embd_out = llama_model_n_embd_out(model); + const int n_embd_out = llama_model_n_embd_out(model_tgt); std::vector embd_res(n_embd_out, 0.0f); @@ -1697,10 +1642,10 @@ private: } const float * embd = nullptr; - if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) { - embd = llama_get_embeddings_ith(ctx, i); + if (llama_pooling_type(slot.ctx_tgt) == LLAMA_POOLING_TYPE_NONE) { + embd = llama_get_embeddings_ith(slot.ctx_tgt, i); } else { - embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + embd = llama_get_embeddings_seq(slot.ctx_tgt, batch.seq_id[i][0]); } if (embd == nullptr) { @@ -1711,7 +1656,7 @@ private: } // normalize only when there is pooling - if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { + if (llama_pooling_type(slot.ctx_tgt) != LLAMA_POOLING_TYPE_NONE) { common_embd_normalize(embd, embd_res.data(), n_embd_out, slot.task->params.embd_normalize); res->embedding.push_back(embd_res); break; @@ -1736,9 +1681,9 @@ private: continue; } - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + const float * embd = llama_get_embeddings_seq(ctx_tgt, batch.seq_id[i][0]); if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); + embd = llama_get_embeddings_ith(ctx_tgt, i); } if (embd == NULL) { @@ -1843,18 +1788,22 @@ private: const auto & cur = slot.prompt.checkpoints.front(); SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024); slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); } auto & cur = slot.prompt.checkpoints.emplace_back(); - server_prompt_checkpoint_update(cur, ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, false, pos_min, pos_max); + + cur.update_pos(slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max); + + cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, - cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + cur.pos_max, cur.n_tokens, (float) cur.size() / 1024 / 1024); } void process_single_task(server_task && task) { @@ -2009,7 +1958,7 @@ private: std::string filepath = task.slot_action.filepath; const llama_tokens & tokens = slot->prompt.tokens.get_tokens(); - const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); + const size_t nwrite = llama_state_seq_save_file(ctx_tgt, filepath.c_str(), slot->id, tokens.data(), token_count); const int64_t t_end = ggml_time_us(); const double t_save_ms = (t_end - t_start) / 1000.0; @@ -2048,7 +1997,7 @@ private: llama_tokens tokens; tokens.resize(slot->n_ctx); size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); + size_t nread = llama_state_seq_load_file(ctx_tgt, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); if (nread == 0) { slot->prompt.tokens.clear(); // KV may already been invalidated? send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); @@ -2207,8 +2156,13 @@ private: SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); - llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard); - llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); + llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, n_keep , n_keep + n_discard); + llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); + + if (ctx_dft) { + llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, n_keep , n_keep + n_discard); + llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard); + } // add generated tokens to cache // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481 @@ -2236,12 +2190,10 @@ private: // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; - auto accept_special_token = [&](server_slot & slot, llama_token token) { - return params_base.special || - slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); - }; + std::vector generating; + std::vector drafting; - // first, add sampled tokens from any ongoing sequences + // determine which slots are generating and drafting for (auto & slot : slots) { if (slot.state != SLOT_STATE_GENERATING) { continue; @@ -2254,12 +2206,103 @@ private: continue; } + generating.push_back(&slot); + + if (spec) { + common_speculative_get_draft_params(spec.get(), slot.id).drafting = false; + + const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + const bool use_ckpt_dft = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + + const int n_draft_max = slot.get_n_draft_max(); + + if (n_draft_max > 0) { + GGML_ASSERT(slot.can_speculate()); + + if (!slot.spec_draft.empty()) { + // we have a previous (partial) draft to reuse + if (use_ckpt_tgt) { + GGML_ASSERT(!slot.spec_ckpt.empty()); + } + } else { + GGML_ASSERT(slot.spec_i_batch.empty()); + + slot.spec_ckpt.update_pos( + slot.prompt.n_tokens(), + llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id), + llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id)); + + if (use_ckpt_dft) { + slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + } + + slot.spec_prompt = slot.prompt.tokens.get_text_tokens(); + + common_speculative_get_draft_params(spec.get(), slot.id) = { + /* .drafting = */ true, + /* .n_max = */ n_draft_max, + /* .n_past = */ slot.prompt.n_tokens(), + /* .id_last = */ slot.sampled, + /* .prompt = */ &slot.spec_prompt, + /* .result = */ &slot.spec_draft, + }; + + drafting.push_back(&slot); + } + } + } + } + + // generate the actual drafts (if any) + { + common_speculative_draft(spec.get()); + } + + // make checkpoints if needed + for (auto * slot_ptr : drafting) { + auto & slot = *slot_ptr; + + auto & draft = slot.spec_draft; + auto & ckpt = slot.spec_ckpt; + + slot.n_draft_total += draft.size(); + + // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] + if (ctx_dft) { + ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + + llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, ckpt.pos_max + 1, -1); + } + + if (!draft.empty()) { + const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + + if (use_ckpt_tgt) { + //const int64_t t_start = ggml_time_us(); + + ckpt.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + + //const int64_t t_total = ggml_time_us() - t_start; + //printf("checkpoint total: %f ms\n", t_total / 1000.0); + + SLT_DBG(slot, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %d, size = %.3f MiB, draft = %.3f MiB)\n", + ckpt.pos_min, ckpt.pos_max, slot.prompt.n_tokens(), + (float) ckpt.size() / 1024 / 1024, + (float) ckpt.data_dft.size() / 1024 / 1024); + } + } + } + + // update the batch with the sampled/drafted tokens + for (auto * slot_ptr : generating) { + auto & slot = *slot_ptr; + slot.update_batch(batch); } // process in chunks of params.n_batch - int32_t n_batch = llama_n_batch(ctx); - int32_t n_ubatch = llama_n_ubatch(ctx); + int32_t n_batch = llama_n_batch(ctx_tgt); + int32_t n_ubatch = llama_n_ubatch(ctx_tgt); float alora_scale = -1.0f; size_t alora_disabled_id = 0; @@ -2303,12 +2346,12 @@ private: /*if (1) { // first 16 tokens (avoid flooding logs) for (int i = 0; i < std::min(16, input_tokens.size()); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_tgt, input_tokens[i]).c_str()); } } else { // all for (int i = 0; i < (int) input_tokens.size(); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx_tgt, input_tokens[i]).c_str()); } }*/ @@ -2327,7 +2370,7 @@ private: } // TODO: support memory-less logits computation - if (slot.task->need_logits() && !llama_get_memory(ctx)) { + if (slot.task->need_logits() && !llama_get_memory(ctx_tgt)) { send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); slot.release(); continue; @@ -2379,7 +2422,7 @@ private: const auto n_cache_reuse = slot.task->params.n_cache_reuse; const bool can_cache_reuse = - llama_memory_can_shift(llama_get_memory(ctx)) && + llama_memory_can_shift(llama_get_memory(ctx_tgt)) && !slot.prompt.tokens.has_mtmd; if (!can_cache_reuse && n_cache_reuse > 0) { @@ -2413,13 +2456,18 @@ private: if (n_match >= (size_t) n_cache_reuse) { SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); //for (size_t i = head_p; i < head_p + n_match; i++) { - // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx_tgt, prompt_tokens[i]).c_str()); //} const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; - llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c); - llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); + llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, head_p, head_c); + llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, head_c, head_c + n_match, kv_shift); + + if (ctx_dft) { + llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, head_p, head_c); + llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, head_c, head_c + n_match, kv_shift); + } for (size_t i = 0; i < n_match; i++) { slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]); @@ -2446,7 +2494,7 @@ private: const auto pos_min_thold = std::max(0, pos_next - n_swa); if (n_past > 0 && n_past < slot.prompt.n_tokens()) { - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); if (pos_min == -1) { SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); @@ -2475,14 +2523,14 @@ private: { const auto token = slot.prompt.tokens[i]; - const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; + const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_tgt, token) : "[mtmd]"; ss0 << piece; st0 << std::setw(8) << token; } { const auto token = slot.task->tokens[i]; - const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; + const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx_tgt, token) : "[mtmd]"; ss1 << piece; st1 << std::setw(8) << token; } @@ -2514,18 +2562,13 @@ private: if (!do_reset) { // restore the context checkpoint - const size_t checkpoint_size = it->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - if (n != checkpoint_size) { - SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024); - do_reset = true; - //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); - } else { - pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max)); - n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens); - SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) checkpoint_size / 1024 / 1024); - } + it->load_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + it->load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max)); + n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens); + SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_past = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, n_past, (float) it->size() / 1024 / 1024); } if (do_reset) { @@ -2542,7 +2585,7 @@ private: for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { const auto & cur = *it; if (cur.pos_max > pos_next) { - SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, pos_next = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, pos_next, (float) cur.size() / 1024 / 1024); it = slot.prompt.checkpoints.erase(it); } else { ++it; @@ -2582,14 +2625,18 @@ private: SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); - if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) { + if (!llama_memory_seq_rm(llama_get_memory(ctx_tgt), slot.id, p0, -1)) { SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); slot.prompt_clear(true); // there is no common part left slot.n_prompt_tokens_cache = 0; - } + } else { + if (ctx_dft && !llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, p0, -1)) { + GGML_ABORT("failed to truncate draft context\n"); + } + } // If using an alora, there may be uncached tokens that come // before the invocation sequence. When this happens, the @@ -2615,7 +2662,7 @@ private: // - the model does not support partial sequence removal // - the model uses SWA (and we are not using `swa_full`) do_checkpoint = do_checkpoint && ( - (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) || + (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) || (n_swa > 0)); bool has_mtmd = false; @@ -2624,7 +2671,7 @@ private: while (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { // process the image size_t n_tokens_out = 0; - int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); + int32_t res = input_tokens.process_chunk(ctx_tgt, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); if (res != 0) { SLT_ERR(slot, "failed to process image, res = %d\n", res); send_error(slot, "failed to process image", ERROR_TYPE_SERVER); @@ -2632,6 +2679,16 @@ private: continue; } + if (ctx_dft) { + // TODO: in the future, figure out how to infuse target embeddings to the images + // for now, we skip this for simplicity + // maybe we simply need to call `common_speculative_process()` on the mtmd batches in the `process_chunk` above? + res = input_tokens.process_chunk(ctx_dft.get(), mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); + if (res != 0) { + GGML_ABORT("failed to process multi-modal data on draft context\n"); + } + } + slot.n_prompt_tokens_processed += n_tokens_out; // add the image chunk to cache @@ -2733,8 +2790,8 @@ private: SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); } - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); - const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id); // no need for empty or small checkpoints do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64); @@ -2765,9 +2822,14 @@ private: SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + auto accept_special_token = [&](server_slot & slot, llama_token token) { + return params_base.special || + slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); + }; + if (slot_batched) { // apply lora, only need to do it once per batch - common_set_adapter_lora(ctx, slot_batched->lora); + common_set_adapter_lora(ctx_tgt, slot_batched->lora); // if the lora is temporarily disabled for an alora, re-enable it // for next time @@ -2776,7 +2838,7 @@ private: slot_batched->lora[alora_disabled_id].scale = alora_scale; } - llama_set_embeddings(ctx, slot_batched->task->need_embd()); + llama_set_embeddings(ctx_tgt, slot_batched->task->need_embd()); } if (batch.n_tokens == 0) { @@ -2805,7 +2867,7 @@ private: batch.logits + i, }; - const int ret = llama_decode(ctx, batch_view); + const int ret = llama_decode(ctx_tgt, batch_view); metrics.on_decoded(slots); @@ -2858,11 +2920,63 @@ private: continue; // continue loop of n_batch } + // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] + // for now, always re-evaluate for simplicity + // ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384 + // + // | spec type | need re-eval | + // | --- | --- | + // | draft model | no | because the draft model does not use embeddings from the target + // | MTP (std) | yes | + // | MTP Gemma4 | no | because the KV cache is shared + // | Eagle3 | yes | + // | DFlash | yes | https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4405406982 + // + // note: this logic is now moved in `common_speculative_process()` + // keeping the sketch here until for a bit, until the logic is finalized + // + //if (ctx_dft) { + // // TODO: update as needed for MTP, Eagle3, etc. + // const bool need_tgt_embd = false; + + // if (need_tgt_embd) { + // llama_synchronize(ctx_tgt); + // } + + // // the logic here varies depending on the speculative decoding method + // // - some draft contexts require embeddings from the target context, others don't + // // - some draft contexts involve an encoder step to transform the target embeddings to draft embeddings + // // TODO: extract this in a function ? + // { + // // TODO: hook the embeddings from the last target batch here + // if (llama_model_has_encoder(model_dft.get())) { + // //llama_encode(ctx_dft, ...); + + // GGML_ABORT("not implemented yet\n"); + // } + + // const int ret = llama_decode(ctx_dft.get(), batch_view); + + // if (ret != 0) { + // SRV_ERR("failed to decode draft batch, ret = %d\n", ret); + + // // TODO: handle error + // break; + // } + // } + //} + if (!common_speculative_process(spec.get(), batch_view)) { + SRV_ERR("%s", "failed to process speculative batch\n"); + + // TODO: handle error + break; + } + // move the head of the batch forward with the number of tokens we just processed i_next = i + n_tokens; // on successful decode, restore the original batch size - n_batch = llama_n_batch(ctx); + n_batch = llama_n_batch(ctx_tgt); // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too for (auto & slot : slots) { @@ -2921,7 +3035,7 @@ private: slot.state = SLOT_STATE_GENERATING; if (slot.can_speculate()) { - common_speculative_begin(slot.spec.get(), slot.prompt.tokens.get_text_tokens()); + common_speculative_begin(spec.get(), slot.id, slot.prompt.tokens.get_text_tokens()); } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots @@ -2933,7 +3047,7 @@ private: const int tok_idx = slot.i_batch - i; - llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx, tok_idx); + llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx); slot.i_batch = -1; @@ -2954,7 +3068,7 @@ private: completion_token_output result; result.tok = id; - result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok)); + result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs if (slot.task->params.sampling.n_probs > 0) { @@ -2985,23 +3099,23 @@ private: // verify and try to accept the draft { - const bool use_ckpt = slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; // only save the sampler sampler state if we use checkpoints common_sampler_ptr smpl_save; - if (use_ckpt) { + if (use_ckpt_tgt) { smpl_save.reset(common_sampler_clone(slot.smpl.get())); } GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); - auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft); + auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft); slot.spec_i_batch.clear(); GGML_ASSERT(accepted.size() >= 1); // check for partial draft acceptance if (accepted.size() < slot.spec_draft.size() + 1) { - if (use_ckpt) { + if (use_ckpt_tgt) { if (trace > 0) { SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size()); } @@ -3011,16 +3125,19 @@ private: const auto & ckpt = slot.spec_ckpt; - SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", - ckpt.pos_min, ckpt.pos_max, ckpt.size()); + SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size()); - const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); - if (n != ckpt.size()) { - GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", - __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n); + { + ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + + llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, ckpt.pos_max + 1, -1); } - llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, ckpt.pos_max + 1, -1); + if (slot.ctx_dft) { + ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); + + llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, ckpt.pos_max + 1, -1); + } slot.prompt.tokens.keep_first(ckpt.n_tokens); slot.smpl = std::move(smpl_save); @@ -3033,7 +3150,7 @@ private: SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft); } - common_speculative_accept(slot.spec.get(), accepted.size() - 1); + common_speculative_accept(spec.get(), slot.id, accepted.size() - 1); slot.spec_draft = std::move(accepted); } @@ -3055,13 +3172,16 @@ private: slot.sampled = ids.back(); // last accepted token SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); - llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, slot.prompt.tokens.pos_next(), -1); + llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, slot.prompt.tokens.pos_next(), -1); + if (slot.ctx_dft) { + llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, slot.prompt.tokens.pos_next(), -1); + } for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; result.tok = ids[i]; - result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok)); + result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // set later // TODO: set result.probs @@ -3113,7 +3233,7 @@ void server_context::terminate() { } llama_context * server_context::get_llama_context() const { - return impl->ctx; + return impl->ctx_tgt; } server_response_reader server_context::get_response_reader() { @@ -3123,8 +3243,8 @@ server_response_reader server_context::get_response_reader() { server_context_meta server_context::get_meta() const { auto bos_id = llama_vocab_bos(impl->vocab); auto eos_id = llama_vocab_eos(impl->vocab); - auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, bos_id, true) : ""; - auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, eos_id, true) : ""; + auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, bos_id, true) : ""; + auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx_tgt, eos_id, true) : ""; return server_context_meta { /* build_info */ std::string(llama_build_info()), @@ -3137,7 +3257,7 @@ server_context_meta server_context::get_meta() const { /* has_inp_audio */ impl->chat_params.allow_audio, /* json_webui_settings */ impl->json_webui_settings, /* slot_n_ctx */ impl->get_slot_n_ctx(), - /* pooling_type */ llama_pooling_type(impl->ctx), + /* pooling_type */ llama_pooling_type(impl->ctx_tgt), /* chat_params */ impl->chat_params, /* chat_template_caps */ common_chat_templates_get_caps(impl->chat_params.tmpls.get()), @@ -3155,10 +3275,10 @@ server_context_meta server_context::get_meta() const { /* model_vocab_type */ llama_vocab_type(impl->vocab), /* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab), - /* model_n_ctx_train */ llama_model_n_ctx_train(impl->model), - /* model_n_embd_inp */ llama_model_n_embd(impl->model), - /* model_n_params */ llama_model_n_params(impl->model), - /* model_size */ llama_model_size(impl->model), + /* model_n_ctx_train */ llama_model_n_ctx_train(impl->model_tgt), + /* model_n_embd_inp */ llama_model_n_embd(impl->model_tgt), + /* model_n_params */ llama_model_n_params(impl->model_tgt), + /* model_size */ llama_model_size(impl->model_tgt), }; } @@ -4045,7 +4165,7 @@ void server_routes::init_routes() { std::vector tasks; tasks.reserve(documents.size()); for (size_t i = 0; i < documents.size(); i++) { - auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); + auto tmp = format_prompt_rerank(ctx_server.model_tgt, ctx_server.vocab, ctx_server.mctx, query, documents[i]); server_task task = server_task(SERVER_TASK_TYPE_RERANK); task.id = rd.get_new_id(); task.tokens = std::move(tmp); diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 45e5168fa..b9b4a704a 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -76,7 +76,7 @@ json task_params::to_json(bool only_metrics) const { {"reasoning_in_content", chat_parser_params.reasoning_in_content}, {"generation_prompt", chat_parser_params.generation_prompt}, {"samplers", samplers}, - {"speculative.type", common_speculative_type_to_str(speculative.type)}, + {"speculative.types", common_speculative_type_name_str(speculative.types)}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, {"backend_sampling", sampling.backend_sampling}, @@ -133,7 +133,7 @@ json task_params::to_json(bool only_metrics) const { {"reasoning_in_content", chat_parser_params.reasoning_in_content}, {"generation_prompt", chat_parser_params.generation_prompt}, {"samplers", samplers}, - {"speculative.type", common_speculative_type_to_str(speculative.type)}, + {"speculative.types", common_speculative_type_name_str(speculative.types)}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, {"backend_sampling", sampling.backend_sampling}, @@ -296,6 +296,8 @@ task_params server_task::params_from_json_cmpl( params.speculative = defaults.speculative; + // TODO: to keep things simple, we disable speculative parameter adjustments for now +#if 0 // TODO: for now, be able to adjust only the draft-model based speculative parameters params.speculative.draft.n_min = json_value(data, "speculative.n_min", defaults.speculative.draft.n_min); params.speculative.draft.n_max = json_value(data, "speculative.n_max", defaults.speculative.draft.n_max); @@ -305,7 +307,6 @@ task_params server_task::params_from_json_cmpl( params.speculative.draft.n_min = std::max(params.speculative.draft.n_min, 0); params.speculative.draft.n_max = std::max(params.speculative.draft.n_max, 0); -#if 0 // for debugging and research purposes params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type))); @@ -1981,7 +1982,7 @@ size_t server_prompt_cache::n_tokens() const { return res; } -server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) { +server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_tgt, size_t state_size_dft) { // first check if the current state is contained fully in the cache for (auto it = states.begin(); it != states.end(); ++it) { const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens); @@ -2005,11 +2006,13 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t } } - std::vector state_data; + std::vector state_data_tgt; + std::vector state_data_dft; // check if we can allocate enough memory for the new state try { - state_data.resize(state_size); + state_data_tgt.resize(state_size_tgt); + state_data_dft.resize(state_size_dft); } catch (const std::bad_alloc & e) { SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what()); @@ -2022,17 +2025,19 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t return nullptr; } - auto & cur = states.emplace_back(); - cur = { + states.push_back({ /*.tokens =*/ prompt.tokens.clone(), - /*.data =*/ std::move(state_data), + /*.data =*/ { + /*.main =*/ std::move(state_data_tgt), + /*.drft =*/ std::move(state_data_dft), + }, /*.checkpoints =*/ prompt.checkpoints, - }; + }); - return &cur; + return &states.back(); } -bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { +bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_tgt, llama_context * ctx_dft, int32_t id_slot) { const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins @@ -2065,16 +2070,39 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok if (it_best != states.end()) { SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); - const size_t size = it_best->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0); - if (n != size) { - SRV_WRN("failed to restore state with size %zu\n", size); + { + auto & data = it_best->data.main; - return false; + const size_t size = data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx_tgt, data.data(), size, id_slot, 0); + if (n != size) { + SRV_WRN("failed to restore state with size %zu\n", size); + + return false; + } + + data.clear(); + data.shrink_to_fit(); } - it_best->data.clear(); - it_best->data.shrink_to_fit(); + { + auto & data = it_best->data.drft; + + if (!data.empty()) { + GGML_ASSERT(ctx_dft); + + const size_t size = data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx_dft, data.data(), size, id_slot, 0); + if (n != size) { + SRV_WRN("failed to restore state with size %zu\n", size); + + return false; + } + + data.clear(); + data.shrink_to_fit(); + } + } prompt = std::move(*it_best); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 289e1fb8d..64bdecd79 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -565,42 +565,29 @@ struct server_task_result_apply_lora : server_task_result { virtual json to_json() override; }; -struct server_prompt_checkpoint { - llama_pos pos_min; - llama_pos pos_max; - - int64_t n_tokens; - - std::vector data; +struct server_prompt_data { + std::vector main; + std::vector drft; size_t size() const { - return data.size(); - } - - bool empty() const { - return data.empty(); - } - - void clear() { - pos_min = 0; - pos_max = 0; - n_tokens = 0; - data.clear(); + return main.size() + drft.size(); } }; struct server_prompt { server_tokens tokens; - std::vector data; + server_prompt_data data; - std::list checkpoints; + std::list checkpoints; size_t size() const { - size_t res = data.size(); + size_t res = 0; - for (const auto & checkpoint : checkpoints) { - res += checkpoint.size(); + res += data.size(); + + for (const auto & ckpt : checkpoints) { + res += ckpt.size(); } return res; @@ -614,7 +601,7 @@ struct server_prompt { return server_prompt { tokens.clone(), data, - checkpoints + checkpoints, }; } }; @@ -637,9 +624,9 @@ struct server_prompt_cache { size_t n_tokens() const; - server_prompt * alloc(const server_prompt & prompt, size_t state_size); + server_prompt * alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_drft); - bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot); + bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_main, llama_context * ctx_drft, int32_t id_slot); void update(); }; diff --git a/tools/server/tests/unit/test_speculative.py b/tools/server/tests/unit/test_speculative.py index eebd3cc8f..84cd77e6f 100644 --- a/tools/server/tests/unit/test_speculative.py +++ b/tools/server/tests/unit/test_speculative.py @@ -5,7 +5,7 @@ from utils import * server = ServerPreset.stories15m_moe() -MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf" +MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/tiny-llamas/resolve/main/stories15M-q4_0.gguf" def create_server(): global server From ef22b3e4ac9444d1dca1c44164861e0317b5579d Mon Sep 17 00:00:00 2001 From: willjoha Date: Mon, 11 May 2026 18:32:26 +0200 Subject: [PATCH 07/17] docs: fix metrics endpoint description in server README (#22879) * docs: fix metrics endpoint description in server README Required model query parameter for router mode described. Removed metrics: - llamacpp:kv_cache_usage_ratio - llamacpp:kv_cache_tokens Added metrics: - llamacpp:prompt_seconds_total - llamacpp:tokens_predicted_seconds_total - llamacpp:n_decode_total - llamacpp:n_busy_slots_per_decode * server: fix metrics type for n_busy_slots_per_decode metric --- tools/server/README.md | 27 +++++++++++++++++---------- tools/server/server-context.cpp | 8 ++++---- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/tools/server/README.md b/tools/server/README.md index 77eddb335..7f856faa8 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1043,16 +1043,23 @@ If query param `?fail_on_no_slot=1` is set, this endpoint will respond with stat This endpoint is only accessible if `--metrics` is set. -Available metrics: -- `llamacpp:prompt_tokens_total`: Number of prompt tokens processed. -- `llamacpp:tokens_predicted_total`: Number of generation tokens processed. -- `llamacpp:prompt_tokens_seconds`: Average prompt throughput in tokens/s. -- `llamacpp:predicted_tokens_seconds`: Average generation throughput in tokens/s. -- `llamacpp:kv_cache_usage_ratio`: KV-cache usage. `1` means 100 percent usage. -- `llamacpp:kv_cache_tokens`: KV-cache tokens. -- `llamacpp:requests_processing`: Number of requests processing. -- `llamacpp:requests_deferred`: Number of requests deferred. -- `llamacpp:n_tokens_max`: High watermark of the context size observed. +In *router mode* the query param `?model={model_id}` has to be set. This endpoint will respond with status code 400 `model name is missing from the request` if not set. + +#### Available metrics + +| Metric | Type | Description | +| ------ | ---------------------- | ----------- | +| `llamacpp:prompt_tokens_total` | Counter | Number of prompt tokens processed. | +| `llamacpp:prompt_seconds_total` | Counter | Prompt process time in seconds. | +| `llamacpp:prompt_tokens_seconds` | Gauge | Average prompt throughput in tokens/s. | +| `llamacpp:tokens_predicted_total` | Counter | Number of generation tokens processed. | +| `llamacpp:tokens_predicted_seconds_total` | Counter | Predict process time in seconds. | +| `llamacpp:predicted_tokens_seconds` | Gauge | Average generation throughput in tokens/s. | +| `llamacpp:requests_processing` | Gauge | Number of requests processing. | +| `llamacpp:requests_deferred` | Gauge | Number of requests deferred. | +| `llamacpp:n_tokens_max` | Counter | High watermark of the context size observed. | +| `llamacpp:n_decode_total` | Counter | Total Number of llama_decode() calls. | +| `llamacpp:n_busy_slots_per_decode` | Gauge | Average number of busy slots per llama_decode() call. | ### POST `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file. diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 0a51390af..ce743e665 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -3622,10 +3622,6 @@ void server_routes::init_routes() { {"name", "n_tokens_max"}, {"help", "Largest observed n_tokens."}, {"value", res_task->n_tokens_max} - }, { - {"name", "n_busy_slots_per_decode"}, - {"help", "Average number of busy slots per llama_decode() call"}, - {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)} }}}, {"gauge", {{ {"name", "prompt_tokens_seconds"}, @@ -3643,6 +3639,10 @@ void server_routes::init_routes() { {"name", "requests_deferred"}, {"help", "Number of requests deferred."}, {"value", (uint64_t) res_task->n_tasks_deferred} + },{ + {"name", "n_busy_slots_per_decode"}, + {"help", "Average number of busy slots per llama_decode() call"}, + {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)} }}} }; From e93666076038c0bd26397feed6cfb8a6c6d04f74 Mon Sep 17 00:00:00 2001 From: Pascal Date: Mon, 11 May 2026 18:42:08 +0200 Subject: [PATCH 08/17] Ggml/cuda snake fusion hardening (#22912) * cuda: tighten snake fusion type checks for all operands (defensive, sync vulkan) * cuda: reject snake fusion when ne[2] or ne[3] > 1 (mirror vulkan PR review) * cuda: merge type_ok and types_ok into a single types_ok (address am17an review) * cuda: filter ADD/SUB/MUL/DIV in supports_op to F32/F16 bin_bcast only dispatches F32/F16 type triplets, mirror the vulkan filter so unsupported types fall back through cpy instead of aborting. * test-backend-ops: extend snake_fuse to rank-4 with ne[2]/ne[3] > 1 cases --- ggml/src/ggml-cuda/ggml-cuda.cu | 30 ++++++++++++++++++++++++------ tests/test-backend-ops.cpp | 26 +++++++++++++++----------- 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b92a20870..e25be3592 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3929,10 +3929,25 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph // closure check: the trailing add must read the same x as the leading mul const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0]; - const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16); + // Kernel iterates over total = T * C, so x and add must be 2D and + // a / inv_b must collapse to [1, C, 1, 1]. Higher dims are not handled. + const bool dim_ok = (x->ne[2] == 1 && x->ne[3] == 1) && + (add->ne[2] == 1 && add->ne[3] == 1) && + (a->ne[2] == 1 && a->ne[3] == 1); const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1]; - if (type_ok && shape_ok && x_in_add == x && add->type == x->type) { + // x must be in the supported whitelist and every operand / intermediate + // result must share x's type, since launch_snake casts a / inv_b as + // float and templates the kernel on a single T. Mixed precision chains + // fall back to the naive path. + const ggml_tensor * sin1 = cgraph->nodes[i + 1]; + const bool types_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16) && + (a->type == x->type) && (inv_b->type == x->type) && + (mul0->type == x->type) && (sin1->type == x->type) && + (sqr->type == x->type) && (mul1->type == x->type) && + (add->type == x->type); + + if (types_ok && shape_ok && dim_ok && x_in_add == x) { ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add); return 4; } @@ -5291,12 +5306,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: - case GGML_OP_ADD: case GGML_OP_ADD_ID: case GGML_OP_ADD1: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SQRT: @@ -5305,6 +5316,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CLAMP: case GGML_OP_LOG: return true; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_SSM_SCAN: { if (op->src[3]->ne[0] == 1) { // Mamba2 diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 922ad493a..333119486 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3561,7 +3561,7 @@ struct test_relu_sqr : public test_case { // and dispatches a single fused kernel. struct test_snake_fuse : public test_case { const ggml_type type; - const std::array ne; // [T, C] + const std::array ne; // [T, C, D2, D3] std::string op_desc(ggml_tensor * t) override { GGML_UNUSED(t); @@ -3586,11 +3586,11 @@ struct test_snake_fuse : public test_case { } test_snake_fuse(ggml_type type = GGML_TYPE_F32, - std::array ne = {256, 192}) + std::array ne = {256, 192, 1, 1}) : type(type), ne(ne) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * x = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); + ggml_tensor * x = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]); ggml_set_name(x, "x"); ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, ne[1]); @@ -7558,11 +7558,15 @@ static std::vector> make_test_cases_eval() { // SNAKE activation fusion: x + sin(a*x)^2 * inv_b for (ggml_type type : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) { - test_cases.emplace_back(new test_snake_fuse(type, { 5, 7})); // primes sub-block - test_cases.emplace_back(new test_snake_fuse(type, { 33, 32})); // boundary - test_cases.emplace_back(new test_snake_fuse(type, {1025, 13})); // large prime, grid-stride - test_cases.emplace_back(new test_snake_fuse(type, { 128, 16})); // power-of-two - test_cases.emplace_back(new test_snake_fuse(type, { 256, 192})); // BigVGAN-ish + test_cases.emplace_back(new test_snake_fuse(type, { 5, 7, 1, 1})); // primes sub-block + test_cases.emplace_back(new test_snake_fuse(type, { 33, 32, 1, 1})); // boundary + test_cases.emplace_back(new test_snake_fuse(type, {1025, 13, 1, 1})); // large prime, grid-stride + test_cases.emplace_back(new test_snake_fuse(type, { 128, 16, 1, 1})); // power-of-two + test_cases.emplace_back(new test_snake_fuse(type, { 256, 192, 1, 1})); // BigVGAN-ish + // higher-rank shapes: matcher must reject fusion, fallback to naive chain + test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 2, 1})); // ne[2] > 1 + test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 1, 2})); // ne[3] > 1 + test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 2, 3})); // ne[2] > 1 and ne[3] > 1 } // glu ops @@ -9093,9 +9097,9 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1})); // SNAKE activation fusion at BigVGAN scale (T=7680 = 24 kHz x 320 ms, C=192) - test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192})); - test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192})); - test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192})); + test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192, 1, 1})); + test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192, 1, 1})); + test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192, 1, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416)); From 8e1f9d083453678808a26465a3563a19a0b49699 Mon Sep 17 00:00:00 2001 From: CrispStrobe <154636388+CrispStrobe@users.noreply.github.com> Date: Mon, 11 May 2026 19:48:29 +0200 Subject: [PATCH 09/17] CUDA: handle OW > 65535 in im2col (2D and 3D) (#22944) `im2col_cuda` and `im2col_3d_cuda` both dispatch with `block_nums.y = OW`. CUDA caps grid Y at 65535. Conv1d encoders on raw 16 kHz audio with T > 65535 (~ 4 s) trip the limit -- e.g. SEANet at 11 s lands at OW = 176000 -- and the launch returns `invalid configuration argument`. Clamp `block_nums.y` to `MIN(OW, MAX_GRIDDIM_Y)` and loop inside the kernel with stride `MAX_GRIDDIM_Y`. Same in-kernel stride pattern already used for the z axis (`MAX_GRIDDIM_Z`). Both 2D `im2col_kernel` and 3D `im2col_3d_kernel` need the same fix. Bit-identical for OW <= 65535 (single iteration of the new outer loop). Tested on T4 / Jetson Orin with a SEANet encoder running on 11 s / 16 kHz audio (im2col reaching OW ~ 176000); pre-fix launch returns `invalid configuration argument`, post-fix runs to completion. Existing test-backend-ops im2col cases unchanged. --- ggml/src/ggml-cuda/im2col.cu | 61 +++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu index 56dc05457..28c79ab46 100644 --- a/ggml/src/ggml-cuda/im2col.cu +++ b/ggml/src/ggml-cuda/im2col.cu @@ -1,5 +1,6 @@ #include "im2col.cuh" +#define MAX_GRIDDIM_Y 65535 #define MAX_GRIDDIM_Z 65535 template @@ -18,22 +19,23 @@ static __global__ void im2col_kernel( const int64_t ikh = rem / KW; const int64_t ikw = rem - ikh * KW; - const int64_t iow = blockIdx.y; - for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) { - const int64_t in = iz / OH; - const int64_t ioh = iz - in * OH; + for (int64_t iow = blockIdx.y; iow < OW; iow += MAX_GRIDDIM_Y) { + for (int64_t iz = blockIdx.z; iz < N_OH; iz += MAX_GRIDDIM_Z) { + const int64_t in = iz / OH; + const int64_t ioh = iz - in * OH; - const int64_t iiw = iow * s0 + ikw * d0 - p0; - const int64_t iih = ioh * s1 + ikh * d1 - p1; + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; - const int64_t offset_dst = - ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw; + const int64_t offset_dst = + ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw; - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; - } else { - const int64_t offset_src = iic * IC_IH_IW + in * IH_IW; - dst[offset_dst] = x[offset_src + iih * IW + iiw]; + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = iic * IC_IH_IW + in * IH_IW; + dst[offset_dst] = x[offset_src + iih * IW + iiw]; + } } } @@ -51,7 +53,7 @@ static void im2col_cuda(const float * x, T* dst, const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; const int64_t N_OH = N * OH; const int64_t KH_KW = KW*KH; - dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z)); + dim3 block_nums(num_blocks, MIN(OW, MAX_GRIDDIM_Y), MIN(N_OH, MAX_GRIDDIM_Z)); im2col_kernel<<>>(x, dst, IC, IW, IH, OH, OW, KW, KH, IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW, s0, s1, p0, p1, d0, d1); @@ -136,23 +138,24 @@ static __global__ void im2col_3d_kernel( const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; const int64_t ikw = i % KW; - const int64_t iow = blockIdx.y; - for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) { - const int64_t in = iz / OD_OH; - const int64_t iod = (iz - in*OD_OH) / OH; - const int64_t ioh = iz % OH; + for (int64_t iow = blockIdx.y; iow < OW; iow += MAX_GRIDDIM_Y) { + for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz += MAX_GRIDDIM_Z) { + const int64_t in = iz / OD_OH; + const int64_t iod = (iz - in*OD_OH) / OH; + const int64_t ioh = iz % OH; - const int64_t iiw = iow * s0 + ikw * d0 - p0; - const int64_t iih = ioh * s1 + ikh * d1 - p1; - const int64_t iid = iod * s2 + ikd * d2 - p2; + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; + const int64_t iid = iod * s2 + ikd * d2 - p2; - const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { - dst[offset_dst] = 0.0f; - } else { - const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x); - dst[offset_dst] = src[offset_src]; + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x); + dst[offset_dst] = src[offset_src]; + } } } } @@ -178,7 +181,7 @@ static void im2col_3d_cuda(const float * src, T* dst, const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; - dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z)); + dim3 block_nums(num_blocks, MIN(OW, MAX_GRIDDIM_Y), MIN(N_OD_OH, MAX_GRIDDIM_Z)); im2col_3d_kernel<<>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW, IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW, From 1ec7ba0c14f33f17e980daeeda5f35b225d41994 Mon Sep 17 00:00:00 2001 From: Shawn Gu Date: Mon, 11 May 2026 11:57:26 -0700 Subject: [PATCH 10/17] opencl: add q4_1 MoE for Adreno (#22856) * Q4_1 MoE CLC pass sanity check * remove unnecessary code * opencl: remove unnecessary asserts and reformat * opencl: fix supports_op for q4_1 moe * q4_1 moe is supported by Adreno with certain shapes --------- Co-authored-by: Li He --- ggml/src/ggml-opencl/CMakeLists.txt | 2 + ggml/src/ggml-opencl/ggml-opencl.cpp | 366 ++++++++++++++++-- ggml/src/ggml-opencl/kernels/cvt.cl | 90 +++++ .../kernels/gemm_moe_q4_1_f32_ns.cl | 254 ++++++++++++ .../kernels/gemv_moe_q4_1_f32_ns.cl | 119 ++++++ 5 files changed, 798 insertions(+), 33 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index ffde6a4f0..7edb3eb4e 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -104,6 +104,8 @@ set(GGML_OPENCL_KERNELS mul_mv_id_mxfp4_f32_flat gemm_moe_q4_0_f32_ns gemv_moe_q4_0_f32_ns + gemm_moe_q4_1_f32_ns + gemv_moe_q4_1_f32_ns gemm_moe_mxfp4_f32 gemv_moe_mxfp4_f32 gemm_moe_mxfp4_f32_ns diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 4e6f6fb43..73a58f74a 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -544,6 +544,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; cl_kernel kernel_convert_block_q4_0_trans4_ns, kernel_restore_block_q4_0_trans4_ns; cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; + cl_kernel kernel_convert_block_q4_1_trans4_ns, kernel_restore_block_q4_1_trans4_ns; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; @@ -602,6 +603,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4; cl_kernel kernel_timestep_embedding; cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns; + cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns; cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns; cl_kernel kernel_moe_reorder_b; @@ -958,6 +960,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans4_ns", &err), err)); @@ -2856,6 +2860,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve " -cl-mad-enable " " -cl-fast-relaxed-math"; + // gemv_moe_q4_1_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q4_1_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q4_1_f32_ns.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_q4_1_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q4_1_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_q4_1_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q4_1_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q4_1_f32_ns.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q4_1_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q4_1_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gemv_moe_mxfp4_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3749,11 +3785,14 @@ struct ggml_tensor_extra_cl_q4_1 { CL_CHECK(clReleaseMemObject(m)); m = nullptr; } + if (q_img != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q_img = nullptr; + } // Currently, q_img and d_img are only initialized when SMALL_ALLOC is // enabled. They point to the images in ggml_backend_opencl_buffer_context. // So, there is no need to release them here. // TODO: initialize them for non SMALL_PATH path, or remove them. - q_img = nullptr; d_img = nullptr; m_img = nullptr; size_q = 0; @@ -4189,6 +4228,35 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm return GGML_STATUS_SUCCESS; } +// The optimized gemm and gemv kernels are used for large matrices without batch. +// tensor is the quantized weights matrix. +inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + int64_t threshold_ne0 = 512; + int64_t threshold_ne1 = 512; + if (!backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) && + backend_ctx->adreno_cl_compiler_version.type != DX) { + threshold_ne0 = 128; + threshold_ne1 = 128; + } + return tensor->ne[0] >= threshold_ne0 && tensor->ne[1] >= threshold_ne1 && + tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + GGML_UNUSED(backend_ctx); + int ne01 = tensor->ne[1]; + return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); +} + +inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + + bool adreno_kernel = use_adreno_kernels(backend_ctx, tensor); + + size_t elem_num = tensor->ne[0] * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]; + + return ((elem_num < 128 * 1024 * 1024) && adreno_kernel); // max element num: 2**27 +} + static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context; ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; @@ -4385,6 +4453,18 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); } } + // q4_0, q8_0 and mxfp4 have general MUL_MAT_ID support, + // the quantizations here currently do not - they are only supported by Adreno with certain shapes + if (op->src[0]->type == GGML_TYPE_Q4_1) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (op->src[1]->type == GGML_TYPE_F32) { + return use_adreno_moe_kernels(backend_ctx, op->src[0]) + && ggml_is_contiguous(op->src[0]) + && ggml_is_contiguous(op->src[1]); + } +#endif + return false; + } return false; case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -4555,6 +4635,12 @@ struct ggml_backend_opencl_buffer_context { for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) { delete e; } + for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1) { + delete e; + } + for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1_in_use) { + delete e; + } for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4) { delete e; } @@ -4868,35 +4954,6 @@ static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buff return GGML_STATUS_SUCCESS; } -// The optimized gemm and gemv kernels are used for large matrices without batch. -// tensor is the quantized weights matrix. -inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { - int64_t threshold_ne0 = 512; - int64_t threshold_ne1 = 512; - if (!backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) && - backend_ctx->adreno_cl_compiler_version.type != DX) { - threshold_ne0 = 128; - threshold_ne1 = 128; - } - return tensor->ne[0] >= threshold_ne0 && tensor->ne[1] >= threshold_ne1 && - tensor->ne[2] == 1 && tensor->ne[3] == 1; -} - -inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { - GGML_UNUSED(backend_ctx); - int ne01 = tensor->ne[1]; - return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); -} - -inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { - - bool adreno_kernel = use_adreno_kernels(backend_ctx, tensor); - - size_t elem_num = tensor->ne[0] * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]; - - return ((elem_num < 128 * 1024 * 1024) && adreno_kernel); // max element num: 2**27 -} - static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); @@ -5097,15 +5154,54 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno moe q4_1 kernel needs special transpose and unshuffling + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for Q + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + // normal q4_1 repack +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1; if (use_adreno_kernels(backend_ctx, tensor)) { kernel = backend_ctx->kernel_convert_block_q4_1_noshuffle; } - #else +#else cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1; - #endif // GGML_OPENCL_USE_ADRENO_KERNELS +#endif // GGML_OPENCL_USE_ADRENO_KERNELS CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); @@ -5862,6 +5958,36 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, ggml_tensor_extra_cl_q4_1 * extra = (ggml_tensor_extra_cl_q4_1 *)tensor->extra; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (use_adreno_kernels(backend_ctx, tensor)) { static ggml_cl_buffer buf_trans_q; static ggml_cl_buffer buf_trans_m; @@ -12862,6 +12988,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, #ifdef GGML_OPENCL_SOA_Q ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; + ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; #endif @@ -13131,6 +13258,179 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, break; } + case GGML_TYPE_Q4_1: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q4_1_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q4_1_f32_ns; + + if (strstr(src0->name, "as") != NULL) { + moe_router_reoerder(backend, src2, ne20); + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1; + cl_image_desc image_desc_buf_src1; + image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } case GGML_TYPE_Q8_0: { #ifdef GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_id_q8_0_f32_flat; diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index c87450dc4..5bbf09710 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -370,6 +370,96 @@ kernel void kernel_restore_block_q4_1_noshuffle( } } +kernel void kernel_convert_block_q4_1_trans4_ns( + __global struct block_q4_1 * src0, + __global uint * dst_q, + __global half * dst_d, + __global half * dst_m, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK4_1; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + global struct block_q4_1 * b = src0 + src_blk_offset; + dst_d[dst_blk_offset] = b->d; + dst_m[dst_blk_offset] = b->m; + + // extract quantization and unshuffle + ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0]; + + ushort8 post_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK4_1 / 4; ++i) { + uchar x0 = pre_block_ptr[2*i + 0]; + uchar x1 = pre_block_ptr[2*i + 1]; + + post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + post_block_ptr[i + QK4_1 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + uint4 q_block = as_uint4(post_block); + + uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + dst_q[offset] = q_block.x; + dst_q[offset + ne01] = q_block.y; + dst_q[offset + ne01 * 2] = q_block.z; + dst_q[offset + ne01 * 3] = q_block.w; +} + +kernel void kernel_restore_block_q4_1_trans4_ns( + __global uint * src_q, + __global half * src_d, + __global half * src_m, + __global struct block_q4_1 * dst0, + uint ne00, + uint ne01 +) { + int i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK4_1; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint src_dm_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q4_1 * b = dst0 + dst_blk_offset; + b->d = src_d[src_dm_offset]; + b->m = src_m[src_dm_offset]; + + // collect transposed quantization parts for a block + uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + uint4 q_block; + q_block.x = src_q[src_q_offset]; + q_block.y = src_q[src_q_offset + ne01]; + q_block.z = src_q[src_q_offset + ne01 * 2]; + q_block.w = src_q[src_q_offset + ne01 * 3]; + + ushort8 post_block = as_ushort8(q_block); + ushort8 pre_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK4_0 / 4; ++i) { + uchar x0 = post_block_ptr[i + 0]; + uchar x1 = post_block_ptr[i + QK4_0 / 4]; + + pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; +} + //------------------------------------------------------------------------------ // block_mxfp4 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl new file mode 100644 index 000000000..e2574ae01 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl @@ -0,0 +1,254 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 + + +#define dequantize_q4_1(q4, a_f16, scale, m) \ + a_f16.s0 = (half)(q4.s0 & 0x000F) * scale + m; \ + a_f16.s1 = (half)((q4.s0 & 0x00F0) >> 4) * scale + m; \ + a_f16.s2 = (half)((q4.s0 & 0x0F00) >> 8) * scale + m; \ + a_f16.s3 = (half)((q4.s0 & 0xF000) >> 12) * scale + m; \ + a_f16.s4 = (half)(q4.s1 & 0x000F) * scale + m; \ + a_f16.s5 = (half)((q4.s1 & 0x00F0) >> 4) * scale + m; \ + a_f16.s6 = (half)((q4.s1 & 0x0F00) >> 8) * scale + m; \ + a_f16.s7 = (half)((q4.s1 & 0xF000) >> 12) * scale + m; \ + a_f16.s8 = (half)(q4.s2 & 0x000F) * scale + m; \ + a_f16.s9 = (half)((q4.s2 & 0x00F0) >> 4) * scale + m; \ + a_f16.sa = (half)((q4.s2 & 0x0F00) >> 8) * scale + m; \ + a_f16.sb = (half)((q4.s2 & 0xF000) >> 12) * scale + m; \ + a_f16.sc = (half)(q4.s3 & 0x000F) * scale + m; \ + a_f16.sd = (half)((q4.s3 & 0x00F0) >> 4) * scale + m; \ + a_f16.se = (half)((q4.s3 & 0x0F00) >> 8) * scale + m; \ + a_f16.sf = (half)((q4.s3 & 0xF000) >> 12) * scale + m; \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair +kernel void kernel_gemm_moe_q4_1_f32_ns( + __read_only image1d_buffer_t src0_q, + __global half * src0_d, + __global half * src0_m, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + // First sub-block + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); + uint b_sub_offset = col * ne00 + step; + + // Load scale and m for current Q4_1 block + uint sm_offset = s_sub_offset + get_global_id(0); + half s = src0_d[sm_offset]; + half m = src0_m[sm_offset]; + + // Load 16 q (64-bits) in transposed layout + uint2 q4x16; + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_1(as_ushort4(q4x16), reg_a, s, m); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 8 elements reduction for better precision + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Repeat for second sub-block + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + // Load next 16 q (64-bits) in transposed layout + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_1(as_ushort4(q4x16), reg_a, s, m); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 3-levels reduction for better precision + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load poster router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile, override correct result in the end + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl new file mode 100644 index 000000000..3739a2157 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl @@ -0,0 +1,119 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_Q4_1 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline float8 q4_1_to_fp32_packed8(ushort2 q4x8, half s, half m) { + float8 fp32x8; + fp32x8.s0 = (float)((q4x8.s0 & 0x000F) * s + m); + fp32x8.s1 = (float)(((q4x8.s0 & 0x00F0) >> 4) * s + m); + fp32x8.s2 = (float)(((q4x8.s0 & 0x0F00) >> 8) * s + m); + fp32x8.s3 = (float)(((q4x8.s0 & 0xF000) >> 12) * s + m); + fp32x8.s4 = (float)((q4x8.s1 & 0x000F) * s + m); + fp32x8.s5 = (float)(((q4x8.s1 & 0x00F0) >> 4) * s + m); + fp32x8.s6 = (float)(((q4x8.s1 & 0x0F00) >> 8) * s + m); + fp32x8.s7 = (float)(((q4x8.s1 & 0xF000) >> 12) * s + m); + return fp32x8; +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q4_1_f32_ns( + __global uint * src0_q, + __global half * src0_d, + __global half * src0_m, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_Q4_1); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ; + uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; + + regQ.s0 = src0_q[block_offset]; + regQ.s1 = src0_q[block_offset + ne01]; + regQ.s2 = src0_q[block_offset + ne01 * 2]; + regQ.s3 = src0_q[block_offset + ne01 * 3]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + half regM = src0_m[ib00 * ne01 + i01 + expert_offset]; + half regS = src0_d[ib00 * ne01 + i01 + expert_offset]; + + float8 fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s0), regS, regM); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s1), regS, regM); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s2), regS, regM); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s3), regS, regM); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} From da44953329daec5f16b7b19d7ebfdc6552415362 Mon Sep 17 00:00:00 2001 From: guyfischman <138163913+guyfischman@users.noreply.github.com> Date: Tue, 12 May 2026 07:15:02 +0200 Subject: [PATCH 11/17] metal : promote mul_mv/mul_mm batch divisors to function constants (#22711) * metal : promote mul_mv/mul_mm batch divisors to function constants * metal : take op directly in get_pipeline_mul_mv_ext --- ggml/src/ggml-metal/ggml-metal-device.cpp | 46 +++++- ggml/src/ggml-metal/ggml-metal-device.h | 2 +- ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +- ggml/src/ggml-metal/ggml-metal.metal | 165 +++++++++++----------- 4 files changed, 127 insertions(+), 88 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index d211bf79f..f0147af84 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -647,19 +647,30 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_m return res; } -ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, const ggml_tensor * op, int nsg, int nxpsg, int r1ptg) { char base[256]; char name[256]; + const ggml_type tsrc0 = op->src[0]->type; + const ggml_type tsrc1 = op->src[1]->type; + const int ne12 = op->src[1]->ne[2]; + const int r2 = ne12 / op->src[0]->ne[2]; + const int r3 = op->src[1]->ne[3] / op->src[0]->ne[3]; + + GGML_ASSERT(ne12 <= INT16_MAX && r2 <= INT16_MAX && r3 <= INT16_MAX); + snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg); - snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg); + snprintf(name, 256, "%s_nsg=%d_nxpsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, nxpsg, ne12, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); - ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, (int16_t) r2, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, (int16_t) r3, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -687,8 +698,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta ? (op->ne[0] % NRA != 0 || op->ne[1] % NRB != 0) : (op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0); + GGML_ASSERT(op->src[1]->ne[2] <= INT16_MAX && op->src[1]->ne[3] <= INT16_MAX); + const int16_t ne12 = (int16_t) op->src[1]->ne[2]; + const int16_t ne13 = (int16_t) op->src[1]->ne[3]; + const int16_t r2 = (int16_t) (ne12 / op->src[0]->ne[2]); + const int16_t r3 = (int16_t) (ne13 / op->src[0]->ne[3]); + snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); - snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out); + snprintf(name, 256, "%s_bci=%d_bco=%d_ne12=%d_ne13=%d_r2=%d_r3=%d", + base, bc_inp, bc_out, ne12, ne13, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { @@ -696,6 +714,10 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1); + ggml_metal_cv_set_int16(cv, ne12, FC_MUL_MM + 2); + ggml_metal_cv_set_int16(cv, ne13, FC_MUL_MM + 3); + ggml_metal_cv_set_int16(cv, r2, FC_MUL_MM + 4); + ggml_metal_cv_set_int16(cv, r3, FC_MUL_MM + 5); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -877,14 +899,21 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta } }; + GGML_ASSERT(ne12 <= INT16_MAX && ne13 <= INT16_MAX); + const int16_t r2 = (int16_t) (ne12 / ne02); + const int16_t r3 = (int16_t) (ne13 / ne03); + snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); - snprintf(name, 256, "%s_nsg=%d", base, nsg); + snprintf(name, 256, "%s_nsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, ne12, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, r2, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, r3, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -1102,6 +1131,9 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m ggml_metal_cv_t cv = ggml_metal_cv_init(); ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 4718ca083..1f212a92f 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -129,7 +129,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); -struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, const struct ggml_tensor * op, int nsg, int nxpsg, int r1ptg); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 5fa162c87..a114391c2 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2120,7 +2120,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { GGML_ABORT("unsupported ne11"); }; - auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); + auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op, nsg, nxpsg, r1ptg); ggml_metal_kargs_mul_mv_ext args = { /*.ne00 =*/ ne00, diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index c372eaede..3882b9558 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3353,6 +3353,9 @@ static inline void helper_mv_reduce_and_write( constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]]; constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]]; +constant short FC_mul_mv_ne12 [[function_constant(FC_MUL_MV + 2)]]; +constant short FC_mul_mv_r2 [[function_constant(FC_MUL_MV + 3)]]; +constant short FC_mul_mv_r3 [[function_constant(FC_MUL_MV + 4)]]; template void mul_vec_q_n_f32_impl( @@ -3376,10 +3379,10 @@ void mul_vec_q_n_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); @@ -3388,7 +3391,7 @@ void mul_vec_q_n_f32_impl( // pointers to src0 rows device const block_q_type * ax[NR0]; FOR_UNROLL (int row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } @@ -3462,8 +3465,8 @@ void kernel_mul_mv_q1_0_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13; @@ -3471,7 +3474,7 @@ void kernel_mul_mv_q1_0_f32_impl( device const block_q1_0 * ax[nr0]; for (int row = 0; row < nr0; ++row) { - const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0); } @@ -3590,10 +3593,10 @@ void kernel_mul_mv_q8_0_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); @@ -3602,7 +3605,7 @@ void kernel_mul_mv_q8_0_f32_impl( // pointers to src0 rows device const block_q8_0 * ax[NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } @@ -3682,10 +3685,10 @@ void kernel_mul_mv_ext_q4_f32_impl( const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; - const int i12 = i1m%args.ne12; - const int i13 = i1m/args.ne12; + const int i12 = i1m%FC_mul_mv_ne12; + const int i13 = i1m/FC_mul_mv_ne12; - const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; @@ -3785,10 +3788,10 @@ void kernel_mul_mv_ext_q4x4_f32_impl( const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; - const int i12 = i1m%args.ne12; - const int i13 = i1m/args.ne12; + const int i12 = i1m%FC_mul_mv_ne12; + const int i13 = i1m/FC_mul_mv_ne12; - const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; @@ -4000,10 +4003,10 @@ void kernel_mul_mv_t_t_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const T0 * x = (device const T0 *) (src0 + offset0); @@ -4012,7 +4015,7 @@ void kernel_mul_mv_t_t_impl( // pointers to src0 rows device const T0 * ax [NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const T0 *) ((device char *) src0 + offset0); } @@ -4122,10 +4125,10 @@ void kernel_mul_mv_t_t_4_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); @@ -4135,7 +4138,7 @@ void kernel_mul_mv_t_t_4_impl( device const T0 * ax [NR0]; device const T04 * ax4[NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax [row] = (device const T0 *) ((device char *) src0 + offset0); ax4[row] = (device const T04 *) ((device char *) src0 + offset0); @@ -4239,10 +4242,10 @@ void kernel_mul_mv_t_t_short_impl( return; } - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; device const T0 * x = (device const T0 *) (src0 + offset0); @@ -7462,10 +7465,10 @@ void kernel_mul_mv_q2_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0); @@ -7567,10 +7570,10 @@ void kernel_mul_mv_q3_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0); @@ -7741,10 +7744,10 @@ void kernel_mul_mv_q4_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0); @@ -7853,10 +7856,10 @@ void kernel_mul_mv_q5_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); @@ -7989,10 +7992,10 @@ void kernel_mul_mv_q6_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); @@ -8094,10 +8097,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0); @@ -8202,10 +8205,10 @@ void kernel_mul_mv_iq2_xs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0); @@ -8321,10 +8324,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0); @@ -8433,10 +8436,10 @@ void kernel_mul_mv_iq3_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0); @@ -8545,10 +8548,10 @@ void kernel_mul_mv_iq2_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0); @@ -8658,10 +8661,10 @@ void kernel_mul_mv_iq1_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0); @@ -8757,10 +8760,10 @@ void kernel_mul_mv_iq1_m_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0); @@ -8866,10 +8869,10 @@ void kernel_mul_mv_iq4_nl_f32_impl( const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); @@ -8975,10 +8978,10 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int im = tgpig.z; const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); @@ -9086,10 +9089,10 @@ void kernel_mul_mv_mxfp4_f32_impl( const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0); @@ -9304,6 +9307,10 @@ kernel void kernel_diag_f32( constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]]; constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; +constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]]; +constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]]; +constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]]; +constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]]; // each block_q contains 16*nl weights #ifdef GGML_METAL_HAS_TENSOR @@ -9330,11 +9337,11 @@ kernel void kernel_mul_mm( // Batch dimension handling const int im = tgpig.z; - const int i12 = im % args.ne12; - const int i13 = im / args.ne12; + const int i12 = im % FC_mul_mm_ne12; + const int i13 = im / FC_mul_mm_ne12; // Batch offsets for srcA and srcB - const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03; // Tile dimensions constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X; @@ -9473,10 +9480,10 @@ kernel void kernel_mul_mm( short il = il0; - const int i12 = im%args.ne12; - const int i13 = im/args.ne12; + const int i12 = im % FC_mul_mm_ne12; + const int i13 = im / FC_mul_mm_ne12; - const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03; const short offset1 = il0/nl; device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1; From 78fbbc2c0788efc8857a2c0dc9802ec689fa12c1 Mon Sep 17 00:00:00 2001 From: Jesus Talavera <145992175+jesus-talavera-ibm@users.noreply.github.com> Date: Tue, 12 May 2026 07:17:04 +0200 Subject: [PATCH 12/17] convert : add split() to LoraTorchTensor in LoRA converter (#22832) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * convert : add split() method to LoraTorchTensor * Fix python type-check * Fix flake8 Lint * fix: handle positional dim arg in torch.split dispatch * Fix type-check again * Fix type-checks * Remove unit test per reviewers feedback * work around ty deficiency --------- Co-authored-by: Sigbjørn Skjæret --- convert_lora_to_gguf.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index d58334205..ad4751bb9 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -188,6 +188,24 @@ class LoraTorchTensor: def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor: return self.transpose(axis0, axis1) + def split(self, split_size: int | Sequence[int], dim: int = 0) -> tuple[LoraTorchTensor, ...]: + shape = self.shape + ndim = len(shape) + if dim < 0: + dim += ndim + if dim == ndim - 1: + A_chunks = self._lora_A.split(split_size, dim=-1) + return tuple(LoraTorchTensor(a, self._lora_B) for a in A_chunks) + elif dim == ndim - 2: + B_chunks = self._lora_B.split(split_size, dim=-2) + return tuple(LoraTorchTensor(self._lora_A, b) for b in B_chunks) + else: + B_chunks = self._lora_B.split(split_size, dim=dim) + if self._lora_A.shape[dim] == 1: + return tuple(LoraTorchTensor(self._lora_A, b) for b in B_chunks) + A_chunks = self._lora_A.split(split_size, dim=dim) + return tuple(LoraTorchTensor(a, b) for a, b in zip(A_chunks, B_chunks)) + def to(self, *args, **kwargs): return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs)) @@ -230,6 +248,11 @@ class LoraTorchTensor: ) else: raise NotImplementedError + elif func is torch.split: + assert len(args) and len(args) >= 2 + tensor, split_size = args[0], args[1] + dim = args[2] if len(args) > 2 else kwargs.get("dim", 0) + return tensor.split(split_size, dim=dim) else: raise NotImplementedError From 41782591303ada159467fc75b232b2c9110f45aa Mon Sep 17 00:00:00 2001 From: AesSedai <7980540+AesSedai@users.noreply.github.com> Date: Tue, 12 May 2026 02:11:14 -0700 Subject: [PATCH 13/17] mtmd: add MiMo v2.5 vision (#22883) * mimo-v2.5: vision support * mimo-v2.5: use fused qkv for vision * mimi-v2.5: fix f16 vision overflow * mimo-v2.5: comment cleanups * mimo-v2.5: Flash doesn't have mmproj more cleanup remember to use filter_tensors * mimo-v2.5: fix trailing whitespace --- convert_hf_to_gguf.py | 67 +++++++++++ gguf-py/gguf/constants.py | 48 ++++---- gguf-py/gguf/gguf_writer.py | 6 + gguf-py/gguf/tensor_mapping.py | 4 + tools/mtmd/CMakeLists.txt | 1 + tools/mtmd/clip-graph.h | 3 +- tools/mtmd/clip-impl.h | 5 + tools/mtmd/clip-model.h | 4 + tools/mtmd/clip.cpp | 126 +++++++++++++++++++- tools/mtmd/models/mimovl.cpp | 209 +++++++++++++++++++++++++++++++++ tools/mtmd/models/models.h | 9 ++ tools/mtmd/mtmd.cpp | 1 + 12 files changed, 460 insertions(+), 23 deletions(-) create mode 100644 tools/mtmd/models/mimovl.cpp diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index bf76fa406..d79372cea 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -9760,6 +9760,73 @@ class MimoV2Model(TextModel): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("MiMoV2ForCausalLM") +class MiMoV2VisionModel(MmprojModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + hp = self.hparams_vision + + hp["image_size"] = hp.get("image_size", 560) + hp["num_attention_heads"] = hp.get("num_heads", 32) + hp["num_hidden_layers"] = hp.get("depth", 28) + + self.n_q_heads = int(hp["num_heads"]) + self.num_kv_heads = int(hp.get("num_key_value_heads", 8)) + self.head_dim = int(hp.get("qk_channels", 64)) + self.spatial_merge_size = int(hp["spatial_merge_size"]) + # MiMoV2 vision RMSNorm: HF uses getattr(config, "rms_norm_eps", 1e-6) and the + # field is absent from MiMo-V2.5's vision_config + self.rms_norm_eps = float(hp.get("rms_norm_eps", 1e-6)) + + # fullatt_block_indexes are also reflected in vit_window_attn_types as -1 + self.fullatt_block_indexes = list(hp.get("fullatt_block_indexes") or []) + self.vit_window_attn_types = list(hp.get("vit_window_attn_types") or []) + self.visual_token_window_size = int(hp.get("visual_token_window_size", -1)) + self.use_sink = bool(hp.get("use_sink", False)) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.MIMOVL) + self.gguf_writer.add_vision_use_silu(True) + self.gguf_writer.add_vision_head_count_kv(self.num_kv_heads) + self.gguf_writer.add_vision_spatial_merge_size(self.spatial_merge_size) + self.gguf_writer.add_uint32(gguf.Keys.ClipVision.WINDOW_SIZE, self.visual_token_window_size) + self.gguf_writer.add_vision_wa_pattern_mode(self.vit_window_attn_types) + self.gguf_writer.add_vision_attention_layernorm_eps(self.rms_norm_eps) + self.gguf_writer.add_vision_min_pixels(int(self.preprocessor_config["min_pixels"])) + self.gguf_writer.add_vision_max_pixels(int(self.preprocessor_config["max_pixels"])) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + # Sinks must be F32: any sink-style softmax/mask add in ggml requires + # F32, and we fold sinks into a host-built F32 mask at encode time. + if new_name.endswith(".attn_sinks"): + return gguf.GGMLQuantizationType.F32 + return super().tensor_force_quant(name, new_name, bid, n_dims) + + @classmethod + def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None: + name, _ = item + if not name.startswith("visual."): + return None + return super().filter_tensors(item) + + def modify_tensors(self, data_torch, name, bid): + # Conv3D patch embed: split along the temporal axis (kt=2) into two Conv2D + # weights that the existing qwen2vl-style two-Conv2D path consumes. + if name == "visual.patch_embed.proj.weight": + _, _, kt, _, _ = data_torch.shape + if kt != 2: + raise ValueError(f"unexpected temporal_patch_size: {kt}") + embd_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + yield (embd_name + ".weight", data_torch[:, :, 0, ...]) + yield (embd_name + ".weight.1", data_torch[:, :, 1, ...]) + return + + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("Step3p5ForCausalLM") class Step35Model(TextModel): model_arch = gguf.MODEL_ARCH.STEP35 diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 617cbc49d..4055ec287 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -299,30 +299,32 @@ class Keys: HAS_LLAVA_PROJECTOR = "clip.has_llava_projector" class ClipVision: - PROJECTOR_TYPE = "clip.vision.projector_type" # for mixed modality models - IMAGE_SIZE = "clip.vision.image_size" - IMAGE_MIN_PIXELS = "clip.vision.image_min_pixels" - IMAGE_MAX_PIXELS = "clip.vision.image_max_pixels" - PREPROC_MIN_TILES = "clip.vision.preproc_min_tiles" - PREPROC_MAX_TILES = "clip.vision.preproc_max_tiles" - PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size" - PATCH_SIZE = "clip.vision.patch_size" - EMBEDDING_LENGTH = "clip.vision.embedding_length" - FEED_FORWARD_LENGTH = "clip.vision.feed_forward_length" - PROJECTION_DIM = "clip.vision.projection_dim" - BLOCK_COUNT = "clip.vision.block_count" - IMAGE_MEAN = "clip.vision.image_mean" - IMAGE_STD = "clip.vision.image_std" - SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size" - USE_GELU = "clip.use_gelu" - USE_SILU = "clip.use_silu" - N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl - WA_LAYER_INDEXES = "clip.vision.wa_layer_indexes" # used by youtuvl - IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers" - WINDOW_SIZE = "clip.vision.window_size" + PROJECTOR_TYPE = "clip.vision.projector_type" # for mixed modality models + IMAGE_SIZE = "clip.vision.image_size" + IMAGE_MIN_PIXELS = "clip.vision.image_min_pixels" + IMAGE_MAX_PIXELS = "clip.vision.image_max_pixels" + PREPROC_MIN_TILES = "clip.vision.preproc_min_tiles" + PREPROC_MAX_TILES = "clip.vision.preproc_max_tiles" + PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size" + PATCH_SIZE = "clip.vision.patch_size" + EMBEDDING_LENGTH = "clip.vision.embedding_length" + FEED_FORWARD_LENGTH = "clip.vision.feed_forward_length" + PROJECTION_DIM = "clip.vision.projection_dim" + BLOCK_COUNT = "clip.vision.block_count" + IMAGE_MEAN = "clip.vision.image_mean" + IMAGE_STD = "clip.vision.image_std" + SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size" + USE_GELU = "clip.use_gelu" + USE_SILU = "clip.use_silu" + N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl + WA_LAYER_INDEXES = "clip.vision.wa_layer_indexes" # used by youtuvl + WA_PATTERN_MODE = "clip.vision.wa_pattern_mode" # used by mimovl, per-layer -1/0/1 + IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers" + WINDOW_SIZE = "clip.vision.window_size" class Attention: HEAD_COUNT = "clip.vision.attention.head_count" + HEAD_COUNT_KV = "clip.vision.attention.head_count_kv" # used by mimovl (GQA) LAYERNORM_EPS = "clip.vision.attention.layer_norm_epsilon" class Projector: @@ -733,6 +735,7 @@ class MODEL_TENSOR(IntEnum): V_ENC_ATTN_V = auto() V_ENC_ATTN_O = auto() V_ENC_ATTN_O_NORM = auto() + V_ENC_ATTN_SINKS = auto() # mimovl V_ENC_POST_ATTN_NORM = auto() V_ENC_FFN_UP = auto() V_ENC_FFN_GATE = auto() @@ -1246,6 +1249,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.V_ENC_INPUT_NORM: "v.blk.{bid}.ln1", MODEL_TENSOR.V_ENC_ATTN_O: "v.blk.{bid}.attn_out", MODEL_TENSOR.V_ENC_ATTN_O_NORM: "v.blk.{bid}.attn_out_norm", + MODEL_TENSOR.V_ENC_ATTN_SINKS: "v.blk.{bid}.attn_sinks", MODEL_TENSOR.V_ENC_POST_ATTN_NORM: "v.blk.{bid}.ln2", MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up", MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate", @@ -1426,6 +1430,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.V_ENC_ATTN_V, MODEL_TENSOR.V_ENC_ATTN_O, MODEL_TENSOR.V_ENC_ATTN_O_NORM, + MODEL_TENSOR.V_ENC_ATTN_SINKS, MODEL_TENSOR.V_ENC_POST_ATTN_NORM, MODEL_TENSOR.V_ENC_FFN_UP, MODEL_TENSOR.V_ENC_FFN_GATE, @@ -4258,6 +4263,7 @@ class VisionProjectorType: HUNYUANVL = "hunyuanvl" MINICPMV4_6 = "minicpmv4_6" GRANITE_SPEECH = "granite_speech" # audio + MIMOVL = "mimovl" # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 35fb01470..a10138271 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1151,6 +1151,9 @@ class GGUFWriter: def add_vision_head_count(self, value: int) -> None: self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value) + def add_vision_head_count_kv(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT_KV, value) + def add_vision_attention_layernorm_eps(self, value: float) -> None: self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value) @@ -1222,6 +1225,9 @@ class GGUFWriter: def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None: self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers) + def add_vision_wa_pattern_mode(self, modes: Sequence[int]) -> None: + self.add_array(Keys.ClipVision.WA_PATTERN_MODE, modes) + def add_vision_window_size(self, value: int) -> None: self.add_uint32(Keys.ClipVision.WINDOW_SIZE, value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index f27f0e4c9..f40cb8282 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1569,6 +1569,10 @@ class TensorNameMap: "vision_model.transformer.resblocks.{bid}.attn.out_proj", # Step3-VL ), + MODEL_TENSOR.V_ENC_ATTN_SINKS: ( + "visual.blocks.{bid}.attn.sinks", # mimovl + ), + MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2", "model.vision_tower.encoder.layers.{bid}.layer_norm2", # minicpmv4_6 diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 21d17dbaa..a76adc9b8 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -34,6 +34,7 @@ add_library(mtmd models/pixtral.cpp models/qwen2vl.cpp models/qwen3vl.cpp + models/mimovl.cpp models/qwen3a.cpp models/step3vl.cpp models/siglip.cpp diff --git a/tools/mtmd/clip-graph.h b/tools/mtmd/clip-graph.h index d3e7b1ed0..39f069501 100644 --- a/tools/mtmd/clip-graph.h +++ b/tools/mtmd/clip-graph.h @@ -98,7 +98,8 @@ struct clip_graph { ggml_tensor * v_cur, ggml_tensor * kq_mask, float kq_scale, - int il) const; + int il, + ggml_tensor * sinks = nullptr) const; // implementation of the 2D RoPE without adding a new op in ggml // this is not efficient (use double the memory), but works on all backends diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 817bf26b2..8e09f26e9 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -31,6 +31,7 @@ #define KEY_N_BLOCK "clip.%s.block_count" #define KEY_PROJ_DIM "clip.%s.projection_dim" #define KEY_N_HEAD "clip.%s.attention.head_count" +#define KEY_N_HEAD_KV "clip.%s.attention.head_count_kv" #define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon" // vision-specific @@ -53,6 +54,7 @@ #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern" #define KEY_WIN_ATTN_LAYER_INDEXES "clip.vision.wa_layer_indexes" +#define KEY_WA_PATTERN_MODE "clip.vision.wa_pattern_mode" #define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" #define KEY_MINICPMV_VERSION "clip.minicpmv_version" #define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num" @@ -86,6 +88,7 @@ #define TN_ATTN_Q "%s.blk.%d.attn_q.%s" #define TN_ATTN_V "%s.blk.%d.attn_v.%s" #define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s" +#define TN_ATTN_SINKS "%s.blk.%d.attn_sinks" #define TN_ATTN_K_NORM "%s.blk.%d.attn_k_norm.%s" #define TN_ATTN_Q_NORM "%s.blk.%d.attn_q_norm.%s" #define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s" @@ -344,6 +347,7 @@ enum projector_type { PROJECTOR_TYPE_HUNYUANVL, PROJECTOR_TYPE_MINICPMV4_6, PROJECTOR_TYPE_GRANITE_SPEECH, + PROJECTOR_TYPE_MIMOVL, PROJECTOR_TYPE_UNKNOWN, }; @@ -393,6 +397,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_HUNYUANVL, "hunyuanvl"}, { PROJECTOR_TYPE_MINICPMV4_6, "minicpmv4_6"}, { PROJECTOR_TYPE_GRANITE_SPEECH, "granite_speech"}, + { PROJECTOR_TYPE_MIMOVL, "mimovl"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 48f8b1a19..ce15dbcd1 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -42,6 +42,7 @@ struct clip_hparams { int32_t n_ff = 0; int32_t projection_dim = 0; int32_t n_head = 0; + int32_t n_head_kv = 0; int32_t n_layer = 0; // idefics3 int32_t n_merge = 0; // number of patch merges **per-side** @@ -83,6 +84,7 @@ struct clip_hparams { int32_t attn_window_size = 0; int32_t n_wa_pattern = 0; std::unordered_set wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL) + std::vector wa_pattern_mode; // mimovl: per-layer window-attention mode // deepseek-ocr (sam) int32_t sam_n_layer = 0; @@ -166,6 +168,8 @@ struct clip_layer { ggml_tensor * o_w = nullptr; ggml_tensor * o_b = nullptr; + ggml_tensor * attn_sinks = nullptr; + ggml_tensor * k_norm = nullptr; ggml_tensor * q_norm = nullptr; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 513b94f2a..f0c63d375 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -642,7 +642,8 @@ ggml_tensor * clip_graph::build_attn( ggml_tensor * v_cur, ggml_tensor * kq_mask, float kq_scale, - int il) const { + int il, + ggml_tensor * sinks) const { // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced ggml_build_forward_expand(gf, q_cur); @@ -665,6 +666,9 @@ ggml_tensor * clip_graph::build_attn( cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f); ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + if (sinks != nullptr) { + ggml_flash_attn_ext_add_sinks(cur, sinks); + } cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); @@ -677,6 +681,9 @@ ggml_tensor * clip_graph::build_attn( // ggml_mul_mat_set_prec(kq, GGML_PREC_F32); kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f); + if (sinks != nullptr) { + ggml_soft_max_add_sinks(kq, sinks); + } ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); @@ -866,6 +873,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_MIMOVL: + { + builder = std::make_unique(ctx, img); + } break; case PROJECTOR_TYPE_STEP3VL: { builder = std::make_unique(ctx, img); @@ -1389,6 +1400,22 @@ struct clip_model_loader { LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__); } } break; + case PROJECTOR_TYPE_MIMOVL: + { + hparams.n_merge = 2; // spatial_merge_size + hparams.image_resize_algo = RESIZE_ALGO_BICUBIC_PILLOW; + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); + get_u32(string_format(KEY_N_HEAD_KV, "vision"), hparams.n_head_kv); + // 1D banded sliding-window radius (visual_token_window_size); required + get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size); + std::vector pat; + get_arr_int(KEY_WA_PATTERN_MODE, pat, true); + GGML_ASSERT((int) pat.size() == hparams.n_layer && "mimovl wa_pattern_mode length must equal n_layer"); + hparams.wa_pattern_mode.assign(pat.begin(), pat.end()); + get_u32(KEY_IMAGE_MIN_PIXELS, hparams.image_min_pixels); + get_u32(KEY_IMAGE_MAX_PIXELS, hparams.image_max_pixels); + hparams.set_warmup_n_tokens(46*46); // avoid OOM on warmup + } break; case PROJECTOR_TYPE_STEP3VL: { hparams.n_merge = 4; // two stride-2 downsamplers after patching @@ -1729,6 +1756,8 @@ struct clip_model_loader { layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight")); layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false); + // mimovl per-head attention sink bias + layer.attn_sinks = get_tensor(string_format(TN_ATTN_SINKS, prefix, il), false); // qwen3vl deepstack layer layer.deepstack_norm_w = get_tensor(string_format(TN_DEEPSTACK_NORM, il, "weight"), false); @@ -1913,6 +1942,13 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; + case PROJECTOR_TYPE_MIMOVL: + { + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); + } break; case PROJECTOR_TYPE_STEP3VL: { model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); @@ -3011,6 +3047,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_MIMOVL: case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_PADDLEOCR: case PROJECTOR_TYPE_HUNYUANOCR: @@ -3032,6 +3069,7 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_MIMOVL: case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_PADDLEOCR: case PROJECTOR_TYPE_HUNYUANVL: @@ -3110,6 +3148,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_MIMOVL: case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_YOUTUVL: { @@ -3681,6 +3720,89 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima set_input_i32("positions", positions); } break; + case PROJECTOR_TYPE_MIMOVL: + { + const int merge = hparams.n_merge; // 2 + const int merge_unit = merge * merge; // 4 + const int patch = hparams.patch_size; // 16 + const int H = image_size_height / patch; + const int W = image_size_width / patch; + const int n_pos_full = H * W; + const int llm_h = H / merge; + const int llm_w = W / merge; + const int n_units = llm_h * llm_w; // n_pos / merge_unit + + // Row-major merge-tile-ordered (h, w) positions + std::vector pos_h_row(n_pos_full); + std::vector pos_w_row(n_pos_full); + { + int idx = 0; + for (int ty = 0; ty < llm_h; ty++) { + for (int tx = 0; tx < llm_w; tx++) { + for (int dy = 0; dy < merge; dy++) { + for (int dx = 0; dx < merge; dx++) { + pos_h_row[idx] = ty * merge + dy; + pos_w_row[idx] = tx * merge + dx; + idx++; + } + } + } + } + } + + // Col-major merge-unit permutation + std::vector idx_col(n_units); + for (int r = 0; r < llm_h; r++) { + for (int c = 0; c < llm_w; c++) { + int u_row = r * llm_w + c; + int u_col = c * llm_h + r; + idx_col[u_col] = (float) u_row; + } + } + + // Col-mode positions: permute pos_*_row by idx_col + std::vector pos_h_col(n_pos_full); + std::vector pos_w_col(n_pos_full); + for (int u = 0; u < n_units; u++) { + int src = (int) idx_col[u]; + for (int k = 0; k < merge_unit; k++) { + pos_h_col[u * merge_unit + k] = pos_h_row[src * merge_unit + k]; + pos_w_col[u * merge_unit + k] = pos_w_row[src * merge_unit + k]; + } + } + + // Pack into ggml_rope_multi VISION-mode layout. The non-CPU kernels + // only read slots 0 and 1, so pack h in slot 0, w in slot 1: + // positions[0..n_pos) = h + // positions[n_pos..2*n_pos) = w + // positions[2*n_pos..3*n_pos) = 0 + // positions[3*n_pos..4*n_pos) = 0 + std::vector positions_row(static_cast(n_pos_full) * 4, 0); + std::vector positions_col(static_cast(n_pos_full) * 4, 0); + for (int i = 0; i < n_pos_full; i++) { + positions_row[0 * n_pos_full + i] = pos_h_row[i]; + positions_row[1 * n_pos_full + i] = pos_w_row[i]; + positions_col[0 * n_pos_full + i] = pos_h_col[i]; + positions_col[1 * n_pos_full + i] = pos_w_col[i]; + } + + // Banded 1D sliding-window mask + const int window = hparams.attn_window_size; + GGML_ASSERT(window > 0); + std::vector mask(static_cast(n_pos_full) * n_pos_full, std::numeric_limits::lowest()); + for (int q = 0; q < n_pos_full; q++) { + int lo = std::max(0, q - window); + int hi = std::min(n_pos_full - 1, q + window); + for (int k = lo; k <= hi; k++) { + mask[static_cast(q) * n_pos_full + k] = 0.0f; + } + } + + set_input_i32("mimovl_positions_row", positions_row); + set_input_i32("mimovl_positions_col", positions_col); + set_input_f32("mimovl_idx_col", idx_col); + set_input_f32("mimovl_window_mask", mask); + } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_KIMIVL: case PROJECTOR_TYPE_KIMIK25: @@ -4081,6 +4203,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_QWEN3VL: // main path + deepstack paths return ctx->model.mm_1_b->ne[0] * (1 + ctx->model.n_deepstack_layers); + case PROJECTOR_TYPE_MIMOVL: + return ctx->model.mm_1_w->ne[1]; case PROJECTOR_TYPE_STEP3VL: return ctx->model.mm_model_proj->ne[1]; case PROJECTOR_TYPE_GEMMA3: diff --git a/tools/mtmd/models/mimovl.cpp b/tools/mtmd/models/mimovl.cpp new file mode 100644 index 000000000..19db88f13 --- /dev/null +++ b/tools/mtmd/models/mimovl.cpp @@ -0,0 +1,209 @@ +#include "models.h" + +ggml_tensor * clip_graph_mimovl::build_mm(ggml_tensor * w, ggml_tensor * x) const { + ggml_tensor * cur = ggml_mul_mat(ctx0, w, x); + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + return cur; +} + +// MiMoVL vision tower for MiMo-V2.5 (non-Pro). Qwen2.5-VL-shaped ViT, except: +// 1. GQA in attention (32 Q / 8 KV heads, head_dim 64). +// 2. Per-head attention sinks on every windowed layer. The sinks adjust +// the softmax denominator (equivalently, a virtual extra K column with V=0), +// so they decay attention weight without contributing to the output. +// 3. Per-layer window-attention mode in hparams.wa_pattern_mode: +// -1 -> full, 0 -> row-window+sinks, 1 -> col-window+sinks. +// Col mode transposes the merge-unit grid on entry and restores +// it on exit. Both patch and rotary orderings are pre-computed +// host-side. +// 4. 1D banded sliding window (|q-k| > window_size -> -inf) as a +// single 2D mask broadcast across heads. +// 5. Per-block MLP biases. +ggml_cgraph * clip_graph_mimovl::build() { + GGML_ASSERT(model.patch_embeddings_0 != nullptr); + GGML_ASSERT(model.patch_embeddings_1 != nullptr); + GGML_ASSERT(model.class_embedding == nullptr); + GGML_ASSERT(hparams.n_head_kv > 0); + GGML_ASSERT(n_head % hparams.n_head_kv == 0); + GGML_ASSERT((int) hparams.wa_pattern_mode.size() == n_layer); + + const int batch_size = 1; + const int n_pos = n_patches; + const int n_head_kv = hparams.n_head_kv; + const int merge = hparams.n_merge > 0 ? hparams.n_merge : 2; + const int merge_unit = merge * merge; + const int n_units = n_pos / merge_unit; + GGML_ASSERT(n_units * merge_unit == n_pos); + + // MiMoVL has head_dim=64 with n_embd=1280, so n_embd is NOT n_head*head_dim + // (the base class's d_head = n_embd/n_head = 40 is wrong here). Derive + // head_dim from the fused QKV projection: rows = (n_head + 2*n_head_kv)*head_dim. + GGML_ASSERT(model.layers[0].qkv_w != nullptr); + const int qkv_rows = model.layers[0].qkv_w->ne[1]; + const int head_dim = qkv_rows / (n_head + 2 * n_head_kv); + GGML_ASSERT(head_dim * (n_head + 2 * n_head_kv) == qkv_rows); + const float attn_scale = 1.0f / std::sqrt((float) head_dim); + const int rope_n_dims = head_dim / 2; + int mrope_sections[4] = {rope_n_dims/2, rope_n_dims/2, 0, 0}; + + // Patch embed: Conv3D(kt=2) split into two Conv2D, then interleave-merge + // along the height axis to match the merge-tile token order. + ggml_tensor * inp_raw = build_inp_raw(); + ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, + patch_size, patch_size, 0, 0, 1, 1); + { + ggml_tensor * inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, + patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_add(ctx0, inp, inp_1); + + GGML_ASSERT(img.nx % (patch_size * 2) == 0); + GGML_ASSERT(img.ny % (patch_size * 2) == 0); + + inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w,h,c,b] -> [c,w,h,b] + inp = ggml_cont_4d(ctx0, inp, n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); + inp = ggml_reshape_4d(ctx0, inp, n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); + inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); + inp = ggml_cont_3d(ctx0, inp, n_embd, n_patches_x * n_patches_y, batch_size); + } + cb(inp, "patch_embed", -1); + + ggml_tensor * positions_row = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos * 4); + ggml_set_name(positions_row, "mimovl_positions_row"); + ggml_set_input(positions_row); + + ggml_tensor * positions_col = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos * 4); + ggml_set_name(positions_col, "mimovl_positions_col"); + ggml_set_input(positions_col); + + // idx_col is the col-major merge-unit permutation. Take it as F32 so we can + // derive the inverse permutation in-graph via ggml_argsort; + // ggml_get_rows requires its index tensor to be I32, so cast back as well. + ggml_tensor * idx_col_f = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_units); + ggml_set_name(idx_col_f, "mimovl_idx_col"); + ggml_set_input(idx_col_f); + ggml_tensor * idx_col = ggml_cast(ctx0, idx_col_f, GGML_TYPE_I32); + ggml_tensor * idx_col_inv = ggml_argsort(ctx0, idx_col_f, GGML_SORT_ORDER_ASC); + + ggml_tensor * window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos); + ggml_set_name(window_mask, "mimovl_window_mask"); + ggml_set_input(window_mask); + + ggml_tensor * window_mask_attn = (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) + ? ggml_cast(ctx0, window_mask, GGML_TYPE_F16) + : window_mask; + + // Reorder helper: permute patches at merge-unit granularity. The patch + // sequence is laid out as n_units groups of merge_unit (=4) consecutive + // patches; the row<->col transpose only permutes whole groups. We keep + // the per-group (h,w) ordering intact by reshaping to + // [n_embd*merge_unit, n_units] before ggml_get_rows. + auto reorder = [&](ggml_tensor * x, ggml_tensor * idx) { + ggml_tensor * y = ggml_reshape_2d(ctx0, x, n_embd * merge_unit, n_units); + y = ggml_get_rows(ctx0, y, idx); + return ggml_reshape_3d(ctx0, y, n_embd, n_pos, batch_size); + }; + + ggml_tensor * inpL = inp; + int prev_mode = -1; + + for (int il = 0; il < n_layer; il++) { + const auto & layer = model.layers[il]; + const int mode = hparams.wa_pattern_mode[il]; + const bool is_full = (mode == -1); + const bool is_col = (mode == 1); + + // Reorder transitions on entry/exit of a col-mode run. + if (is_col && prev_mode != 1) { + inpL = reorder(inpL, idx_col); + cb(inpL, "reorder_to_col", il); + } else if (!is_col && prev_mode == 1) { + inpL = reorder(inpL, idx_col_inv); + cb(inpL, "reorder_to_row", il); + } + + ggml_tensor * cur = inpL; + + // Pre-attention RMSNorm. + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_RMS, eps, il); + cb(cur, "ln1", il); + + // Fused QKV with GQA. + ggml_tensor * qkv = build_mm(layer.qkv_w, cur); + qkv = ggml_add(ctx0, qkv, layer.qkv_b); + + const size_t row = ggml_row_size(qkv->type, head_dim); + const size_t off_k = ggml_row_size(qkv->type, n_head * head_dim); + const size_t off_v = ggml_row_size(qkv->type, (n_head + n_head_kv) * head_dim); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, head_dim, n_head, n_pos, row, qkv->nb[1], 0); + ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, head_dim, n_head_kv, n_pos, row, qkv->nb[1], off_k); + ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, head_dim, n_head_kv, n_pos, row, qkv->nb[1], off_v); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // 2D RoPE + ggml_tensor * pos = is_col ? positions_col : positions_row; + Qcur = ggml_rope_multi(ctx0, Qcur, pos, nullptr, rope_n_dims, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000.0f, 1.0f, 0.0f, 1.0f, 32.0f, 1.0f); + Kcur = ggml_rope_multi(ctx0, Kcur, pos, nullptr, rope_n_dims, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000.0f, 1.0f, 0.0f, 1.0f, 32.0f, 1.0f); + cb(Qcur, "Qcur_rope", il); + cb(Kcur, "Kcur_rope", il); + + // Full layers: plain attention. Windowed layers: banded mask and per-head sinks. + ggml_tensor * mask = is_full ? nullptr : window_mask_attn; + ggml_tensor * sinks = is_full ? nullptr : layer.attn_sinks; + if (!is_full) { + GGML_ASSERT(layer.attn_sinks != nullptr); + } + ggml_tensor * attn_out = build_attn(layer.o_w, layer.o_b, Qcur, Kcur, Vcur, mask, attn_scale, il, sinks); + cb(attn_out, "attn_out", il); + + // Residual 1. + cur = ggml_add(ctx0, attn_out, inpL); + inpL = cur; + cb(cur, "ffn_inp", il); + + // Pre-FFN RMSNorm. + cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_RMS, eps, il); + cb(cur, "ffn_inp_normed", il); + + // SwiGLU MLP with biases + cur = build_ffn(cur, + layer.ff_up_w, layer.ff_up_b, + layer.ff_gate_w, layer.ff_gate_b, + layer.ff_down_w, layer.ff_down_b, + hparams.ffn_op, il); + cb(cur, "ffn_out", il); + + // Residual 2. + cur = ggml_add(ctx0, inpL, cur); + cb(cur, "layer_out", il); + + inpL = cur; + prev_mode = mode; + } + + // If the last block was col-mode, undo the transpose so the merger sees patches in row order. + if (prev_mode == 1) { + inpL = reorder(inpL, idx_col_inv); + cb(inpL, "reorder_to_row_final", -1); + } + + // Merger: post-LayerNorm + inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, NORM_TYPE_NORMAL, 1e-6f, n_layer); + cb(inpL, "post_ln", -1); + + // Spatial merge: pack each merge_unit (=4) of patches into a single + // (n_embd*merge_unit)-wide row, then run the 2-layer MLP. + ggml_tensor * embeddings = ggml_reshape_3d(ctx0, inpL, n_embd * merge_unit, n_units, batch_size); + embeddings = build_ffn(embeddings, + model.mm_0_w, nullptr, + nullptr, nullptr, + model.mm_1_w, nullptr, + FFN_GELU, -1); + cb(embeddings, "vit_out", -1); + + ggml_build_forward_expand(gf, embeddings); + return gf; +} diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index dbba233b1..955daa6d6 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -33,6 +33,15 @@ struct clip_graph_qwen3vl : clip_graph { ggml_cgraph * build() override; }; +struct clip_graph_mimovl : clip_graph { + clip_graph_mimovl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; + // Force F32 mat-mul accumulation to avoid F16 overflow in the FFN down-proj + // when the mmproj is stored in F16 (the source weights are BF16; downcasting + // to F16 reduces dynamic range below the SwiGLU output magnitude on the last few layers). + ggml_tensor * build_mm(ggml_tensor * w, ggml_tensor * x) const override; +}; + struct clip_graph_step3vl : clip_graph { clip_graph_step3vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 87da6876f..22092f6a6 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -325,6 +325,7 @@ struct mtmd_context { case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_MIMOVL: { // <|vision_start|> ... (image embeddings) ... <|vision_end|> img_beg = "<|vision_start|>"; From fa62042af9cc858832b613ec51cce98f9884ce01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Tue, 12 May 2026 11:34:10 +0200 Subject: [PATCH 14/17] ci : bump ty to 0.0.35 (#22961) --- .github/workflows/python-type-check.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-type-check.yml b/.github/workflows/python-type-check.yml index 2d3fa163d..cbeeb39d0 100644 --- a/.github/workflows/python-type-check.yml +++ b/.github/workflows/python-type-check.yml @@ -31,7 +31,7 @@ jobs: uses: actions/setup-python@v6 with: python-version: "3.11" - pip-install: -r requirements/requirements-all.txt ty==0.0.33 + pip-install: -r requirements/requirements-all.txt ty==0.0.35 # - name: Type-check with Pyright # uses: jakebailey/pyright-action@v2 # with: From 706fbd8ab62797bd3702a3b20312efb801db49a6 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 12 May 2026 04:41:58 -0500 Subject: [PATCH 15/17] vulkan: Check shared memory size for mmq shaders (#22693) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 168 ++++++++++++++++++++++++--- 1 file changed, 149 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 7e450a559..90ea7cc1a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -681,6 +681,15 @@ struct vk_device_struct { bool mul_mat_id_m[GGML_TYPE_COUNT]; bool mul_mat_id_s[GGML_TYPE_COUNT]; + // Separate flags for the q8_1 (integer dot) mmq path, whose shader uses + // a different shared-memory layout than the float matmul shaders. + bool mul_mat_l_int[GGML_TYPE_COUNT]; + bool mul_mat_m_int[GGML_TYPE_COUNT]; + bool mul_mat_s_int[GGML_TYPE_COUNT]; + bool mul_mat_id_l_int[GGML_TYPE_COUNT]; + bool mul_mat_id_m_int[GGML_TYPE_COUNT]; + bool mul_mat_id_s_int[GGML_TYPE_COUNT]; + vk::DescriptorSetLayout dsl; vk_matmul_pipeline pipeline_matmul_f32 {}; @@ -3207,6 +3216,70 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec return supported; } +// Shmem usage for the q8_1 mmq shader (mul_mmq.comp), which uses +// block_a_cache / block_b_cache layouts (see mul_mmq_shmem_types.glsl) rather +// than the float load buffers checked by ggml_vk_matmul_shmem_support. +// Sizes follow std430 rules. Returns false for types without a q8_1 pipeline. +static bool ggml_vk_matmul_int_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { + + // FLOAT_TYPE in the shader is float16_t with fp16 support, otherwise float. + const uint32_t fp_size = device->fp16 ? 2u : 4u; + const uint32_t fp_align = fp_size; + const uint32_t fp2_size = 2u * fp_size; + const uint32_t fp2_align = device->fp16 ? 4u : 8u; + + struct member { uint32_t size, align; }; + auto std430_size = [](std::initializer_list members) { + uint32_t off = 0, struct_align = 1; + for (const auto &m : members) { + off = (off + m.align - 1) & ~(m.align - 1); + off += m.size; + struct_align = std::max(struct_align, m.align); + } + return (off + struct_align - 1) & ~(struct_align - 1); + }; + + uint32_t block_a_size = 0; + switch (src0_type) { + case GGML_TYPE_Q4_0: block_a_size = std430_size({{16, 4}, {fp_size, fp_align}}); break; // qs[16/4] + dm + case GGML_TYPE_Q4_1: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + dm(vec2) + case GGML_TYPE_Q5_0: block_a_size = std430_size({{16, 4}, {4, 4}, {fp_size, fp_align}}); break; // qs[16/4] + qh + dm + case GGML_TYPE_Q5_1: block_a_size = std430_size({{16, 4}, {4, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + qh + dm(vec2) + case GGML_TYPE_Q8_0: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + dm + case GGML_TYPE_MXFP4: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + d + case GGML_TYPE_Q2_K: block_a_size = std430_size({{ 8, 4}, {2, 2}, {fp2_size, fp2_align}}); break; // qs[2] + scales(u8vec2) + dm(vec2) + case GGML_TYPE_Q3_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + d_scales(vec2) + case GGML_TYPE_Q4_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + dm(vec2) + case GGML_TYPE_Q5_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + dm(vec2) + case GGML_TYPE_Q6_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + d_scales(vec2) + default: + return false; + } + + // block_b_cache: { int32_t qs[8]; FLOAT_TYPEV2 ds; } + const uint32_t block_b_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); + + const uint32_t BM = warptile[1]; + const uint32_t BN = warptile[2]; + // mul_mmq.comp: BK_STEP=1 for MUL_MAT_ID, 4 otherwise. + const uint32_t BK_STEP = mul_mat_id ? 1u : 4u; + + const uint32_t buf_a_size = BM * BK_STEP * block_a_size; + const uint32_t buf_b_size = BN * BK_STEP * block_b_size; + const uint32_t mmid_row_ids = mul_mat_id ? (BN * 2u * (uint32_t)sizeof(uint16_t)) : 0u; + + const uint32_t warps = warptile[0] / warptile[10]; + const uint32_t ballots_sh = mul_mat_id ? (warps * 4u * (uint32_t)sizeof(uint32_t)) : 0u; + + const uint32_t total_size = buf_a_size + buf_b_size + mmid_row_ids + ballots_sh; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_matmul_int_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), " + "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", total=" << total_size << ", supported=" << supported); + + return supported; +} + struct GpuPipelineConfig { // GPU architecture identifier. // Example: vk_device_architecture::AMD_GCN @@ -3453,6 +3526,40 @@ static void ggml_vk_load_shaders(vk_device& device) { } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) { device->mul_mat_id_l[i] = false; } + + // The q8_1 mmq path has its own (larger) shmem layout, check it separately. + // K-quants use the _int_k warptiles, others use _int. + const bool is_k_quant = (t == GGML_TYPE_Q2_K || t == GGML_TYPE_Q3_K || + t == GGML_TYPE_Q4_K || t == GGML_TYPE_Q5_K || + t == GGML_TYPE_Q6_K); + const auto & s_int = is_k_quant ? s_warptile_mmq_int_k : s_warptile_mmq_int; + const auto & m_int = is_k_quant ? m_warptile_mmq_int_k : m_warptile_mmq_int; + const auto & l_int = is_k_quant ? l_warptile_mmq_int_k : l_warptile_mmq_int; + const auto & s_intid = is_k_quant ? s_warptile_mmqid_int_k : s_warptile_mmqid_int; + const auto & m_intid = is_k_quant ? m_warptile_mmqid_int_k : m_warptile_mmqid_int; + const auto & l_intid = is_k_quant ? l_warptile_mmqid_int_k : l_warptile_mmqid_int; + + if (!ggml_vk_matmul_int_shmem_support(device, s_int, false, t)) { + device->mul_mat_s_int[i] = false; + device->mul_mat_m_int[i] = false; + device->mul_mat_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, m_int, false, t)) { + device->mul_mat_m_int[i] = false; + device->mul_mat_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, l_int, false, t)) { + device->mul_mat_l_int[i] = false; + } + + if (!ggml_vk_matmul_int_shmem_support(device, s_intid, true, t)) { + device->mul_mat_id_s_int[i] = false; + device->mul_mat_id_m_int[i] = false; + device->mul_mat_id_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, m_intid, true, t)) { + device->mul_mat_id_m_int[i] = false; + device->mul_mat_id_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, l_intid, true, t)) { + device->mul_mat_id_l_int[i] = false; + } } } @@ -5613,6 +5720,13 @@ static vk_device ggml_vk_get_device(size_t idx) { device->mul_mat_id_s[i] = true; break; } + + device->mul_mat_l_int[i] = true; + device->mul_mat_m_int[i] = true; + device->mul_mat_s_int[i] = true; + device->mul_mat_id_l_int[i] = true; + device->mul_mat_id_m_int[i] = true; + device->mul_mat_id_s_int[i] = true; } @@ -7220,6 +7334,13 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) { VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + // The q8_1 (integer dot) mmq path uses a different shader with its own + // shared-memory layout, so use the int-specific availability flags. + const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1); + const bool mm_l = is_q8_1 ? ctx->device->mul_mat_l_int[src0_type] : ctx->device->mul_mat_l[src0_type]; + const bool mm_m = is_q8_1 ? ctx->device->mul_mat_m_int[src0_type] : ctx->device->mul_mat_m[src0_type]; + const bool mm_s = is_q8_1 ? ctx->device->mul_mat_s_int[src0_type] : ctx->device->mul_mat_s[src0_type]; + if (ctx->device->coopmat2) { const uint32_t shader_core_count = ctx->device->shader_core_count; const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]); @@ -7236,26 +7357,24 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, // split_k==3 with large tiles likely better than medium tiles with no split_k. (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2); - if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { + if ((mm_l && (n > crossover_large && prefer_large)) || (!mm_m && !mm_s)) { return aligned ? mmp->a_l : mmp->l; } // Use medium shader when the N dimension is greater than the small shader's tile size uint32_t crossover_medium = mmp->s->wg_denoms[1]; - if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) { + if ((mm_m && (n > crossover_medium)) || !mm_s) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) { + if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) { return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) { + if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; - - GGML_UNUSED(src1_type); } static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { @@ -7312,35 +7431,42 @@ static void ggml_vk_matmul( ctx->prealloc_split_k_need_sync = true; } -static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); +static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + + // The q8_1 (integer dot) mmq path uses a different shader with its own + // shared-memory layout, so use the int-specific availability flags. + const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1); + const bool mm_l = is_q8_1 ? ctx->device->mul_mat_id_l_int[src0_type] : ctx->device->mul_mat_id_l[src0_type]; + const bool mm_m = is_q8_1 ? ctx->device->mul_mat_id_m_int[src0_type] : ctx->device->mul_mat_id_m[src0_type]; + const bool mm_s = is_q8_1 ? ctx->device->mul_mat_id_s_int[src0_type] : ctx->device->mul_mat_id_s[src0_type]; if (ctx->device->coopmat2) { // Use large shader when the N dimension is greater than the medium shader's tile size uint32_t crossover_large = mmp->m->wg_denoms[1]; - if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { + if ((mm_l && (n > crossover_large)) || (!mm_m && !mm_s)) { return aligned ? mmp->a_l : mmp->l; } // Use medium shader when the N dimension is greater than the small shader's tile size uint32_t crossover_medium = mmp->s->wg_denoms[1]; - if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) { + if ((mm_m && (n > crossover_medium)) || !mm_s) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) { + if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) { return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) { + if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; } -static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")"); - return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align; +static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align; } static void ggml_vk_matmul_id( @@ -7636,10 +7762,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type))); + const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type); + + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, effective_src1_type)); const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)); + vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type); if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline); @@ -8471,10 +8599,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); + const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type); + + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type, effective_src1_type)); const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type); + vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type); if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline); From ef93e98d01d7c04dae1ebb145793a5e397a57133 Mon Sep 17 00:00:00 2001 From: Masato Nakasaka Date: Tue, 12 May 2026 03:15:34 -0700 Subject: [PATCH 16/17] vulkan: Fix Windows performance regression on Intel GPU BF16 workloads for Xe2 and newer (#22461) * refactor * Use l_warptile only when coopamt is available for BF16 --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 90ea7cc1a..a0a556206 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4260,11 +4260,6 @@ static void ggml_vk_load_shaders(vk_device& device) { m_wg_denoms = { 64, 64, 1 }; s_wg_denoms = { 32, 32, 1 }; - if (device->vendor_id == VK_VENDOR_ID_INTEL && device->architecture == INTEL_XE2) { - // Xe2/Xe3 - bf16 warptile performance tuning - l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, 4, 4, 1, subgroup_size_8 }; - } - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); } @@ -5689,19 +5684,19 @@ static vk_device ggml_vk_get_device(size_t idx) { device->mul_mat_id_m[i] = true; device->mul_mat_id_s[i] = true; break; - case VK_VENDOR_ID_INTEL: - if (!device->coopmat_support || device->architecture != INTEL_XE2) { - device->mul_mat_l[i] = false; - device->mul_mat_id_l[i] = false; - } else { - device->mul_mat_l[i] = true; // if coopmat & XE2+, allow large matmul warptile config for Intel - device->mul_mat_id_l[i] = true; - } + case VK_VENDOR_ID_INTEL: { + // Current Windows driver does not expose BF16 support. + // We only want to use l_warptile if coopmat is available and is Xe2+ + const bool xe2_with_coopmat = device->coopmat_support && device->architecture == INTEL_XE2; + const bool use_l_warptile = (i == GGML_TYPE_BF16) ? (device->coopmat_bf16_support && xe2_with_coopmat) : xe2_with_coopmat; + device->mul_mat_l[i] = use_l_warptile; + device->mul_mat_id_l[i] = use_l_warptile; device->mul_mat_m[i] = true; device->mul_mat_s[i] = true; device->mul_mat_id_m[i] = true; device->mul_mat_id_s[i] = true; break; + } case VK_VENDOR_ID_APPLE: device->mul_mat_l[i] = false; device->mul_mat_m[i] = true; From fde69a3607ff48b20bbb73e9a0059d5a142a5cf1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 12 May 2026 15:07:00 +0300 Subject: [PATCH 17/17] examples : add llama-eval (#21152) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * working llama-eval mc and math suite * multi source llama-eval * Add readme * add checkpointing * examples: add llama-server simulator for testing eval scripts Add a standalone Python script that simulates a llama-server HTTP endpoint for testing the eval script. The simulator: - Implements /v1/chat/completions endpoint with OpenAI-compatible format - Loads AIME dataset from HuggingFace with local caching - Uses Levenshtein distance for intelligent question matching - Supports configurable success rate for correct/wrong answer generation - Provides debug logging for troubleshooting Also includes test scripts and documentation for testing and understanding the simulator functionality. * examples: refactor test-simulator.sh for better readability Extract repeating question string into TEST_QUESTION variable and create make_request() helper function to reduce code duplication. Add proper error handling for error responses. * docs: update llama-eval-discussion.md with session work summary Add summary of llama-server-simulator implementation work including features, testing results, technical decisions, and refactoring. * examples: add simplified llama-eval-new.py for AIME evaluation - Create new simplified evaluation script focused only on AIME - Implement EvalState and Processor dataclasses for structured state management - Add real-time feedback showing correct/incorrect status per case - Abstract grading interface for external grader support - Use structured JSON output for eval state - Apply HuggingFace dataset caching to avoid repeated downloads - Remove Levenshtein matching - eval script only sends requests and validates answers * docs: remove README.md from llama-eval * examples: implement flexible grader system for answer validation - Add Grader class supporting regex and CLI-based grading - Implement built-in regex patterns for AIME, GSM8K, MMLU, HellaSwag, ARC, WinoGrande - Add CLI grader interface: python script.py --answer --expected - Add HF telemetry disable to avoid warnings - Support exact match requirement for regex patterns - Add 30-second timeout for CLI grader - Handle both boxed and plain text formats for AIME answers * examples: use HF_HUB_OFFLINE to avoid HF Hub warnings * examples: remove HF_HUB_OFFLINE to allow dataset download * examples: use cached dataset path to avoid HF Hub requests * examples: use cached dataset path in simulator to avoid HF Hub requests * docs: update llama-eval-discussion.md with session work summary * examples: add threading support and model parameter to llama-eval-new.py - Add ThreadPoolExecutor for parallel request processing controlled by --threads - Add --model argument to specify model name in request data - Refactor process() to use thread-safe _process_single_case() method - Update progress tracking to work with concurrent execution * docs: update llama-eval-discussion.md with threading and model parameter updates - Add threading support implementation details - Document ThreadPoolExecutor usage and thread safety - Add model parameter implementation details - Include testing results for both features * examples: add task summary table to llama-eval-new.py * eval : print progress * eval : add prompts * test : fix path * sim : fix answer matching * eval : support multiple dataset runs * minor * improve grader * docs * remove old files * datasets : add gsm8k * add gpqa + sampling + docs * rename * grader : improve example answers * cont * datasets : add aime2025 * grader : update prompt * grade : improve regex + logs * datasets : fix aime2025 * cleanup * add AGENTS.md * ignore errors * resume eval * cleanup * fix counts * simplify * fix prompts * add html * store full response * add tokens * resoning and error handling * refactor * track total time * remove junk * eval : unify "judge" terminology to "grader" Replace all occurrences of "judge" with "grader" for consistency across the codebase (CLI args, Grader class fields, help text). Assisted-by: llama.cpp:local pi * eval : add Wilson score confidence interval to results Compute 95% CI on-the-fly from completed cases. Displayed in terminal output, HTML report, and JSON state. * llama-eval : add per-task generation speed from server timings Extract predicted_per_second from the server timings response and store it as tps_gen per task. Display in console progress, print_all_tasks, and HTML report. Assisted-by: llama.cpp:local pi * llama-eval : add per-task generation time from server timings Extract predicted_ms from the server timings response and store it as t_gen_ms per task. Display in seconds with one decimal digit in console progress, print_all_tasks, and HTML report. Assisted-by: llama.cpp:local pi * llama-eval : rename display, escaped, and count variables to use prefix convention - _display suffix → display_ prefix (answer, tokens, tps, t_gen) - _escaped suffix → escaped_ prefix (response, prompt, reasoning) - _count suffix → n_ prefix (correct, incorrect, pending) Assisted-by: llama.cpp:local pi * llama-eval : support multiple evaluation endpoints with dynamic task distribution - Add ServerConfig dataclass (url, threads, name) - Accept comma-separated --server, --threads, --server-name CLI args - Dynamic shared-queue task distribution across servers (fast servers do more work) - One ThreadPoolExecutor per server, workers pull from shared Queue - Track which server processed each task (server_name in results) - Thread-safe EvalState with threading.Lock for concurrent mutations - Server column in HTML report and console output - Backward compatible: single server works as before Assisted-by: llama.cpp:local pi * llama-server-simulator : replace Flask with stdlib http.server - Use HTTPServer + BaseHTTPRequestHandler instead of Flask - RequestHandler handles POST /v1/chat/completions - Server runs in daemon thread with clean Ctrl+C shutdown - Remove flask and unused asdict imports Assisted-by: llama.cpp:local pi * llama-eval : update README with PR link and quick-start examples Assisted-by: llama.cpp:local pi * llama-eval : track model name in eval state and verify on resume - Store model_name in EvalState and JSON output - Display model in HTML summary table - Verify --model matches stored model when resuming Assisted-by: llama.cpp:local pi * llama-server-simulator : fix comment - Dice coefficient, not Levenshtein Assisted-by: llama.cpp:local pi * llama-eval : require --grader-model or --model when using --grader-type llm Assisted-by: llama.cpp:local pi * llama-eval : protect dump() with lock for thread safety Assisted-by: llama.cpp:local pi * llama-eval : compact HTML report output - Replace verbose summary table with single inline bar - Shorten status text: '✓'/'✗'/'–'/'!' instead of full words - Flatten CSS: remove box-shadows, border-radius, reduce padding - Use system-ui font, 13px table, 12px details - Conditional reasoning section (only shown when present) - Single toggle JS function instead of two - Shorter column headers Assisted-by: llama.cpp:local pi * llama-eval : check server connectivity on startup - Hit /v1/models for each server before evaluation - Exit with error if any server is unreachable - Print comma-separated model IDs per server in startup output - Sequential checks, no retries, no timeout override Assisted-by: llama.cpp:local pi * llama-eval : use server1/server2 instead of gpu1/gpu2 in README Assisted-by: llama.cpp:local pi --------- Co-authored-by: gatbontonpc --- examples/llama-eval/README.md | 26 + examples/llama-eval/llama-eval.py | 1416 +++++++++++++++++ examples/llama-eval/llama-server-simulator.py | 317 ++++ examples/llama-eval/test-simulator.sh | 86 + 4 files changed, 1845 insertions(+) create mode 100644 examples/llama-eval/README.md create mode 100755 examples/llama-eval/llama-eval.py create mode 100755 examples/llama-eval/llama-server-simulator.py create mode 100755 examples/llama-eval/test-simulator.sh diff --git a/examples/llama-eval/README.md b/examples/llama-eval/README.md new file mode 100644 index 000000000..3c5c35f78 --- /dev/null +++ b/examples/llama-eval/README.md @@ -0,0 +1,26 @@ +# llama-eval + +Simple evaluation tool for llama.cpp with support for multiple datasets. + +For a full description, usage examples, and sample results, see: + +- [PR 21152](https://github.com/ggml-org/llama.cpp/pull/21152) + +## Quick start + +```bash +# Single server +python3 llama-eval.py \ + --server http://localhost:8033 \ + --model my-model \ + --dataset gsm8k --n_cases 100 \ + --grader-type regex --threads 32 + +# Multiple servers (comma-separated URLs and thread counts) +python3 llama-eval.py \ + --server http://server1:8033,http://server2:8033 \ + --server-name server1,server2 \ + --threads 16,16 \ + --dataset aime2025 --n_cases 240 \ + --grader-type regex +``` diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py new file mode 100755 index 000000000..b33a3615b --- /dev/null +++ b/examples/llama-eval/llama-eval.py @@ -0,0 +1,1416 @@ +#!/usr/bin/env python3 +# type: ignore + +import argparse +import json +import os +import re +import subprocess +import sys +import threading +import time +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass, asdict, field +from pathlib import Path +from queue import Queue +from typing import Dict, List, Optional, Any, Tuple +import requests +from tqdm import tqdm +import random +from math import sqrt + + +@dataclass +class ServerConfig: + url: str + threads: int + name: str = "" + +def wilson_interval(correct: int, total: int, z: float = 1.96) -> Tuple[float, float]: + """Wilson score confidence interval for a proportion.""" + if total == 0: + return (0.0, 1.0) + p = correct / total + z2 = z * z / total + center = (p + z2 / 2) / (1 + z2) + margin = z * sqrt((p * (1 - p) + z2 / 4) / total) / (1 + z2) + return (center - margin, center + margin) + +cache_dir = Path.home() / ".cache" / "huggingface" / "datasets" +cache_dir.mkdir(parents=True, exist_ok=True) +os.environ["HF_DATASETS_CACHE"] = str(cache_dir) +os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" + +GRADER_PATTERNS = { + "aime": r'\boxed{(\d+)}|\b(\d+)\b', + "aime2025": r'\boxed{(\d+)}|\b(\d+)\b', + "gsm8k": r'\b(\d+)\b', +} + +SAMPLE_ANSWERS = { + "aime": [ + "42", + "-123", + "999" + ], + "aime2025": [ + "42", + "-123", + "999" + ], + "gsm8k": [ + "42", + "-123", + "999" + ], + "gpqa": [ + "A", + "D", + "C" + ], +} + +TEMPLATE_REGISTRY = { + "aime": """Solve the following math problem step by step. Put your answer inside \\boxed{{}}. + +{question} + +Remember to put your answer inside \\boxed{{}}. +""", + "aime2025": """Solve the following math problem step by step. Put your answer inside \\boxed{{}}. + +{question} + +Remember to put your answer inside \\boxed{{}}. +""", + "gsm8k": """{question} +Please reason step by step, and put your final numeric answer within \\boxed{{}} without any extra characters. +""", + "gpqa": """Answer the following multiple choice question. The last line of your response should be in the following format: 'Answer: A/B/C/D' (e.g. 'Answer: A'). + +{Question} + +A) {A} +B) {B} +C) {C} +D) {D} +""", +} + + +class BaseDataset(ABC): + @abstractmethod + def get_question(self, index: int) -> Dict: + pass + + @abstractmethod + def get_question_text(self, question: Dict) -> str: + pass + + @abstractmethod + def get_answer(self, question: Dict) -> str: + pass + + @abstractmethod + def get_prompt(self, question: Dict) -> str: + pass + + def __len__(self) -> int: + return len(self.questions) + + +@dataclass +class TaskState: + task_id: str + prompt: str + expected: str + question_text: str = "" + response: Optional[str] = None + answer: Optional[str] = None + grader_log: Dict[str, Any] = field(default_factory=dict) + correct: bool = False + status: str = "pending" + tokens: Optional[int] = None + tps_gen: Optional[float] = None + t_gen_ms: Optional[float] = None + reasoning_content: Optional[str] = None + server_name: Optional[str] = None + + +class EvalState: + def __init__( + self, + dataset_type: str, + sampling_config: Dict[str, Any], + output_file: Path = Path("llama-eval-state.json"), + model_name: Optional[str] = None + ): + self.dataset_type = dataset_type + self.sampling_config = sampling_config + self.output_file = output_file + self.model_name = model_name + self.dataset: Optional[BaseDataset] = None + self.tasks: List[Tuple[int, str]] = [] + self.all_tasks: List[Tuple[int, str]] = [] + self.task_states: Dict[str, Any] = {} + self.total = 0 + self.correct = 0 + self.processed = 0 + self.total_time: float = 0.0 + self._lock = threading.Lock() + + def load_dataset(self, seed: int = 1234): + if self.dataset_type == "aime": + self.dataset = AimeDataset() + elif self.dataset_type == "aime2025": + self.dataset = Aime2025Dataset() + elif self.dataset_type == "gsm8k": + self.dataset = Gsm8kDataset() + elif self.dataset_type == "gpqa": + self.dataset = GpqaDataset(variant="diamond", seed=seed) + else: + raise ValueError(f"Unknown dataset type: {self.dataset_type}") + + def setup_tasks(self, n_cases: Optional[int] = None, seed: int = 1234): + if self.dataset is None: + raise ValueError("Dataset not loaded. Call load_dataset() first.") + + if n_cases is None: + n_cases = len(self.dataset) + + dataset_size = len(self.dataset) + rng = random.Random(seed) + + self.tasks = [] + for chunk_idx in range((n_cases + dataset_size - 1) // dataset_size): + chunk_size = min(dataset_size, n_cases - chunk_idx * dataset_size) + indices = list(range(dataset_size)) + rng.shuffle(indices) + chunk_indices = indices[:chunk_size] + + for i in chunk_indices: + task_id = f"{self.dataset_type}_{chunk_idx:03d}_{i:03d}" + self.tasks.append((i, task_id)) + + self.all_tasks = list(self.tasks) + + def get_case(self, index: int) -> Tuple[str, str, str]: + if self.dataset is None: + raise ValueError("Dataset not loaded.") + question = self.dataset.get_question(index) + question_text = self.dataset.get_question_text(question) + prompt = self.dataset.get_prompt(question) + expected = self.dataset.get_answer(question) + return question_text, prompt, expected + + def add_result( + self, + task_id: str, + prompt: str, + expected: str, + response: Optional[str], + answer: Optional[str], + grader_log: Dict[str, Any], + correct: bool, + status: str, + tokens: Optional[int] = None, + tps_gen: Optional[float] = None, + t_gen_ms: Optional[float] = None, + reasoning_content: Optional[str] = None, + server_name: Optional[str] = None + ): + with self._lock: + if "cases" not in self.task_states: + self.task_states["cases"] = {} + + self.task_states["cases"][task_id] = { + "task_id": task_id, + "prompt": prompt, + "expected": expected, + "response": response, + "answer": answer, + "grader_log": grader_log, + "correct": correct, + "status": status, + "tokens": tokens, + "tps_gen": tps_gen, + "t_gen_ms": t_gen_ms, + "reasoning_content": reasoning_content, + "server_name": server_name + } + + self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False)) + + def print_progress(self, task_state: TaskState, total_tasks: int, n_correct: int = 0): + display_answer = task_state.answer if task_state.answer else "N/A" + display_tokens = str(task_state.tokens) if task_state.tokens is not None else "N/A" + display_tps = f"{task_state.tps_gen:.1f}" if task_state.tps_gen is not None else "N/A" + display_t_gen = f"{task_state.t_gen_ms/1000:.1f}" if task_state.t_gen_ms is not None else "N/A" + display_server = task_state.server_name if task_state.server_name else "N/A" + success_ratio = n_correct / self.processed if self.processed > 0 else 0.0 + first_line = task_state.question_text.split('\n')[0] + truncated_question = first_line[:43] + if len(first_line) > 43: + truncated_question += "..." + else: + truncated_question = truncated_question.ljust(43) + "..." + print(f"{self.processed:3}/{total_tasks:3} {task_state.task_id:<20} {self.dataset_type.upper()} {truncated_question:<40} {task_state.expected:<10} {display_answer:<10} {display_tokens:<6} {display_tps:<6} {display_t_gen:<8} {'✓' if task_state.correct else '✗'} [{n_correct:3}/{self.processed:3}, {success_ratio:.3f}] {display_server}") + + def print_summary(self): + if self.total == 0: + print(f"\n{'='*60}") + print(f"Results: 0/0 correct (0.0%)") + print(f"{'='*60}") + else: + ci_lower, ci_upper = self.accuracy_ci() + print(f"\n{'='*60}") + print(f"Results: {self.correct}/{self.total} correct ({self.correct/self.total*100:.1f}%) [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]") + print(f"{'='*60}") + + def dump(self): + with self._lock: + tasks_to_save = self.all_tasks if self.all_tasks else self.tasks + all_cases = {} + for i, task_id in tasks_to_save: + question_text, prompt, expected = self.get_case(i) + if task_id in self.task_states.get("cases", {}): + all_cases[task_id] = self.task_states["cases"][task_id] + else: + all_cases[task_id] = { + "task_id": task_id, + "prompt": prompt, + "expected": expected, + "question_text": question_text, + "response": None, + "answer": None, + "grader_log": {}, + "correct": False, + "status": "pending", + "tokens": None, + "tps_gen": None, + "t_gen_ms": None, + "reasoning_content": None, + "server_name": None + } + + ci_lower, ci_upper = self.accuracy_ci() + data = { + "id": self.dataset_type, + "model_name": self.model_name, + "tasks": [tid for _, tid in tasks_to_save], + "task_states": { + "total": self.total, + "correct": self.correct, + "total_time": self.total_time, + "ci_lower": ci_lower, + "ci_upper": ci_upper, + "cases": all_cases, + }, + "sampling_config": self.sampling_config + } + with open(self.output_file, "w") as f: + json.dump(data, f, indent=2) + + self.dump_html(tasks_to_save, all_cases) + + def dump_html(self, tasks_to_save: List[Tuple[int, str]], all_cases: Dict[str, Any]): + html_file = Path(str(self.output_file) + ".html") + + cases = all_cases + completed = {tid: c for tid, c in cases.items() if c.get("status") == "ok"} + n_correct = sum(1 for c in completed.values() if c.get("correct", False)) + n_incorrect = len(completed) - n_correct + n_pending = len(tasks_to_save) - len(completed) + accuracy = n_correct / len(completed) * 100 if completed else 0.0 + ci_lower, ci_upper = wilson_interval(n_correct, len(completed)) if completed else (0.0, 1.0) + + sampling_parts = [] + for k, v in self.sampling_config.items(): + if v is not None: + sampling_parts.append(f"{k}={v}") + sampling_str = ", ".join(sampling_parts) if sampling_parts else "default" + + rows = [] + for i, task_id in tasks_to_save: + case = cases.get(task_id, {}) + status = case.get("status", "pending") + expected = case.get("expected", "") + answer = case.get("answer", "") if status == "ok" else "" + is_correct = case.get("correct", False) if status == "ok" else False + response = case.get("response", "") or "" + prompt = case.get("prompt", "") or "" + grader_log = case.get("grader_log", {}) + + if status == "ok": + status_class = "correct" if is_correct else "incorrect" + status_text = "✓" if is_correct else "✗" + elif status == "pending": + status_class = "pending" + status_text = "–" + else: + status_class = "error" + status_text = "!" + + tokens = case.get("tokens") + tokens_str = str(tokens) if tokens is not None else "" + tps_gen = case.get("tps_gen") + tps_str = f"{tps_gen:.1f}" if tps_gen is not None else "" + t_gen_ms = case.get("t_gen_ms") + t_gen_str = f"{t_gen_ms/1000:.1f}" if t_gen_ms is not None else "" + reasoning_content = case.get("reasoning_content", "") or "" + server_name = case.get("server_name", "") or "" + + escaped_response = self._escape_html(response) + escaped_prompt = self._escape_html(prompt) + escaped_reasoning = self._escape_html(reasoning_content) + grader_log_str = self._escape_html(json.dumps(grader_log, indent=2)) + escaped_server = self._escape_html(server_name) + + rows.append(f""" + {task_id} + {status_text} + {self._escape_html(expected)} + {self._escape_html(answer)} + {tokens_str} + {tps_str} + {t_gen_str} + {escaped_server} + + + +
+ Prompt
{escaped_prompt}
+ Response
{escaped_response}
+ {f'Reasoning
{escaped_reasoning}
' if escaped_reasoning else ''} + Grader
{grader_log_str}
+
+ + """) + + rows_html = "\n".join(rows) + + html_content = f""" + + + +{self.dataset_type.upper()} Eval + + + +
+ {self.dataset_type.upper()} + Model: {self.model_name or 'N/A'} + Accuracy: {accuracy:.1f}% [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%] + Correct: {n_correct} / {len(completed)} + Pending: {n_pending} + Time: {self.total_time:.1f}s + Sampling: {sampling_str} +
+ + + + + + + + + + + + + + + {rows_html} + +
IDGoldAnswerTokensT/sGen sServer
+ + +""" + + with open(html_file, "w") as f: + f.write(html_content) + + def _escape_html(self, s: str) -> str: + return (s.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'")) + + @classmethod + def load(cls, path: Path) -> "EvalState": + with open(path, "r") as f: + data = json.load(f) + + eval_state = cls( + dataset_type=data["id"], + sampling_config=data["sampling_config"], + output_file=path, + model_name=data.get("model_name") + ) + eval_state.load_dataset() + + eval_state.tasks = [] + eval_state.all_tasks = [] + for task_id in data.get("tasks", []): + parts = task_id.rsplit("_", 2) + if len(parts) >= 3: + idx = int(parts[-1]) + else: + idx = 0 + eval_state.tasks.append((idx, task_id)) + eval_state.all_tasks.append((idx, task_id)) + + eval_state.task_states = data.get("task_states", {}) + + cases = eval_state.task_states.get("cases", {}) + eval_state.total = eval_state.task_states.get("total", 0) + eval_state.correct = eval_state.task_states.get("correct", 0) + eval_state.total_time = eval_state.task_states.get("total_time", 0.0) + + if eval_state.total == 0: + eval_state.total = len(cases) + eval_state.correct = sum(1 for c in cases.values() if c.get("correct", False)) + + return eval_state + + def is_complete(self) -> bool: + if not self.all_tasks: + return False + cases = self.task_states.get("cases", {}) + completed = {tid for tid in self.task_states.get("cases", {}).keys() if cases.get(tid, {}).get("status") == "ok"} + return len(completed) == len(self.all_tasks) + + def get_pending_tasks(self) -> List[Tuple[int, str]]: + cases = self.task_states.get("cases", {}) + pending = [] + for i, task_id in self.all_tasks: + status = cases.get(task_id, {}).get("status", "pending") + if status != "ok": + pending.append((i, task_id)) + return pending + + def print_all_tasks(self): + cases = self.task_states.get("cases", {}) + tasks_to_show = self.all_tasks if self.all_tasks else self.tasks + print() + print("Tasks:") + print(" Task ID Dataset Prompt (first 40 chars) Expected Answer Tokens T/s Gen s Status") + for i, task_id in tasks_to_show: + question, prompt, expected = self.get_case(i) + case = cases.get(task_id, {}) + status = case.get("status", "pending") + answer = case.get("answer", "N/A") if status == "ok" else "N/A" + tokens = case.get("tokens") + tokens_str = str(tokens) if tokens is not None else "N/A" + tps_gen = case.get("tps_gen") + tps_str = f"{tps_gen:.1f}" if tps_gen is not None else "N/A" + t_gen_ms = case.get("t_gen_ms") + t_gen_str = f"{t_gen_ms/1000:.1f}" if t_gen_ms is not None else "N/A" + server_name = case.get("server_name", "") or "" + is_correct = case.get("correct", False) if status == "ok" else False + symbol = "✓ " if is_correct else ("✗ " if status == "ok" else "") + first_line = question.split('\n')[0] + question_trunc = first_line[:43] + if len(first_line) > 43: + question_trunc += "..." + else: + question_trunc = question_trunc.ljust(43) + "..." + print(f" {task_id:<20} {self.dataset_type.upper()} {question_trunc:<40} {expected:<10} {answer:<10} {tokens_str:<6} {tps_str:<6} {t_gen_str:<8} {symbol}{status} {server_name}") + print() + + def print_existing_summary(self): + cases = self.task_states.get("cases", {}) + completed_cases = {tid: c for tid, c in cases.items() if c.get("status") == "ok"} + correct = sum(1 for c in completed_cases.values() if c.get("correct", False)) + total = len(completed_cases) + if total == 0: + print(f"{'='*60}") + print(f"Results: 0/0 correct (0.0%)") + print(f"{'='*60}") + else: + ci_lower, ci_upper = self.accuracy_ci() + print(f"{'='*60}") + print(f"Results: {correct}/{total} correct ({correct/total*100:.1f}%) [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]") + print(f"{'='*60}") + + def accuracy_ci(self) -> Tuple[float, float]: + """Compute Wilson score confidence interval from completed cases.""" + cases = self.task_states.get("cases", {}) + completed = {tid: c for tid, c in cases.items() if c.get("status") == "ok"} + correct = sum(1 for c in completed.values() if c.get("correct", False)) + total = len(completed) + return wilson_interval(correct, total) + +def normalize_number(s: str) -> Optional[int]: + match = re.match(r"\d+", s) # match digits from the start + if not match: + return None + return int(match.group(0)) + +class AimeDataset(BaseDataset): + def __init__(self, split: str = "train"): + self.split = split + self.questions: List[Dict] = [] + self._load_dataset() + + def _load_dataset(self): + print(f"Loading AIME dataset (split: {self.split})...") + from datasets import load_dataset + + cache_path = cache_dir / "AI-MO___aimo-validation-aime" / "default" / "0.0.0" + if cache_path.exists(): + print(f"Using cached dataset from {cache_path}") + ds = load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path)) + else: + ds = load_dataset("AI-MO/aimo-validation-aime", split=self.split) + + self.questions = [] + for row in ds: + question = dict(row) + question["dataset_type"] = "aime" + self.questions.append(question) + + print(f"AIME dataset loaded: {len(self.questions)} questions") + + def get_question(self, index: int) -> Dict: + """Get question by index""" + return self.questions[index] + + def get_question_text(self, question: Dict) -> str: + """Get question string""" + return question["problem"] if "problem" in question else question["question"] + + def get_answer(self, question: Dict) -> str: + answer = question["answer"] + if isinstance(answer, str): + normalized = normalize_number(answer) + return str(normalized) if normalized is not None else answer + return str(answer) + + def get_prompt(self, question: Dict) -> str: + """Get formatted prompt for the question""" + return TEMPLATE_REGISTRY[question["dataset_type"]].format( + question=self.get_question_text(question), + ) + +class Aime2025Dataset(BaseDataset): + def __init__(self): + self.questions: List[Dict] = [] + self._load_dataset() + + def _load_dataset(self): + print(f"Loading AIME2025 dataset...") + from datasets import load_dataset + + config_name = "AIME2025-I" + cache_path = cache_dir / "opencompass___AIME2025" / "default" / "0.0.0" + if cache_path.exists(): + print(f"Using cached dataset from {cache_path}") + ds = load_dataset("opencompass/AIME2025", config_name, split="test", cache_dir=str(cache_path)) + else: + ds = load_dataset("opencompass/AIME2025", config_name, split="test") + + self.questions = [] + for row in ds: + question = dict(row) + question["dataset_type"] = "aime2025" + self.questions.append(question) + + print(f"AIME2025 dataset loaded: {len(self.questions)} questions") + + print(f"Loading AIME2025 dataset (part 2)...") + config_name_2 = "AIME2025-II" + cache_path_2 = cache_dir / "opencompass___AIME2025" / "default" / "0.0.0" + if cache_path_2.exists(): + print(f"Using cached dataset from {cache_path_2}") + ds_2 = load_dataset("opencompass/AIME2025", config_name_2, split="test", cache_dir=str(cache_path_2)) + else: + ds_2 = load_dataset("opencompass/AIME2025", config_name_2, split="test") + + for row in ds_2: + question = dict(row) + question["dataset_type"] = "aime2025" + self.questions.append(question) + + print(f"AIME2025 dataset loaded: {len(self.questions)} questions (total)") + + def get_question(self, index: int) -> Dict: + """Get question by index""" + return self.questions[index] + + def get_question_text(self, question: Dict) -> str: + """Get question string""" + return question["question"] + + def get_answer(self, question: Dict) -> str: + answer = question["answer"] + if isinstance(answer, str): + normalized = normalize_number(answer) + return str(normalized) if normalized is not None else answer + return str(answer) + + def get_prompt(self, question: Dict) -> str: + """Get formatted prompt for the question""" + return TEMPLATE_REGISTRY["aime2025"].format( + question=self.get_question_text(question), + ) + +class Gsm8kDataset(BaseDataset): + def __init__(self, split: str = "test"): + self.split = split + self.questions: List[Dict] = [] + self._load_dataset() + + def _load_dataset(self): + print(f"Loading GSM8K dataset (split: {self.split})...") + from datasets import load_dataset + + cache_path = cache_dir / "openai___gsm8k" / "default" / "0.0.0" + if cache_path.exists(): + print(f"Using cached dataset from {cache_path}") + ds = load_dataset("openai/gsm8k", "main", split=self.split, cache_dir=str(cache_path)) + else: + ds = load_dataset("openai/gsm8k", "main", split=self.split) + + self.questions = [] + for row in ds: + question = dict(row) + question["dataset_type"] = "gsm8k" + + # Extract numeric answer from the answer field (already has #### prefix) + gold = question["answer"] + # Split by #### and take the last part + parts = gold.split("####") + if len(parts) > 1: + gold = parts[-1].strip() + # Extract the first number from the remaining text + normalized = normalize_number(gold) + question["gold"] = str(normalized) if normalized is not None else gold + + self.questions.append(question) + + print(f"GSM8K dataset loaded: {len(self.questions)} questions") + + def get_question(self, index: int) -> Dict: + """Get question by index""" + return self.questions[index] + + def get_question_text(self, question: Dict) -> str: + """Get question string""" + return question["problem"] if "problem" in question else question["question"] + + def get_answer(self, question: Dict) -> str: + # GSM8K has pre-extracted gold field, AIME uses answer field + if "gold" in question: + return question["gold"] + answer = question["answer"] + if isinstance(answer, str): + normalized = normalize_number(answer) + return str(normalized) if normalized is not None else answer + return str(answer) + + def get_prompt(self, question: Dict) -> str: + """Get formatted prompt for the question""" + return TEMPLATE_REGISTRY[question["dataset_type"]].format( + question=self.get_question_text(question), + ) + +class GpqaDataset(BaseDataset): + def __init__(self, variant: str = "diamond", seed: int = 1234): + self.variant = variant + self.seed = seed + self.questions: List[Dict] = [] + self._load_dataset() + + def _load_dataset(self): + print(f"Loading GPQA dataset (variant: {self.variant})...") + import pandas as pd + + url = f"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_{self.variant}.csv" + df = pd.read_csv(url) + + rng = random.Random(self.seed) + + self.questions = [] + for _, row in df.iterrows(): + question = row.to_dict() + question["dataset_type"] = "gpqa" + + # Shuffle the answer options + correct_answer = question["Correct Answer"] + incorrect_answers = [ + question["Incorrect Answer 1"], + question["Incorrect Answer 2"], + question["Incorrect Answer 3"] + ] + + # Create list of (answer, is_correct) tuples + options = [(ans, ans == correct_answer) for ans in incorrect_answers] + options.append((correct_answer, True)) + + # Shuffle the options + rng.shuffle(options) + + # Extract shuffled answers and determine correct letter + shuffled_answers = [ans for ans, _ in options] + correct_letter = chr(ord('A') + options.index((correct_answer, True))) + + # Store shuffled answers and correct letter + question["shuffled_answers"] = shuffled_answers + question["correct_letter"] = correct_letter + + self.questions.append(question) + + print(f"GPQA dataset loaded: {len(self.questions)} questions") + + def get_question(self, index: int) -> Dict: + """Get question by index""" + return self.questions[index] + + def get_question_text(self, question: Dict) -> str: + """Get question string""" + return question["Question"] + + def get_answer(self, question: Dict) -> str: + # GPQA returns the correct letter (A, B, C, or D) + return question["correct_letter"] + + def get_prompt(self, question: Dict) -> str: + """Get formatted prompt for the question""" + return TEMPLATE_REGISTRY["gpqa"].format( + Question=self.get_question_text(question), + A=question["shuffled_answers"][0], + B=question["shuffled_answers"][1], + C=question["shuffled_answers"][2], + D=question["shuffled_answers"][3] + ) + +class Grader: + def __init__( + self, + grader_type: str = "llm", + grader_script: Optional[str] = None, + grader_model_name: Optional[str] = None, + grader_server_url: str = "", + dataset_type: str = "aime" + ): + self.grader_type = grader_type + self.grader_script = grader_script + self.grader_model_name = grader_model_name + self.grader_server_url = grader_server_url + self.dataset_type = dataset_type + self.pattern = self._get_pattern() + + def _get_pattern(self) -> Optional[str]: + if self.grader_type == "regex": + return GRADER_PATTERNS.get(self.dataset_type) # Use dataset_type as key + return None + + def _extract_answer_regex(self, pred: str) -> Optional[str]: + """Extract answer using regex pattern""" + if not self.pattern: + return None + + # For AIME datasets, prioritize boxed answers + if self.dataset_type in ["aime", "aime2025"]: + boxed_pattern = r'\\boxed{([^}]+)}' + boxed_matches = re.findall(boxed_pattern, pred, re.IGNORECASE) + if boxed_matches: + # Return the last boxed answer found (most likely the final answer) + return boxed_matches[-1].strip() + + # For other datasets, search for numbers from the end of the text + # This prioritizes numbers that appear later in the response + matches = re.findall(self.pattern, pred, re.IGNORECASE) + if not matches: + return None + + # Process matches from end to start + for match in reversed(matches): + if isinstance(match, tuple): + match = match[0] if match[0] else match[1] + answer = match.strip() + if answer: + return answer + return None + + def _grade_regex(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]: + """Grade using regex pattern matching""" + answer = self._extract_answer_regex(pred) + if answer is None: + return False, None + is_correct = answer.strip() == gold.strip() + return is_correct, answer + + def _grade_cli(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]: + """Grade using external CLI script""" + if not self.grader_script: + raise ValueError("CLI grader requires --grader-script") + + script_path = Path(self.grader_script) + if not script_path.exists(): + raise FileNotFoundError(f"Grader script not found: {self.grader_script}") + + try: + result = subprocess.run( + [str(script_path), "--answer", pred, "--expected", gold], + capture_output=True, + text=True, + timeout=30 + ) + is_correct = result.returncode == 0 + answer = pred if is_correct else None + return is_correct, answer + except subprocess.TimeoutExpired: + return False, None + except Exception as e: + return False, None + + def _grade_llm(self, gold: str, pred: str, problem: str) -> Tuple[bool, Optional[str]]: + """Grade using LLM-based extraction with few-shot examples""" + sample_answers = SAMPLE_ANSWERS.get(self.dataset_type, []) + sample_examples = "\n".join([ + f"Example {i+1}: {ans}" for i, ans in enumerate(sample_answers) + ]) + + system_prompt = f"""You are an answer extraction system. Your task is to extract the answer from the model's response. + +Here are some examples of extracted answers to demonstrate what you are supposed to output: + +{sample_examples} + +When extracting the answer, provide only the extracted answer itself, nothing else. If there is no clear answer that can be extracted from the response, reply with 'no answer'.""" + + user_prompt = f"""Extract the answer from the following response: + +"{pred}" + +Please provide only the extracted answer, nothing else. If there is no clear answer that can be extracted from the response, reply with 'no answer'.""" + + url = f"{self.grader_server_url}/v1/chat/completions" + headers = {"Content-Type": "application/json"} + data = { + "model": self.grader_model_name, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt} + ], + "temperature": 0, + } + #print(json.dumps(data, indent=2)) + + try: + response = requests.post(url, headers=headers, json=data) + response.raise_for_status() + answer = response.json()["choices"][0]["message"]["content"].strip() + is_correct = answer.strip().lower() == gold.strip().lower() + return is_correct, answer + except Exception as e: + return False, None + + def _truncate_response(self, response: str, max_lines: int = 6) -> str: + """Keep only last N lines of response""" + lines = response.split('\n') + return '\n'.join(lines[-max_lines:]) if len(lines) > max_lines else response + + def grade(self, gold: str, pred: str, problem: str = "") -> Tuple[bool, Optional[str]]: + """Grade the response""" + if self.grader_type == "regex": + return self._grade_regex(gold, pred) + elif self.grader_type == "cli": + return self._grade_cli(gold, pred) + elif self.grader_type == "llm": + return self._grade_llm(gold, pred, problem) + else: + raise ValueError(f"Unknown grader type: {self.grader_type}") + +class Processor: + def __init__( + self, + server_configs: List[ServerConfig], + grader: Grader, + model_name: Optional[str] = None, + n_predict: int = -1 + ): + self.server_configs = server_configs + self.grader = grader + self.model_name = model_name + self.n_predict = n_predict + + @staticmethod + def _check_server(server_config: ServerConfig) -> List[str]: + url = f"{server_config.url}/v1/models" + try: + response = requests.get(url) + response.raise_for_status() + models = [m["id"] for m in response.json().get("data", [])] + return models + except Exception as e: + print(f"Error: Cannot reach server {server_config.name} ({server_config.url}): {e}", file=sys.stderr) + sys.exit(1) + + def _make_request( + self, server_config: ServerConfig, eval_state: EvalState, prompt: str + ) -> Tuple[Dict[str, Any], int, Optional[float], Optional[float], str]: + url = f"{server_config.url}/v1/chat/completions" + headers = {"Content-Type": "application/json"} + data = { + "model": self.model_name if self.model_name else "llama", + "messages": [{"role": "user", "content": prompt}], + "n_predict": self.n_predict + } + if eval_state.sampling_config.get("temperature") is not None: + data["temperature"] = eval_state.sampling_config["temperature"] + if eval_state.sampling_config.get("top_k") is not None: + data["top_k"] = eval_state.sampling_config["top_k"] + if eval_state.sampling_config.get("top_p") is not None: + data["top_p"] = eval_state.sampling_config["top_p"] + if eval_state.sampling_config.get("min_p") is not None: + data["min_p"] = eval_state.sampling_config["min_p"] + + response = requests.post(url, headers=headers, json=data) + response.raise_for_status() + result = response.json() + tokens = result.get("usage", {}).get("completion_tokens", 0) + timings = result.get("timings", {}) + tps_gen = timings.get("predicted_per_second") if timings else None + t_gen_ms = timings.get("predicted_ms") if timings else None + finish_reason = result.get("choices", [{}])[0].get("finish_reason", "stop") + return result, tokens, tps_gen, t_gen_ms, finish_reason + + def _process_single_case( + self, server_config: ServerConfig, eval_state: EvalState, i: int, task_id: str + ) -> TaskState: + question_text, prompt, expected = eval_state.get_case(i) + + task_state = TaskState( + task_id=task_id, + prompt=prompt, + expected=expected, + question_text=question_text, + server_name=server_config.name + ) + + try: + response, tokens, tps_gen, t_gen_ms, finish_reason = self._make_request(server_config, eval_state, prompt) + result = response["choices"][0]["message"]["content"] + reasoning_content = response["choices"][0].get("message", {}).get("reasoning_content") + task_state.response = result + task_state.tokens = tokens + task_state.tps_gen = tps_gen + task_state.t_gen_ms = t_gen_ms + task_state.reasoning_content = reasoning_content + + if finish_reason != "stop": + task_state.status = f"error: finish_reason={finish_reason}" + eval_state.add_result( + task_id, prompt, expected, result, None, + {"finish_reason": finish_reason}, False, task_state.status, + tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name + ) + eval_state.dump() + return task_state + + result_truncated = self.grader._truncate_response(result, max_lines=10) + is_correct, answer = self.grader.grade(expected, result_truncated, prompt) + + grader_log = { + "pred": result_truncated, + "grader_type": self.grader.grader_type + } + if self.grader.grader_type == "regex" and self.grader.pattern: + grader_log["pattern"] = self.grader.pattern + + task_state.correct = is_correct + task_state.answer = answer + task_state.grader_log = grader_log + task_state.status = "ok" + + eval_state.add_result( + task_id, prompt, expected, result, answer, + grader_log, is_correct, "ok", + tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name + ) + + eval_state.dump() + + except Exception as e: + task_state.status = f"error: {str(e)}" + + return task_state + + @staticmethod + def _worker( + server_config: ServerConfig, + processor: "Processor", + eval_state: EvalState, + task_queue: Queue, + results_queue: Queue, + ): + """Worker that pulls tasks from a shared queue and sends them to its server.""" + while True: + task = task_queue.get() + if task is None: # sentinel + task_queue.task_done() + break + try: + i, task_id = task + result = processor._process_single_case(server_config, eval_state, i, task_id) + results_queue.put(result) + finally: + task_queue.task_done() + + def evaluate(self, eval_state: EvalState, verbose: bool = False, resume: bool = False): + total_tasks = len(eval_state.tasks) + eval_state.total = len(eval_state.all_tasks) if eval_state.all_tasks else total_tasks + eval_state.processed = 0 + start_time = time.time() + + # Check servers and list models + server_models = [self._check_server(sc) for sc in self.server_configs] + + # Print server info + print(f"\nProcessing {len(eval_state.tasks)} {eval_state.dataset_type.upper()} tasks ...") + print(f"Servers ({len(self.server_configs)}):") + for i, sc in enumerate(self.server_configs): + models_str = ", ".join(server_models[i]) if server_models[i] else "(none)" + print(f" {i+1}. {sc.name} — {sc.url} ({sc.threads} threads) [{models_str}]") + print(f"Model: {self.model_name}") + print(f"Grader: {self.grader.grader_type}") + print(f"Sampling: temp={eval_state.sampling_config.get('temperature', 'skip')}, top-k={eval_state.sampling_config.get('top_k', 'skip')}, top-p={eval_state.sampling_config.get('top_p', 'skip')}, min-p={eval_state.sampling_config.get('min_p', 'skip')}") + print() + + # Shared task queue: all workers compete for tasks + task_queue: Queue = Queue() + for i, task_id in eval_state.tasks: + task_queue.put((i, task_id)) + + # Results queue: workers push completed TaskStates here + results_queue: Queue = Queue() + + # Total worker threads across all servers + total_threads = sum(sc.threads for sc in self.server_configs) + + # Add one sentinel per worker so every worker exits cleanly + for _ in range(total_threads): + task_queue.put(None) + + # Launch workers: one ThreadPoolExecutor per server + executors: List[ThreadPoolExecutor] = [] + worker_futures: List[Any] = [] + for server_config in self.server_configs: + executor = ThreadPoolExecutor(max_workers=server_config.threads) + executors.append(executor) + for _ in range(server_config.threads): + future = executor.submit( + self._worker, server_config, self, eval_state, + task_queue, results_queue + ) + worker_futures.append(future) + + # Drain results as they complete + n_correct = 0 + session_time = 0.0 + completed_count = 0 + + while completed_count < total_tasks: + task_state = results_queue.get() + eval_state.processed += 1 + completed_count += 1 + if task_state.correct: + n_correct += 1 + elapsed = time.time() - start_time + eval_state.total_time += elapsed + session_time += elapsed + start_time = time.time() + eval_state.print_progress(task_state, total_tasks, n_correct) + + if verbose: + print(f"\nCase {eval_state.processed}: {task_state.correct}") + print(f" Expected: {task_state.expected}") + if task_state.response: + print(f" Response: {task_state.response}") + if task_state.answer: + print(f" Answer: {task_state.answer}") + print(f" Status: {task_state.status}") + + # Wait for all workers to finish and shut down executors + for future in worker_futures: + future.result() + for executor in executors: + executor.shutdown(wait=True) + + print(f"\nSession time: {session_time:.1f}s | Total accumulated time: {eval_state.total_time:.1f}s") + eval_state.print_summary() + eval_state.dump() + +def main(): + parser = argparse.ArgumentParser( + description="Simplified evaluation tool for llama.cpp" + ) + parser.add_argument( + "--server", + type=str, + default="http://localhost:8033", + help="Comma-separated llama-server URLs (default: http://localhost:8033)" + ) + parser.add_argument( + "--server-name", + type=str, + default="", + help="Comma-separated display names for servers (default: use URLs)" + ) + parser.add_argument( + "--dataset", + type=str, + default="aime", + choices=["aime", "aime2025", "gsm8k", "gpqa"], + help="Dataset type (default: aime)" + ) + parser.add_argument( + "--n_cases", + type=int, + default=None, + help="Number of cases to evaluate (default: all)" + ) + parser.add_argument( + "--seed", + type=int, + default=1234, + help="Random seed for shuffling (default: 1234)" + ) + parser.add_argument( + "--n_predict", + type=int, + default=-1, + help="Max tokens to predict per prompt (default: -1, infinite)" + ) + parser.add_argument( + "--temperature", + type=float, + default=None, + help="Sampling temperature (default: not passed)" + ) + parser.add_argument( + "--top-k", + type=int, + default=None, + help="Top K sampling (default: not passed)" + ) + parser.add_argument( + "--top-p", + type=float, + default=None, + help="Top P sampling (default: not passed)" + ) + parser.add_argument( + "--min-p", + type=float, + default=None, + help="Min P sampling (default: not passed)" + ) + parser.add_argument( + "--threads", + type=str, + default="32", + help="Comma-separated thread counts per server (default: 32)" + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="Model name to append as query parameter (e.g., gpt-oss-20b-hf)" + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Show detailed output for each case" + ) + parser.add_argument( + "--output", + type=Path, + default=Path("llama-eval-state.json"), + help="Output file for eval state (default: llama-eval-state.json)" + ) + parser.add_argument( + "--grader-type", + type=str, + default="llm", + choices=["regex", "cli", "llm"], + help="Grader type: regex, cli, or llm (default: llm)" + ) + parser.add_argument( + "--grader-script", + type=str, + default=None, + help="CLI grader script path (required for --grader-type cli)" + ) + parser.add_argument( + "--grader-server", + type=str, + default="", + help="Server URL for LLM grader (default: same as main server)" + ) + parser.add_argument( + "--grader-model", + type=str, + default="", + help="Model name for LLM grader (default: same as main model)" + ) + parser.add_argument( + "--resume", + action="store_true", + help="Resume from existing eval state" + ) + + args = parser.parse_args() + + # Parse server URLs and thread counts + server_urls = [u.strip() for u in args.server.split(",") if u.strip()] + thread_counts = [int(t.strip()) for t in args.threads.split(",") if t.strip()] + + if len(server_urls) != len(thread_counts): + print(f"Error: --server ({len(server_urls)} URLs) and --threads ({len(thread_counts)} values) must have the same count") + sys.exit(1) + + # Parse server names (optional, defaults to URLs) + if args.server_name: + server_names = [n.strip() for n in args.server_name.split(",") if n.strip()] + if len(server_names) != len(server_urls): + print(f"Error: --server-name ({len(server_names)} names) and --server ({len(server_urls)} URLs) must have the same count") + sys.exit(1) + else: + server_names = server_urls # fallback to URLs + + server_configs = [ + ServerConfig(url=url, threads=threads, name=name) + for url, threads, name in zip(server_urls, thread_counts, server_names) + ] + + if args.dataset == "gpqa" and args.grader_type != "llm": + print("Error: GPQA dataset requires --grader-type llm") + parser.print_help() + sys.exit(1) + + if args.output.exists(): + print(f"Loading existing eval state from {args.output}") + eval_state = EvalState.load(args.output) + + # Verify model matches + if eval_state.model_name is not None and args.model != eval_state.model_name: + print(f"Error: Model mismatch. State has '{eval_state.model_name}', but --model is '{args.model}'") + sys.exit(1) + + eval_state.print_all_tasks() + eval_state.print_existing_summary() + + if eval_state.is_complete(): + return + + print() + + if not args.resume: + print(f"Evaluation incomplete. Run with --resume to continue.") + return + + pending_tasks = eval_state.get_pending_tasks() + print(f"Resuming from {len(pending_tasks)} pending tasks") + + existing_cases = eval_state.task_states.get("cases", {}) + + eval_state.tasks = pending_tasks + eval_state.task_states["cases"] = existing_cases + + grader_server_url = args.grader_server if args.grader_server else server_configs[0].url + grader_model_name = args.grader_model if args.grader_model else args.model + if args.grader_type == "llm" and not grader_model_name: + print("Error: --grader-type llm requires --grader-model or --model") + sys.exit(1) + grader = Grader( + grader_type=args.grader_type, + grader_script=args.grader_script, + grader_model_name=grader_model_name, + grader_server_url=grader_server_url, + dataset_type=eval_state.dataset_type + ) + resume = True + else: + if args.resume: + print("Error: No existing eval state found to resume") + sys.exit(1) + + grader_server_url = args.grader_server if args.grader_server else server_configs[0].url + grader_model_name = args.grader_model if args.grader_model else args.model + if args.grader_type == "llm" and not grader_model_name: + print("Error: --grader-type llm requires --grader-model or --model") + sys.exit(1) + + grader = Grader( + grader_type=args.grader_type, + grader_script=args.grader_script, + grader_model_name=grader_model_name, + grader_server_url=grader_server_url, + dataset_type=args.dataset + ) + + if args.grader_type == "llm" and not args.grader_server: + print("Warning: Using same server for LLM grader (no --grader-server specified)") + + sampling_config = {} + if args.temperature is not None: + sampling_config["temperature"] = args.temperature + if args.top_k is not None: + sampling_config["top_k"] = args.top_k + if args.top_p is not None: + sampling_config["top_p"] = args.top_p + if args.min_p is not None: + sampling_config["min_p"] = args.min_p + + eval_state = EvalState( + dataset_type=args.dataset, + sampling_config=sampling_config, + output_file=args.output, + model_name=args.model + ) + eval_state.load_dataset(seed=args.seed) + eval_state.setup_tasks(n_cases=args.n_cases, seed=args.seed) + eval_state.dump() + resume = False + + eval_state.print_all_tasks() + + processor = Processor( + server_configs=server_configs, + grader=grader, + model_name=args.model, + n_predict=args.n_predict + ) + + processor.evaluate(eval_state, verbose=args.verbose, resume=resume) + print(f"\nEval state dumped to {args.output}") + +if __name__ == "__main__": + main() diff --git a/examples/llama-eval/llama-server-simulator.py b/examples/llama-eval/llama-server-simulator.py new file mode 100755 index 000000000..2f9cdc545 --- /dev/null +++ b/examples/llama-eval/llama-server-simulator.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python3 + +import argparse +import json +import random +import re +import time +import sys +import os +import threading +from http.server import HTTPServer, BaseHTTPRequestHandler +from typing import Dict, List, Optional +from dataclasses import dataclass +from pathlib import Path + +import datasets + +# Set cache directory for HuggingFace datasets +cache_dir = Path.home() / ".cache" / "huggingface" / "datasets" +cache_dir.mkdir(parents=True, exist_ok=True) +os.environ["HF_DATASETS_CACHE"] = str(cache_dir) + +def dice(s1: str, s2: str) -> float: + """Calculate Dice coefficient between two strings based on bigram overlap.""" + if not s1 and not s2: + return 1.0 + + def _bigrams(s: str): + return [s[i : i + 2] for i in range(len(s) - 1)] + + bigrams1 = _bigrams(s1) + bigrams2 = _bigrams(s2) + + if not bigrams1 and not bigrams2: + return 1.0 + + from collections import Counter + + freq1 = Counter(bigrams1) + freq2 = Counter(bigrams2) + + intersection = sum(min(freq1[bg], freq2[bg]) for bg in freq1) + dice_coeff = 2 * intersection / (len(bigrams1) + len(bigrams2)) + return dice_coeff + +def debug_log(message: str): + """Log debug messages to both stdout and a file""" + print(message, file=sys.stderr) + with open("/tmp/simulator-debug.log", "a") as f: + f.write(message + "\n") + +simulator: Optional["Simulator"] = None + +@dataclass +class EvalState: + id: str + tasks: List[str] + task_states: Dict[str, Dict] + sampling_config: Dict + +def normalize_number(s: str) -> Optional[int]: + match = re.match(r"\d+", s) # match digits from the start + if not match: + return None + return int(match.group(0)) + +class AimeDataset: + def __init__(self, split: str = "train"): + self.split = split + self.questions: List[Dict] = [] + self._load_dataset() + + def _load_dataset(self): + print(f"Loading AIME dataset (split: {self.split})...") + + cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0" + if cache_path.exists(): + print(f"Using cached dataset from {cache_path}") + ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path)) + else: + ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split) + + self.questions = list(ds) + print(f"AIME dataset loaded: {len(self.questions)} questions") + + def find_question(self, request_text: str) -> Optional[Dict]: + best_match = None + best_distance = -1 + best_index = -1 + + for i, question in enumerate(self.questions): + question_text = question["problem"] + request_lower = request_text.lower() + question_lower = question_text.lower() + + # Exact match + if question_lower == request_lower: + debug_log(f"DEBUG: Found exact match at index {i}") + return question + + # Remove LaTeX formatting for more flexible matching + question_no_latex = re.sub(r'\$[^$]+\$', '', question_text) + if question_no_latex.lower() == request_lower: + debug_log(f"DEBUG: Found match (no LaTeX) at index {i}") + return question + + # Calculate Dice coefficient for partial matches + # Only consider if request is at least 50% of question length + if len(request_lower) >= len(question_lower) * 0.5: + distance = dice(question_lower, request_lower) + + if distance > best_distance: + best_distance = distance + best_match = question + best_index = i + + if best_match and best_distance > 0.3: # Threshold for partial match + debug_log(f"DEBUG: Found best partial match at index {best_index} with distance {best_distance:.3f}") + return best_match + + debug_log(f"DEBUG: No matching question found for: {request_text[:100]}...") + return None + + def get_answer(self, question: Dict) -> str: + answer = question["answer"] + if isinstance(answer, str): + normalized = normalize_number(answer) + return str(normalized) if normalized is not None else answer + return str(answer) + +class Simulator: + def __init__( + self, + port: int = 8033, + host: str = "localhost", + success_rate: float = 0.8, + dataset_split: str = "train" + ): + self.port = port + self.host = host + self.success_rate = success_rate + self.dataset = AimeDataset(dataset_split) + self.eval_state = EvalState( + id="aime-2025", + tasks=["aime"], + task_states={}, + sampling_config={"temperature": 0, "max_tokens": 2048} + ) + + def _generate_response( + self, + question: Dict, + should_be_correct: bool + ) -> Dict: + expected_answer = self.dataset.get_answer(question) + + if should_be_correct: + response_text = expected_answer + else: + response_text = self._generate_wrong_answer(question) + + return { + "id": f"chatcmpl-{int(time.time())}", + "object": "chat.completion", + "created": int(time.time()), + "model": "llama", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": response_text + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150 + } + } + + def _generate_wrong_answer(self, question: Dict) -> str: + expected_answer = self.dataset.get_answer(question) + + if expected_answer.isdigit(): + wrong_answer = str(int(expected_answer) + 1) + else: + wrong_answer = expected_answer + " (wrong)" + + return wrong_answer + + def _process_request(self, request_data: Dict) -> Dict: + messages = request_data.get("messages", []) + if not messages: + return {"error": "No messages in request"} + + request_text = messages[0].get("content", "") + debug_log(f"DEBUG: Received request with content: {request_text[:150]}...") + + question = self.dataset.find_question(request_text) + if not question: + debug_log(f"DEBUG: find_question returned None") + return {"error": "No matching question found"} + + should_be_correct = random.random() < self.success_rate + + response = self._generate_response(question, should_be_correct) + + task_id = "aime" + self.eval_state.task_states[task_id] = { + "correct": should_be_correct, + "expected": self.dataset.get_answer(question), + "predicted": response["choices"][0]["message"]["content"] + } + + return response + +class RequestHandler(BaseHTTPRequestHandler): + def do_POST(self): + if self.path != "/v1/chat/completions": + self._send_json({"error": "Not found"}, 404) + return + + try: + content_length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(content_length) + request_data = json.loads(body) if body else None + + if not request_data: + self._send_json({"error": "Invalid JSON"}, 400) + return + + if simulator is None: + self._send_json({"error": "Simulator not initialized"}, 500) + return + + response = simulator._process_request(request_data) + self._send_json(response, 200) + + except json.JSONDecodeError: + self._send_json({"error": "Invalid JSON"}, 400) + except Exception as e: + print(f"Error processing request: {e}") + self._send_json({"error": str(e)}, 500) + + def _send_json(self, data: dict, status: int = 200): + body = json.dumps(data).encode("utf-8") + self.send_response(status) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def log_message(self, format, *args): + # Suppress default request logging + pass + + +def main(): + parser = argparse.ArgumentParser( + description="llama-server simulator for testing eval scripts" + ) + parser.add_argument( + "--port", + type=int, + default=8033, + help="Server port (default: 8033)" + ) + parser.add_argument( + "--host", + type=str, + default="localhost", + help="Server host (default: localhost)" + ) + parser.add_argument( + "--success-rate", + type=float, + default=0.8, + help="Success rate 0-1 (default: 0.8)" + ) + parser.add_argument( + "--dataset-split", + type=str, + default="train", + help="AIME dataset split to use (default: train)" + ) + + args = parser.parse_args() + + global simulator + simulator = Simulator( + port=args.port, + host=args.host, + success_rate=args.success_rate, + dataset_split=args.dataset_split + ) + + server = HTTPServer((args.host, args.port), RequestHandler) + server_thread = threading.Thread(target=server.serve_forever, daemon=True) + server_thread.start() + + print("\n=== llama-server-simulator ===") + print(f"Server running on http://{args.host}:{args.port}") + print(f"Success rate: {args.success_rate}") + print(f"AIME dataset loaded: {len(simulator.dataset.questions)} questions") + print("\nPress Ctrl+C to stop\n") + + try: + server_thread.join() + except KeyboardInterrupt: + print("\nShutting down...") + server.shutdown() + +if __name__ == "__main__": + main() diff --git a/examples/llama-eval/test-simulator.sh b/examples/llama-eval/test-simulator.sh new file mode 100755 index 000000000..f3ddf3e95 --- /dev/null +++ b/examples/llama-eval/test-simulator.sh @@ -0,0 +1,86 @@ +#!/bin/bash + +set -e + +# Get the directory where this script is located +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +echo "=== llama-server-simulator Test Script ===" +echo "" + +PORT=8033 +SUCCESS_RATE=0.8 +TEST_PORT=8034 + +echo "Starting simulator on port $PORT with success rate $SUCCESS_RATE..." +source "$SCRIPT_DIR/venv/bin/activate" +python3 "$SCRIPT_DIR/llama-server-simulator.py" --port $PORT --success-rate $SUCCESS_RATE > /tmp/simulator-test.log 2>&1 & +SIMULATOR_PID=$! + +echo "Waiting for simulator to start..." +sleep 5 + +# Helper function to make a request and extract the answer +make_request() { + local question="$1" + curl -s -X POST http://localhost:$PORT/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d "{ + \"model\": \"llama\", + \"messages\": [ + {\"role\": \"user\", \"content\": \"$question\"} + ], + \"temperature\": 0, + \"max_tokens\": 2048 + }" | python3 -c "import sys, json; data = json.load(sys.stdin); print(data.get('choices', [{}])[0].get('message', {}).get('content', data.get('error', 'No response')))" +} + +# Test question (repeated in multiple tests) +TEST_QUESTION="Quadratic polynomials P(x) and Q(x) have leading coefficients 2 and -2, respectively. The graphs of both polynomials pass through the two points (16,54) and (20,53). Find P(0) + Q(0)." + +echo "" +echo "=== Test 1: Correct Answer ===" +echo "Sending request with known question..." +answer=$(make_request "$TEST_QUESTION") +echo "Answer: $answer" +echo "Expected: 116" +echo "Correct: $([ "$answer" == "116" ] && echo "Yes" || echo "No")" + +echo "" +echo "=== Test 2: Wrong Answer ===" +echo "Sending request with known question (success rate 0.0)..." +answer=$(make_request "$TEST_QUESTION") +echo "Answer: $answer" +echo "Expected: 116" +echo "Correct: $([ "$answer" == "116" ] && echo "Yes" || echo "No")" + +echo "" +echo "=== Test 3: No Matching Question ===" +echo "Sending request with non-matching text..." +response=$(make_request "What is the capital of France?") +echo "Response: $response" +echo "Expected: No matching question found" +echo "Correct: $([ "$response" == "No matching question found" ] && echo "Yes" || echo "No")" + +echo "" +echo "=== Test 4: Success Rate Verification ===" +echo "Sending 10 requests to test success rate..." +correct_count=0 +for i in {1..10}; do + answer=$(make_request "$TEST_QUESTION") + if [ "$answer" == "116" ]; then + correct_count=$((correct_count + 1)) + fi + echo " Request $i: Answer = $answer" +done +echo "Correct answers: $correct_count/10" +echo "Expected: ~8/10 (80% success rate)" +echo "Success rate: $(echo "scale=1; $correct_count * 10" | bc)%" + +echo "" +echo "=== Test Complete ===" +echo "Stopping simulator..." +kill $SIMULATOR_PID 2>/dev/null +wait $SIMULATOR_PID 2>/dev/null || true + +echo "Simulator stopped."