From b94050e8968f2ea7766d9d47058bf0013ade9333 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 17 Apr 2026 23:24:21 +0800 Subject: [PATCH 01/24] CUDA: use LRU based eviction for cuda graphs (#21611) * CUDA: use a ring-buffer for cuda graphs * bump limit to 128 * use LRU eviction * better naming * do periodic clean-up --- ggml/src/ggml-cuda/common.cuh | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 66ed02d29..ddf50baf4 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1187,6 +1187,7 @@ struct ggml_cuda_graph { bool disable_due_to_gpu_arch = false; bool warmup_complete = false; uint64_t uid = 0; + int64_t last_used_time = 0; struct node_properties { ggml_tensor node; void * node_src_data_ptrs[GGML_MAX_SRC]; @@ -1368,12 +1369,28 @@ struct ggml_backend_cuda_context { // when the computation is split across CPU/GPU (e.g., with --n-cpu-moe) std::unordered_map> cuda_graphs; + int64_t last_graph_eviction_sweep = 0; + ggml_cuda_graph * cuda_graph(const void * first_node_ptr) { + const int64_t time_now = ggml_time_us(); + + // sweep every 5s, evicting cuda graphs unused for >=10s + if (time_now - last_graph_eviction_sweep >= 5'000'000) { + last_graph_eviction_sweep = time_now; + for (auto it = cuda_graphs.begin(); it != cuda_graphs.end(); ) { + if (time_now - it->second->last_used_time >= 10'000'000) { + it = cuda_graphs.erase(it); + } else { + ++it; + } + } + } + auto it = cuda_graphs.find(first_node_ptr); if (it == cuda_graphs.end()) { - cuda_graphs[first_node_ptr] = std::make_unique(); - return cuda_graphs[first_node_ptr].get(); + it = cuda_graphs.emplace(first_node_ptr, std::make_unique()).first; } + it->second->last_used_time = time_now; return it->second.get(); } From 45cac7ca703fb9085eae62b9121fca01d20177f6 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Fri, 17 Apr 2026 09:17:11 -0700 Subject: [PATCH 02/24] ggml-webgpu: fix compiler warnings and refactor FlashAttention encoding (#21052) * Update workflows to remove dependence on llvmpipe * Try setting Dawn_DIR * remove c++20 initializers * Move to proper guid * Try avoiding segfaults on vulkan backend process exit * Remove compiler warnings on parameter casting * Fix soft_max and update reg_tile accumulation to f32 for better precision * Refactor flash_attn a bit * remove c++20 initializers and format * Increase div precision for NVIDIA * revert div precision and comment out ggml-ci node for now * Formatting * Try debugging on a failing CI node * Revert "Try debugging on a failing CI node" This reverts commit 1971e33cba919915e12bcfd5828abfbd54ca942e. --- .github/workflows/build-self-hosted.yml | 30 + .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 583 ++++--- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 1498 +++++++---------- .../wgsl-shaders/flash_attn_vec_blk.wgsl | 12 +- 4 files changed, 947 insertions(+), 1176 deletions(-) diff --git a/.github/workflows/build-self-hosted.yml b/.github/workflows/build-self-hosted.yml index 0efe87716..52624a46d 100644 --- a/.github/workflows/build-self-hosted.yml +++ b/.github/workflows/build-self-hosted.yml @@ -97,6 +97,36 @@ jobs: vulkaninfo --summary GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp + # TODO: investigate slight precision issues in some operations for test-backend-ops on the WebGPU backend. + #ggml-ci-nvidia-webgpu: + # runs-on: [self-hosted, Linux, NVIDIA] + + # steps: + # - name: Clone + # id: checkout + # uses: actions/checkout@v6 + + # - name: Dawn Dependency + # id: dawn-depends + # run: | + # DAWN_VERSION="v20260317.182325" + # DAWN_OWNER="google" + # DAWN_REPO="dawn" + # DAWN_ASSET_NAME="Dawn-18eb229ef5f707c1464cc581252e7603c73a3ef0-ubuntu-latest-Release" + # echo "Fetching release asset from https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz" + # curl -L -o artifact.tar.gz \ + # "https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz" + # mkdir dawn + # tar -xvf artifact.tar.gz -C dawn --strip-components=1 + + # - name: Test + # id: ggml-ci + # run: | + # GG_BUILD_WEBGPU=1 \ + # GG_BUILD_WEBGPU_DAWN_PREFIX="$GITHUB_WORKSPACE/dawn" \ + # GG_BUILD_WEBGPU_DAWN_DIR="$GITHUB_WORKSPACE/dawn/lib64/cmake/Dawn" \ + # bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp + # TODO: provision AMX-compatible machine #ggml-ci-cpu-amx: # runs-on: [self-hosted, Linux, CPU, AMX] diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 3de6258c7..7d9a4403f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -390,12 +390,11 @@ struct ggml_webgpu_flash_attn_pipeline_key { bool has_mask; bool has_sinks; bool uses_logit_softcap; - bool use_vec; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec; + uses_logit_softcap == other.uses_logit_softcap; } }; @@ -409,47 +408,37 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.has_mask); ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); - ggml_webgpu_hash_combine(seed, key.use_vec); return seed; } }; -struct ggml_webgpu_flash_attn_shader_lib_context { - ggml_webgpu_flash_attn_pipeline_key key; - uint32_t sg_mat_m; - uint32_t sg_mat_n; - uint32_t sg_mat_k; - size_t wg_mem_limit_bytes; - uint32_t max_subgroup_size; -}; - -struct ggml_webgpu_flash_attn_shader_decisions { +struct ggml_webgpu_flash_attn_decisions { uint32_t q_tile = 0; uint32_t kv_tile = 0; uint32_t wg_size = 0; }; -inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { - // Keep conservative defaults unless this is the f16 vec-split shape family. - if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) { - return 1u; - } +struct ggml_webgpu_flash_attn_vec_decisions { + uint32_t kv_tile = 0; + uint32_t wg_size = 0; +}; - // Head-dim specializations used by the tuned vec f16 path. - switch (key.head_dim_qk) { - case 64: - return 2u; - case 96: - return 4u; - case 128: - return 1u; - case 192: - return 2u; - case 576: - return 2u; - default: - return 1u; - } +inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key( + const ggml_webgpu_shader_lib_context & context) { + const bool has_mask = context.src3 != nullptr; + const bool has_sinks = context.src4 != nullptr; + const bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) && + (context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + + ggml_webgpu_flash_attn_pipeline_key key = {}; + key.kv_type = context.src1->type; + key.head_dim_qk = (uint32_t) context.src0->ne[0]; + key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.kv_direct = kv_direct; + key.has_mask = has_mask; + key.has_sinks = has_sinks; + key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; + return key; } struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { @@ -471,79 +460,20 @@ inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lh return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size; } -struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context { - ggml_webgpu_flash_attn_vec_reduce_pipeline_key key; - uint32_t max_wg_size; -}; - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { - std::vector defines; - std::string variant = "flash_attn_vec_reduce"; - - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); - - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - variant += std::string("_wg") + std::to_string(context.max_wg_size); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - return result; -} - struct ggml_webgpu_flash_attn_blk_pipeline_key { - uint32_t q_tile; uint32_t kv_tile; - bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { - return q_tile == other.q_tile && kv_tile == other.kv_tile; - } + bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { return kv_tile == other.kv_tile; } }; struct ggml_webgpu_flash_attn_blk_pipeline_key_hash { size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const { size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.q_tile); ggml_webgpu_hash_combine(seed, key.kv_tile); return seed; } }; -struct ggml_webgpu_flash_attn_blk_shader_lib_context { - ggml_webgpu_flash_attn_blk_pipeline_key key; - uint32_t max_wg_size; -}; - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { - std::vector defines; - std::string variant = "flash_attn_vec_blk"; - - defines.push_back(std::string("Q_TILE=") + std::to_string(context.key.q_tile)); - variant += std::string("_qt") + std::to_string(context.key.q_tile); - - defines.push_back(std::string("KV_TILE=") + std::to_string(context.key.kv_tile)); - variant += std::string("_kvt") + std::to_string(context.key.kv_tile); - - uint32_t wg_size = 1; - while ((wg_size << 1) <= context.max_wg_size) { - wg_size <<= 1; - } - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - variant += std::string("_wg") + std::to_string(wg_size); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - return result; -} - // This is exposed because it's necessary in supports_op inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, uint32_t kv_tile, @@ -568,6 +498,41 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; } +inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context, + const ggml_webgpu_flash_attn_pipeline_key & key) { + const size_t limit_bytes = context.wg_mem_limit_bytes; + const size_t q_tile = context.sg_mat_m; + const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + size_t bytes_per_kv = 0; + if (!key.kv_direct) { + bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v); + } + if (key.has_mask) { + bytes_per_kv += q_tile; + } + bytes_per_kv += q_tile; + bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; + const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; + return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; +} + +inline uint32_t ggml_webgpu_flash_attn_vec_get_kv_tile(const ggml_webgpu_shader_lib_context & context) { + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); + const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); + uint32_t kv_tile = std::max(context.sg_mat_n, std::min(32u, min_kv_tile)); + kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; + + if (key.kv_direct) { + kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= context.sg_mat_n; + } + } + + return kv_tile; +} + /** Matrix Multiplication **/ struct ggml_webgpu_legacy_mul_mat_pipeline_key { @@ -802,6 +767,8 @@ class ggml_webgpu_shader_lib { repeat_pipelines; // type std::unordered_map flash_attn_pipelines; + std::unordered_map + flash_attn_vec_pipelines; std::unordered_map @@ -849,10 +816,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_row_norm_pipeline_key key = { - .op = context.dst->op, - .inplace = context.inplace, - }; + ggml_webgpu_row_norm_pipeline_key key = {}; + key.op = context.dst->op; + key.inplace = context.inplace; auto it = row_norm_pipelines.find(key); if (it != row_norm_pipelines.end()) { @@ -908,9 +874,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type, - .vec4 = context.src0->ne[0] % 4 == 0, - .i64_idx = context.src1->type == GGML_TYPE_I64 }; + ggml_webgpu_set_rows_pipeline_key key = {}; + key.dst_type = context.dst->type; + key.vec4 = context.src0->ne[0] % 4 == 0; + key.i64_idx = context.src1->type == GGML_TYPE_I64; auto it = set_rows_pipelines.find(key); if (it != set_rows_pipelines.end()) { @@ -955,7 +922,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_set_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace }; + ggml_webgpu_set_pipeline_key key = {}; + key.type = context.dst->type; + key.inplace = context.inplace; auto it = set_pipelines.find(key); if (it != set_pipelines.end()) { @@ -1062,10 +1031,9 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0; - ggml_webgpu_get_rows_pipeline_key key = { - .src_type = context.src0->type, - .vectorized = (int) vectorized, - }; + ggml_webgpu_get_rows_pipeline_key key = {}; + key.src_type = context.src0->type; + key.vectorized = (int) vectorized; auto it = get_rows_pipelines.find(key); if (it != get_rows_pipelines.end()) { @@ -1115,8 +1083,7 @@ class ggml_webgpu_shader_lib { std::string type_upper = type_str; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - switch (key.src_type) - { + switch (key.src_type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0: @@ -1136,9 +1103,9 @@ class ggml_webgpu_shader_lib { break; } default: - { - defines.push_back(std::string("SRC_TYPE=") + type_str); - } + { + defines.push_back(std::string("SRC_TYPE=") + type_str); + } } defines.push_back("BYTE_HELPERS"); @@ -1181,7 +1148,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace }; + ggml_webgpu_scale_pipeline_key key = {}; + key.inplace = context.inplace; auto it = scale_pipelines.find(key); if (it != scale_pipelines.end()) { @@ -1208,11 +1176,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_solve_tri_pipeline_key key = { - .type = context.dst->type, - .n = (int) context.src0->ne[0], - .k = (int) context.src1->ne[0], - }; + ggml_webgpu_solve_tri_pipeline_key key = {}; + key.type = context.dst->type; + key.n = (int) context.src0->ne[0]; + key.k = (int) context.src1->ne[0]; auto it = solve_tri_pipelines.find(key); if (it != solve_tri_pipelines.end()) { @@ -1250,10 +1217,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_ssm_conv_pipeline_key key = { - .type = context.dst->type, - .vectorized = context.src1->ne[0] == 4, - }; + ggml_webgpu_ssm_conv_pipeline_key key = {}; + key.type = context.dst->type; + key.vectorized = context.src1->ne[0] == 4; auto it = ssm_conv_pipelines.find(key); if (it != ssm_conv_pipelines.end()) { @@ -1293,11 +1259,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_gated_delta_net_pipeline_key key = { - .type = context.dst->type, - .s_v = (int) context.src2->ne[0], - .kda = context.src3->ne[0] == context.src2->ne[0], - }; + ggml_webgpu_gated_delta_net_pipeline_key key = {}; + key.type = context.dst->type; + key.s_v = (int) context.src2->ne[0]; + key.kda = context.src3->ne[0] == context.src2->ne[0]; auto it = gated_delta_net_pipelines.find(key); if (it != gated_delta_net_pipelines.end()) { @@ -1330,7 +1295,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 }; + ggml_webgpu_pad_pipeline_key key = {}; + key.circular = ggml_get_op_params_i32(context.dst, 8) != 0; auto it = pad_pipelines.find(key); if (it != pad_pipelines.end()) { @@ -1357,15 +1323,13 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_mul_mat_vec_pipeline_key key = { - .src0_type = context.src0->type, - .src1_type = context.src1->type, - // Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float - .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0, - }; + ggml_webgpu_mul_mat_vec_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; auto it = mul_mat_vec_pipelines.find(key); if (it != mul_mat_vec_pipelines.end()) { @@ -1451,15 +1415,14 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_mul_mat_pipeline_key key = { - .src0_type = context.src0->type, - .src1_type = context.src1->type, - .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0, - .use_subgroup_matrix = context.supports_subgroup_matrix - }; + ggml_webgpu_mul_mat_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + key.use_subgroup_matrix = context.supports_subgroup_matrix; auto it = mul_mat_fast_pipelines.find(key); if (it != mul_mat_fast_pipelines.end()) { @@ -1578,8 +1541,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_legacy_mul_mat_pipeline_key key = { .src0_type = context.src0->type, - .src1_type = context.src1->type }; + ggml_webgpu_legacy_mul_mat_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; auto it = mul_mat_legacy_pipelines.find(key); if (it != mul_mat_legacy_pipelines.end()) { @@ -1621,8 +1585,7 @@ class ggml_webgpu_shader_lib { std::string type_upper = src0_name; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - switch (context.src0->type) - { + switch (context.src0->type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0: @@ -1642,9 +1605,9 @@ class ggml_webgpu_shader_lib { break; } default: - { - defines.push_back(std::string("SRC0_TYPE=") + src0_name); - } + { + defines.push_back(std::string("SRC0_TYPE=") + src0_name); + } } defines.push_back("BYTE_HELPERS"); @@ -1689,10 +1652,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_id_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_mul_mat_id_pipeline_key key = { - .src0_type = context.src0->type, - .src1_type = context.src1->type, - }; + ggml_webgpu_mul_mat_id_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; auto it = mul_mat_id_pipelines.find(key); if (it != mul_mat_id_pipelines.end()) { @@ -1782,13 +1744,12 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool is_unary = context.dst->op == GGML_OP_UNARY; const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op; - ggml_webgpu_unary_pipeline_key key = { - .type = context.dst->type, - .op = op, - .is_unary = is_unary, - .inplace = context.inplace, - .ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0), - }; + ggml_webgpu_unary_pipeline_key key = {}; + key.type = context.dst->type; + key.op = op; + key.is_unary = is_unary; + key.inplace = context.inplace; + key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0); auto it = unary_pipelines.find(key); if (it != unary_pipelines.end()) { @@ -1853,13 +1814,12 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_binary_pipeline_key key = { - .type = context.dst->type, - .op = context.dst->op, - .inplace = context.inplace, - .overlap = context.overlap, - .src_overlap = context.src_overlap, - }; + ggml_webgpu_binary_pipeline_key key = {}; + key.type = context.dst->type; + key.op = context.dst->op; + key.inplace = context.inplace; + key.overlap = context.overlap; + key.src_overlap = context.src_overlap; auto it = binary_pipelines.find(key); if (it != binary_pipelines.end()) { @@ -1908,9 +1868,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_concat_pipeline_key key = { - .type = context.dst->type, - }; + ggml_webgpu_concat_pipeline_key key = {}; + key.type = context.dst->type; auto it = concat_pipelines.find(key); if (it != concat_pipelines.end()) { @@ -1945,9 +1904,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_repeat_pipeline_key key = { - .type = context.dst->type, - }; + ggml_webgpu_repeat_pipeline_key key = {}; + key.type = context.dst->type; auto it = repeat_pipelines.find(key); if (it != repeat_pipelines.end()) { @@ -1985,16 +1943,16 @@ class ggml_webgpu_shader_lib { return repeat_pipelines[key]; } - webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_flash_attn_shader_lib_context & context) { - auto it = flash_attn_pipelines.find(context.key); + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); + auto it = flash_attn_pipelines.find(key); if (it != flash_attn_pipelines.end()) { return it->second; } - std::vector defines; std::string variant = "flash_attn"; - switch (context.key.kv_type) { + switch (key.kv_type) { case GGML_TYPE_F32: defines.push_back("KV_F32"); break; @@ -2010,111 +1968,206 @@ class ggml_webgpu_shader_lib { default: GGML_ABORT("Unsupported KV type for flash attention shader"); } - variant += std::string("_") + ggml_type_name(context.key.kv_type); + variant += std::string("_") + ggml_type_name(key.kv_type); - if (context.key.has_mask) { + if (key.has_mask) { defines.push_back("MASK"); variant += "_mask"; } - if (context.key.has_sinks) { + if (key.has_sinks) { defines.push_back("SINKS"); variant += "_sinks"; } - if (context.key.uses_logit_softcap) { + if (key.uses_logit_softcap) { defines.push_back("LOGIT_SOFTCAP"); variant += "_lgsc"; } - if (context.key.kv_direct) { + if (key.kv_direct) { defines.push_back("KV_DIRECT"); variant += "_kvdirect"; } - if (context.key.has_mask && context.key.use_vec) { - defines.push_back("BLK"); - variant += "_blk"; - } - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk); + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - uint32_t q_tile = context.sg_mat_m; - uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), - context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - if (context.key.use_vec) { - q_tile = 1; - kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context))); - kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; - const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key); - defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); - } - if (context.key.kv_direct) { - GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); + auto decisions = std::make_shared(); + decisions->q_tile = context.sg_mat_m; + + const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); + uint32_t kv_tile = std::min(min_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + + if (key.kv_direct) { + kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { kv_tile -= context.sg_mat_n; } } - defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + decisions->kv_tile = kv_tile; + decisions->wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - uint32_t wg_size = 0; - if (context.key.use_vec) { - wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); - } else { - wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - } - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + defines.push_back(std::string("Q_TILE=") + std::to_string(decisions->q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size)); - const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn; webgpu_pipeline pipeline = - ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); - auto decisions = std::make_shared(); - decisions->q_tile = q_tile; - decisions->kv_tile = kv_tile; - decisions->wg_size = wg_size; - pipeline.context = decisions; - flash_attn_pipelines[context.key] = pipeline; - return flash_attn_pipelines[context.key]; + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn, defines), variant); + pipeline.context = decisions; + flash_attn_pipelines[key] = pipeline; + return flash_attn_pipelines[key]; } - webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { - auto it = flash_attn_blk_pipelines.find(context.key); + webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); + auto it = flash_attn_vec_pipelines.find(key); + if (it != flash_attn_vec_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "flash_attn_vec"; + + switch (key.kv_type) { + case GGML_TYPE_F32: + defines.push_back("KV_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("KV_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("KV_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("KV_Q8_0"); + break; + default: + GGML_ABORT("Unsupported KV type for flash attention shader"); + } + variant += std::string("_") + ggml_type_name(key.kv_type); + + if (key.has_mask) { + defines.push_back("MASK"); + defines.push_back("BLK"); + variant += "_mask_blk"; + } + if (key.has_sinks) { + defines.push_back("SINKS"); + variant += "_sinks"; + } + if (key.uses_logit_softcap) { + defines.push_back("LOGIT_SOFTCAP"); + variant += "_lgsc"; + } + if (key.kv_direct) { + defines.push_back("KV_DIRECT"); + variant += "_kvdirect"; + } + + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + defines.push_back("Q_TILE=1"); + + auto decisions = std::make_shared(); + decisions->kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context); + decisions->wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + uint32_t vec_ne = 1u; + + // Keep conservative defaults unless this is the f16 vec-split shape family. + if (key.kv_type == GGML_TYPE_F16 && key.head_dim_qk == key.head_dim_v) { + switch (key.head_dim_qk) { + case 64: + case 192: + case 576: + vec_ne = 2u; + break; + case 96: + vec_ne = 4u; + break; + default: + break; + } + } + + defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size)); + defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); + + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant); + pipeline.context = decisions; + flash_attn_vec_pipelines[key] = pipeline; + return flash_attn_vec_pipelines[key]; + } + + webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_flash_attn_blk_pipeline_key key = {}; + key.kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context); + auto it = flash_attn_blk_pipelines.find(key); if (it != flash_attn_blk_pipelines.end()) { return it->second; } - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context); - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); - flash_attn_blk_pipelines[context.key] = pipeline; - return flash_attn_blk_pipelines[context.key]; + std::vector defines; + std::string variant = "flash_attn_vec_blk"; + + defines.push_back(std::string("KV_TILE=") + std::to_string(key.kv_tile)); + variant += std::string("_kvt") + std::to_string(key.kv_tile); + + uint32_t wg_size = 1; + while ((wg_size << 1) <= context.max_wg_size) { + wg_size <<= 1; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + variant += std::string("_wg") + std::to_string(wg_size); + + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_blk, defines), variant); + flash_attn_blk_pipelines[key] = pipeline; + return flash_attn_blk_pipelines[key]; } - webgpu_pipeline get_flash_attn_vec_reduce_pipeline( - const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { - auto it = flash_attn_vec_reduce_pipelines.find(context.key); + webgpu_pipeline get_flash_attn_vec_reduce_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_flash_attn_vec_reduce_pipeline_key key = {}; + key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.wg_size = context.max_wg_size; + auto it = flash_attn_vec_reduce_pipelines.find(key); if (it != flash_attn_vec_reduce_pipelines.end()) { return it->second; } - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(preprocessor, wgsl_flash_attn_vec_reduce, context); - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); - flash_attn_vec_reduce_pipelines[context.key] = pipeline; - return flash_attn_vec_reduce_pipelines[context.key]; + std::vector defines; + std::string variant = "flash_attn_vec_reduce"; + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + variant += std::string("_wg") + std::to_string(context.max_wg_size); + + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_reduce, defines), variant); + flash_attn_vec_reduce_pipelines[key] = pipeline; + return flash_attn_vec_reduce_pipelines[key]; } webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_cpy_pipeline_key key = { - .src_type = context.src0->type, - .dst_type = context.dst->type, - }; + ggml_webgpu_cpy_pipeline_key key = {}; + key.src_type = context.src0->type; + key.dst_type = context.dst->type; auto it = cpy_pipelines.find(key); if (it != cpy_pipelines.end()) { @@ -2166,11 +2219,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_glu_pipeline_key key = { - .glu_op = ggml_get_glu_op(context.dst), - .type = context.dst->type, - .split = (context.src1 != nullptr), - }; + ggml_webgpu_glu_pipeline_key key = {}; + key.glu_op = ggml_get_glu_op(context.dst); + key.type = context.dst->type; + key.split = (context.src1 != nullptr); auto it = glu_pipelines.find(key); if (it != glu_pipelines.end()) { @@ -2239,11 +2291,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_rope_pipeline_key key = { - .type = context.dst->type, - .inplace = context.inplace, - .has_ff = (context.src2 != nullptr), - }; + ggml_webgpu_rope_pipeline_key key = {}; + key.type = context.dst->type; + key.inplace = context.inplace; + key.has_ff = (context.src2 != nullptr); auto it = rope_pipelines.find(key); if (it != rope_pipelines.end()) { @@ -2288,12 +2339,11 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_soft_max_pipeline_key key = { - .mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32, - .has_mask = (context.src1 != nullptr), - .has_sink = (context.src2 != nullptr), - .inplace = context.inplace, - }; + ggml_webgpu_soft_max_pipeline_key key = {}; + key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32; + key.has_mask = (context.src1 != nullptr); + key.has_sink = (context.src2 != nullptr); + key.inplace = context.inplace; auto it = soft_max_pipelines.find(key); if (it != soft_max_pipelines.end()) { @@ -2359,25 +2409,6 @@ class ggml_webgpu_shader_lib { pipeline_desc.layout = nullptr; // nullptr means auto layout return { device.CreateComputePipeline(&pipeline_desc), label }; } - - static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - const size_t q_tile = context.sg_mat_m; - const size_t base_q_bytes = - (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!context.key.kv_direct) { - bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v); - } - if (context.key.has_mask) { - bytes_per_kv += q_tile; - } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; - return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; - } }; #endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 01637e2dd..e7bda817a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -41,6 +41,12 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim wg_x = CEIL_DIV(total_wg, wg_y); } +static inline uint32_t ggml_webgpu_u32_from_f32(float value) { + uint32_t bits; + memcpy(&bits, &value, sizeof(bits)); + return bits; +} + #ifdef GGML_WEBGPU_DEBUG # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl # define WEBGPU_DEBUG_BUF_ELEMS 512 @@ -369,6 +375,96 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, buffer = device.CreateBuffer(&buffer_desc); } +static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { + return webgpu_tensor_offset(tensor) + tensor->view_offs; +} + +static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { + ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; + return ctx->buffer; +} + +static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) { + size_t offset = ggml_webgpu_tensor_offset(t); + return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); +} + +static bool ggml_webgpu_flash_attn_use_vec(webgpu_global_context & global_ctx, + const ggml_tensor * Q, + const ggml_tensor * K, + const ggml_tensor * V) { + const size_t alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const uint32_t k_offset_elems = + (uint32_t) ((ggml_webgpu_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); + const uint32_t v_offset_elems = + (uint32_t) ((ggml_webgpu_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type)); + const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); + const bool kv_vec_type_supported = + K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + + return (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && + (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); +} + +static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { + size_t offset = ggml_webgpu_tensor_offset(t); + return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); +} + +static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { + return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); +} + +// Used to determine if two tensors are the same for in-place operations +static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { + return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && + (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); +} + +// Used to determine if two tensors share the same buffer and their byte ranges overlap, +static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { + return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && + ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) && + ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a)); +} + +struct binary_overlap_flags { + bool inplace; // src0 == dst + bool overlap; // src1 == dst + bool src_overlap; +}; + +static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + binary_overlap_flags flags = {}; + flags.inplace = ggml_webgpu_tensor_equal(src0, dst); + flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); + flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); + + return flags; +} + +static wgpu::BindGroupEntry ggml_webgpu_make_bind_group_entry(uint32_t binding, + wgpu::Buffer buffer, + uint64_t offset, + uint64_t size) { + wgpu::BindGroupEntry entry = {}; + entry.binding = binding; + entry.buffer = std::move(buffer); + entry.offset = offset; + entry.size = size; + return entry; +} + +static wgpu::BindGroupEntry ggml_webgpu_make_tensor_bind_group_entry(webgpu_context & ctx, + uint32_t binding, + ggml_tensor * tensor) { + return ggml_webgpu_make_bind_group_entry(binding, ggml_webgpu_tensor_buf(tensor), + ggml_webgpu_tensor_align_offset(ctx, tensor), + ggml_webgpu_tensor_binding_size(ctx, tensor)); +} + /** End WebGPU object initializations */ /** WebGPU Actions */ @@ -480,10 +576,8 @@ static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context & std::vector entries = dispatch.bind_group_entries; uint32_t params_binding_num = entries.size(); - entries.push_back({ .binding = params_binding_num, - .buffer = ctx->param_arena.buffer, - .offset = param_offset, - .size = ctx->param_arena.slot_size }); + entries.push_back(ggml_webgpu_make_bind_group_entry(params_binding_num, ctx->param_arena.buffer, param_offset, + ctx->param_arena.slot_size)); wgpu::BindGroupDescriptor bind_group_desc; bind_group_desc.layout = dispatch.pipeline.pipeline.GetBindGroupLayout(0); @@ -502,13 +596,17 @@ static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context & #ifdef GGML_WEBGPU_GPU_PROFILE for (size_t i = 0; i < dispatches.size(); i++) { GGML_ASSERT(ctx->profile_timestamp_query_count + 2 <= WEBGPU_MAX_PROFILE_QUERY_COUNT); - const uint32_t query_begin = ctx->profile_timestamp_query_count++; - const uint32_t query_end = ctx->profile_timestamp_query_count++; - wgpu::PassTimestampWrites ts_writes = { .querySet = ctx->profile_timestamp_query_set, - .beginningOfPassWriteIndex = query_begin, - .endOfPassWriteIndex = query_end }; - wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; - wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(&pass_desc); + const uint32_t query_begin = ctx->profile_timestamp_query_count++; + const uint32_t query_end = ctx->profile_timestamp_query_count++; + + wgpu::PassTimestampWrites ts_writes = {}; + ts_writes.querySet = ctx->profile_timestamp_query_set; + ts_writes.beginningOfPassWriteIndex = query_begin; + ts_writes.endOfPassWriteIndex = query_end; + wgpu::ComputePassDescriptor pass_desc = {}; + pass_desc.timestampWrites = &ts_writes; + + wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(&pass_desc); pass.SetPipeline(dispatches[i].pipeline.pipeline); pass.SetBindGroup(0, bind_groups[i]); @@ -544,17 +642,19 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, uint32_t value, size_t offset, size_t size) { - std::vector params = { (uint32_t) offset, (uint32_t) size, value }; - std::vector entries = { - { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() } - }; - size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; - uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); + std::vector params = { (uint32_t) offset, (uint32_t) size, value }; + std::vector entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) }; + size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; + uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t)); - entries.push_back( - { .binding = 1, .buffer = ctx->memset_params_buf, .offset = 0, .size = WEBGPU_PARAMS_BUF_SIZE_BYTES }); + wgpu::BindGroupEntry params_entry = {}; + params_entry.binding = 1; + params_entry.buffer = ctx->memset_params_buf; + params_entry.offset = 0; + params_entry.size = WEBGPU_PARAMS_BUF_SIZE_BYTES; + entries.push_back(params_entry); wgpu::BindGroupDescriptor bind_group_desc; bind_group_desc.layout = ctx->memset_pipeline.pipeline.GetBindGroupLayout(0); @@ -632,65 +732,11 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { delete backend; } -static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { - return webgpu_tensor_offset(tensor) + tensor->view_offs; -} - -static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { - ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; - return ctx->buffer; -} - -static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) { - size_t offset = ggml_webgpu_tensor_offset(t); - return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); -} - -static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { - size_t offset = ggml_webgpu_tensor_offset(t); - return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); -} - -static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { - return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); -} - -// Used to determine if two tensors are the same for in-place operations -static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { - return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && - (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); -} - -// Used to determine if two tensors share the same buffer and their byte ranges overlap, -static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { - return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && - ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) && - ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a)); -} - -struct binary_overlap_flags { - bool inplace; // src0 == dst - bool overlap; // src1 == dst - bool src_overlap; -}; - -static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { - binary_overlap_flags flags = {}; - flags.inplace = ggml_webgpu_tensor_equal(src0, dst); - flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); - flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); - - return flags; -} - static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_cpy_pipeline(shader_lib_ctx); @@ -712,14 +758,8 @@ static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); @@ -732,13 +772,12 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * dst) { const bool inplace = ggml_webgpu_tensor_equal(src0, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx); @@ -772,29 +811,21 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, std::vector entries; uint32_t binding_index = 0; if (!inplace) { - entries.push_back({ .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0)); binding_index++; } - entries.push_back({ .binding = binding_index, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); - entries.push_back({ .binding = binding_index + 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index + 1, dst)); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx); @@ -832,14 +863,8 @@ static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); @@ -850,13 +875,12 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; webgpu_pipeline pipeline = ctx->shader_lib->get_solve_tri_pipeline(shader_lib_ctx); @@ -888,18 +912,9 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size); @@ -911,12 +926,11 @@ static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_conv_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -944,18 +958,9 @@ static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size); @@ -971,15 +976,14 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, ggml_tensor * src4, ggml_tensor * src5, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .src3 = src3, - .src4 = src4, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.src3 = src3; + shader_lib_ctx.src4 = src4; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_gated_delta_net_pipeline(shader_lib_ctx); @@ -1015,34 +1019,10 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }, - { .binding = 3, - .buffer = ggml_webgpu_tensor_buf(src3), - .offset = ggml_webgpu_tensor_align_offset(ctx, src3), - .size = ggml_webgpu_tensor_binding_size(ctx, src3) }, - { .binding = 4, - .buffer = ggml_webgpu_tensor_buf(src4), - .offset = ggml_webgpu_tensor_align_offset(ctx, src4), - .size = ggml_webgpu_tensor_binding_size(ctx, src4) }, - { .binding = 5, - .buffer = ggml_webgpu_tensor_buf(src5), - .offset = ggml_webgpu_tensor_align_offset(ctx, src5), - .size = ggml_webgpu_tensor_binding_size(ctx, src5) }, - { .binding = 6, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, dst), }; return ggml_backend_webgpu_build(ctx, pipeline, params, entries, h, n_seqs); @@ -1058,12 +1038,11 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ct return std::nullopt; } - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = idx, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = idx; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx); @@ -1086,25 +1065,14 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ct }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(idx), - .offset = ggml_webgpu_tensor_align_offset(ctx, idx), - .size = ggml_webgpu_tensor_binding_size(ctx, idx) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, idx), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; if (decisions->i64_idx) { - entries.push_back({ .binding = 3, - .buffer = ctx->set_rows_dev_error_buf, - .offset = 0, - .size = ctx->set_rows_dev_error_buf.GetSize() }); + entries.push_back(ggml_webgpu_make_bind_group_entry(3, ctx->set_rows_dev_error_buf, 0, + ctx->set_rows_dev_error_buf.GetSize())); } uint32_t threads; @@ -1131,12 +1099,11 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * dst) { const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = WEBGPU_MAX_WG_SIZE, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = WEBGPU_MAX_WG_SIZE; webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -1160,20 +1127,9 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2]) }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(idx), - .offset = ggml_webgpu_tensor_align_offset(ctx, idx), - .size = ggml_webgpu_tensor_binding_size(ctx, idx) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, idx), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst) }; uint32_t blocks_per_row = (uint32_t) (dst->ne[0] / (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0 ? 4 : 1)); uint32_t total_rows = (uint32_t) (dst->ne[1] * dst->ne[2] * dst->ne[3]); @@ -1225,17 +1181,16 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, break; } - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; + shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; // Get or create pipeline webgpu_pipeline pipeline; @@ -1270,18 +1225,9 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, // Build bind group entries std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; // Calculate workgroup dimensions @@ -1333,13 +1279,12 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; // Get or create pipeline webgpu_pipeline gather_pipeline, main_pipeline; @@ -1380,22 +1325,14 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, // bind group entries for mul_mat_id_gather.wgsl std::vector gather_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_expert_used_align_offset, - .size = gathered_binding_size }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_tokens_align_offset, - .size = gathered_binding_size }, - { .binding = 3, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_count_ids_align_offset, - .size = gathered_count_ids_binding_size }, + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src2), ggml_webgpu_tensor_align_offset(ctx, src2), + ggml_webgpu_tensor_binding_size(ctx, src2)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), gathered_expert_used_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), gathered_tokens_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), gathered_count_ids_align_offset, + gathered_count_ids_binding_size), }; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; @@ -1427,30 +1364,18 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, // bind group entries for mul_mat_id.wgsl std::vector main_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, - { .binding = 3, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_expert_used_align_offset, - .size = gathered_binding_size }, - { .binding = 4, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_tokens_align_offset, - .size = gathered_binding_size }, - { .binding = 5, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_count_ids_align_offset, - .size = gathered_count_ids_binding_size }, + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0), + ggml_webgpu_tensor_binding_size(ctx, src0)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), ggml_webgpu_tensor_align_offset(ctx, src1), + ggml_webgpu_tensor_binding_size(ctx, src1)), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst)), + ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), gathered_expert_used_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(4, ggml_webgpu_tensor_buf(dst), gathered_tokens_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(5, ggml_webgpu_tensor_buf(dst), gathered_count_ids_align_offset, + gathered_count_ids_binding_size), }; // Calculate workgroup dimensions @@ -1486,11 +1411,9 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_tensor * mask, ggml_tensor * sinks, ggml_tensor * dst) { - float scale = *(float *) dst->op_params; - float max_bias; - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - float logit_softcap; - memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); + float scale = ggml_get_op_params_f32(dst, 0); + float max_bias = ggml_get_op_params_f32(dst, 1); + float logit_softcap = ggml_get_op_params_f32(dst, 2); if (logit_softcap != 0.0f) { scale /= logit_softcap; } @@ -1522,86 +1445,53 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3 (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA) - *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap) - *(uint32_t *) &max_bias, - *(uint32_t *) &logit_softcap, - *(uint32_t *) &n_head_log2, - *(uint32_t *) &m0, - *(uint32_t *) &m1 + ggml_webgpu_u32_from_f32(scale), // scale (possibly adjusted for logit softcap) + ggml_webgpu_u32_from_f32(max_bias), + ggml_webgpu_u32_from_f32(logit_softcap), + ggml_webgpu_u32_from_f32(n_head_log2), + ggml_webgpu_u32_from_f32(m0), + ggml_webgpu_u32_from_f32(m1) }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(Q), - .offset = ggml_webgpu_tensor_align_offset(ctx, Q), - .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(K), - .offset = ggml_webgpu_tensor_align_offset(ctx, K), - .size = ggml_webgpu_tensor_binding_size(ctx, K) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(V), - .offset = ggml_webgpu_tensor_align_offset(ctx, V), - .size = ggml_webgpu_tensor_binding_size(ctx, V) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V), }; uint32_t binding_index = 3; if (has_mask) { - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(mask), - .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask)); } if (has_sinks) { - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(sinks), - .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), - .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, sinks)); } - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst)); - const uint32_t k_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); - const uint32_t v_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); - const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = Q; + shader_lib_ctx.src1 = K; + shader_lib_ctx.src2 = V; + shader_lib_ctx.src3 = mask; + shader_lib_ctx.src4 = sinks; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; + const bool use_vec = ggml_webgpu_flash_attn_use_vec(ctx->global_ctx, Q, K, V); + webgpu_pipeline pipeline = use_vec ? ctx->shader_lib->get_flash_attn_vec_pipeline(shader_lib_ctx) : + ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); - const bool kv_direct = (K->type == GGML_TYPE_F16) && f16_vec4_aligned && - (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && - (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + if (!use_vec) { + auto * decisions = static_cast(pipeline.context.get()); + uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); + uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + } - const bool kv_vec_type_supported = - K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && - (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); - const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); - const bool use_blk = use_vec && has_mask; - - ggml_webgpu_flash_attn_pipeline_key key = { - .kv_type = K->type, - .head_dim_qk = (uint32_t) Q->ne[0], - .head_dim_v = (uint32_t) V->ne[0], - .kv_direct = kv_direct, - .has_mask = static_cast(has_mask), - .has_sinks = static_cast(has_sinks), - .uses_logit_softcap = logit_softcap != 0.0f, - .use_vec = use_vec, - }; - - ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { - .key = key, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, - }; - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); - - auto * decisions = static_cast(pipeline.context.get()); - - uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); - uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches + auto * decisions = static_cast(pipeline.context.get()); wgpu::Buffer blk_buf = {}; uint64_t blk_size_bytes = 0; @@ -1609,197 +1499,162 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, uint32_t blk_nblk1 = 0; uint32_t blk_batch_count = 0; - if (use_vec) { - uint32_t nwg = 1u; - const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); - while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { - nwg <<= 1; - } - nwg = std::min(nwg, vec_nwg_cap); - GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size); - const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; - const bool use_vec_reduce = nwg > 1u; - GGML_ASSERT(nrows <= UINT32_MAX); + const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); + while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { + nwg <<= 1; + } + nwg = std::min(nwg, vec_nwg_cap); + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + const bool use_vec_reduce = nwg > 1u; + GGML_ASSERT(nrows <= UINT32_MAX); - uint64_t tmp_stats_base = 0; - uint64_t tmp_size_bytes = 0; - wgpu::Buffer tmp_buf = {}; - uint64_t tmp_bind_offset = 0; - uint64_t tmp_bind_size = 0; - const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; - const size_t dst_offset = ggml_webgpu_tensor_offset(dst); - size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes); + uint64_t tmp_stats_base = 0; + uint64_t tmp_size_bytes = 0; + wgpu::Buffer tmp_buf = {}; + uint64_t tmp_bind_offset = 0; + uint64_t tmp_bind_size = 0; + const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes); - if (use_vec_reduce) { - const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; - const uint64_t tmp_stats_elems = nrows * 2u * nwg; - tmp_stats_base = tmp_data_elems; - tmp_size_bytes = - ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); - GGML_ASSERT(tmp_stats_base <= UINT32_MAX); - tmp_buf = ggml_webgpu_tensor_buf(dst); - tmp_bind_offset = scratch_offset; - tmp_bind_size = tmp_size_bytes; - scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); - } else { - // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation. - tmp_buf = ggml_webgpu_tensor_buf(dst); - tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst); - tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst); - } - - webgpu_pipeline blk_pipeline; - std::vector blk_params; - std::vector blk_entries; - if (use_blk) { - GGML_ASSERT(has_mask); - - blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); - blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile); - blk_buf = ggml_webgpu_tensor_buf(dst); - const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); - blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; - const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; - blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); - ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = { - .key = - { - .q_tile = decisions->q_tile, - .kv_tile = decisions->kv_tile, - }, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; - blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx); - - blk_params = { - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask - (uint32_t) Q->ne[1], // seq_len_q - (uint32_t) K->ne[1], // seq_len_kv - stride_mask3, // stride_mask3 - blk_nblk0, // nblk0 - blk_nblk1, // nblk1 - }; - blk_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(mask), - .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }, - { .binding = 1, .buffer = blk_buf, .offset = scratch_offset, .size = blk_size_bytes }, - }; - scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); - } - - std::vector split_params = params; - if (use_blk) { - split_params.push_back(0u); // blk_base - split_params.push_back(blk_nblk0); // blk_nblk0 - split_params.push_back(blk_nblk1); // blk_nblk1 - } - split_params.push_back(0u); // tmp_data_base - split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base - split_params.push_back(nwg); // nwg - - std::vector split_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(Q), - .offset = ggml_webgpu_tensor_align_offset(ctx, Q), - .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(K), - .offset = ggml_webgpu_tensor_align_offset(ctx, K), - .size = ggml_webgpu_tensor_binding_size(ctx, K) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(V), - .offset = ggml_webgpu_tensor_align_offset(ctx, V), - .size = ggml_webgpu_tensor_binding_size(ctx, V) }, - }; - uint32_t split_binding_index = 3; - if (has_mask) { - split_entries.push_back({ .binding = split_binding_index++, - .buffer = ggml_webgpu_tensor_buf(mask), - .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); - } - if (has_sinks) { - split_entries.push_back({ .binding = split_binding_index++, - .buffer = ggml_webgpu_tensor_buf(sinks), - .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), - .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); - } - if (use_blk) { - split_entries.push_back({ .binding = split_binding_index++, - .buffer = blk_buf, - .offset = blk_entries[1].offset, - .size = blk_size_bytes }); - } - split_entries.push_back( - { .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size }); - split_entries.push_back({ .binding = split_binding_index++, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - - webgpu_pipeline reduce_pipeline; - std::vector reduce_params; - std::vector reduce_entries; - if (use_vec_reduce) { - const uint32_t reduce_wg_size = std::max( - 32u, - std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); - ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = { - .key = - { - .head_dim_v = (uint32_t) V->ne[0], - .wg_size = reduce_wg_size, - }, - .max_wg_size = reduce_wg_size, - }; - reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); - - reduce_params = { - (uint32_t) nrows, // nrows - (uint32_t) Q->ne[1], // seq_len_q - (uint32_t) Q->ne[2], // n_heads - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst - nwg, // nwg - 0u, // tmp_data_base - (uint32_t) tmp_stats_base, // tmp_stats_base - }; - - reduce_entries = { - { .binding = 0, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_size_bytes }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, - }; - } - - const uint64_t split_wg_total = (uint64_t) wg_x * nwg; - GGML_ASSERT(split_wg_total <= UINT32_MAX); - std::vector dispatches; - - if (use_blk) { - dispatches.push_back({ - blk_pipeline, - std::move(blk_params), - std::move(blk_entries), - { blk_nblk0, blk_nblk1 * blk_batch_count } - }); - } - dispatches.push_back({ - pipeline, std::move(split_params), std::move(split_entries), { (uint32_t) split_wg_total, 1u } - }); - if (use_vec_reduce) { - dispatches.push_back({ - reduce_pipeline, std::move(reduce_params), std::move(reduce_entries), { (uint32_t) nrows, 1u } - }); - } - - return ggml_backend_webgpu_build_multi(ctx, dispatches); + if (use_vec_reduce) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + tmp_stats_base = tmp_data_elems; + tmp_size_bytes = + ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + GGML_ASSERT(tmp_stats_base <= UINT32_MAX); + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = scratch_offset; + tmp_bind_size = tmp_size_bytes; + scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); + } else { + // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation. + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst); + tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst); } - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + webgpu_pipeline blk_pipeline; + std::vector blk_params; + std::vector blk_entries; + if (has_mask) { + blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); + blk_nblk1 = (uint32_t) Q->ne[1]; + blk_buf = ggml_webgpu_tensor_buf(dst); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx; + blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx); + + blk_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) K->ne[1], // seq_len_kv + stride_mask3, // stride_mask3 + blk_nblk0, // nblk0 + blk_nblk1, // nblk1 + }; + blk_entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(mask), + ggml_webgpu_tensor_align_offset(ctx, mask), + ggml_webgpu_tensor_binding_size(ctx, mask)), + ggml_webgpu_make_bind_group_entry(1, blk_buf, scratch_offset, blk_size_bytes), + }; + scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); + } + + std::vector split_params = params; + if (has_mask) { + split_params.push_back(0u); // blk_base + split_params.push_back(blk_nblk0); // blk_nblk0 + split_params.push_back(blk_nblk1); // blk_nblk1 + } + split_params.push_back(0u); // tmp_data_base + split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base + split_params.push_back(nwg); // nwg + + std::vector split_entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q), + ggml_webgpu_tensor_binding_size(ctx, Q)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), ggml_webgpu_tensor_align_offset(ctx, K), + ggml_webgpu_tensor_binding_size(ctx, K)), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V), ggml_webgpu_tensor_align_offset(ctx, V), + ggml_webgpu_tensor_binding_size(ctx, V)), + }; + uint32_t split_binding_index = 3; + if (has_mask) { + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask), + ggml_webgpu_tensor_align_offset(ctx, mask), + ggml_webgpu_tensor_binding_size(ctx, mask))); + } + if (has_sinks) { + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(sinks), + ggml_webgpu_tensor_align_offset(ctx, sinks), + ggml_webgpu_tensor_binding_size(ctx, sinks))); + } + if (has_mask) { + split_entries.push_back( + ggml_webgpu_make_bind_group_entry(split_binding_index++, blk_buf, blk_entries[1].offset, blk_size_bytes)); + } + split_entries.push_back( + ggml_webgpu_make_bind_group_entry(split_binding_index++, tmp_buf, tmp_bind_offset, tmp_bind_size)); + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(dst), + ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst))); + + webgpu_pipeline reduce_pipeline; + std::vector reduce_params; + std::vector reduce_entries; + if (use_vec_reduce) { + const uint32_t reduce_wg_size = std::max( + 32u, std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx; + reduce_shader_ctx.max_wg_size = reduce_wg_size; + reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); + + reduce_params = { + (uint32_t) nrows, // nrows + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) Q->ne[2], // n_heads + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst + nwg, // nwg + 0u, // tmp_data_base + (uint32_t) tmp_stats_base, // tmp_stats_base + }; + + reduce_entries = { + ggml_webgpu_make_bind_group_entry(0, tmp_buf, tmp_bind_offset, tmp_size_bytes), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst)), + }; + } + + uint32_t wg_x = Q->ne[1] * Q->ne[2] * Q->ne[3]; + const uint64_t split_wg_total = (uint64_t) wg_x * nwg; + GGML_ASSERT(split_wg_total <= UINT32_MAX); + + std::vector dispatches; + + if (has_mask) { + dispatches.push_back({ + blk_pipeline, std::move(blk_params), std::move(blk_entries), { blk_nblk0, blk_nblk1 * blk_batch_count } + }); + } + dispatches.push_back({ + pipeline, std::move(split_params), std::move(split_entries), { (uint32_t) split_wg_total, 1u } + }); + if (use_vec_reduce) { + dispatches.push_back({ + reduce_pipeline, std::move(reduce_params), std::move(reduce_entries), { (uint32_t) nrows, 1u } + }); + } + + return ggml_backend_webgpu_build_multi(ctx, dispatches); } #endif // __EMSCRIPTEN__ @@ -1807,13 +1662,12 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor bool is_unary = dst->op == GGML_OP_UNARY; bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx); @@ -1844,10 +1698,10 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor float alpha_p = ggml_get_op_params_f32(dst, 2); float beta = ggml_get_op_params_f32(dst, 3); float eps = ggml_get_op_params_f32(dst, 4); - params.push_back(*reinterpret_cast(&alpha_n)); - params.push_back(*reinterpret_cast(&alpha_p)); - params.push_back(*reinterpret_cast(&beta)); - params.push_back(*reinterpret_cast(&eps)); + params.push_back(ggml_webgpu_u32_from_f32(alpha_n)); + params.push_back(ggml_webgpu_u32_from_f32(alpha_p)); + params.push_back(ggml_webgpu_u32_from_f32(beta)); + params.push_back(ggml_webgpu_u32_from_f32(eps)); break; } default: @@ -1856,25 +1710,19 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor } else if (dst->op == GGML_OP_CLAMP) { float clamp_min = ggml_get_op_params_f32(dst, 0); float clamp_max = ggml_get_op_params_f32(dst, 1); - params.push_back(*reinterpret_cast(&clamp_min)); - params.push_back(*reinterpret_cast(&clamp_max)); + params.push_back(ggml_webgpu_u32_from_f32(clamp_min)); + params.push_back(ggml_webgpu_u32_from_f32(clamp_max)); } else if (dst->op == GGML_OP_FILL) { float fill_val = ggml_get_op_params_f32(dst, 0); - params.push_back(*reinterpret_cast(&fill_val)); + params.push_back(ggml_webgpu_u32_from_f32(fill_val)); effective_src = dst; // fill simply fills dst } std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(effective_src), - .offset = ggml_webgpu_tensor_align_offset(ctx, effective_src), - .size = ggml_webgpu_tensor_binding_size(ctx, effective_src) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, effective_src), }; if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); @@ -1887,15 +1735,14 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, ggml_tensor * dst) { binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = flags.inplace, - .overlap = flags.overlap, - .src_overlap = flags.src_overlap, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = flags.inplace; + shader_lib_ctx.overlap = flags.overlap; + shader_lib_ctx.src_overlap = flags.src_overlap; webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); @@ -1944,38 +1791,18 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0), src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1)); - entries.push_back({ - .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = merged_offset, - .size = merged_end - merged_offset, - }); - entries.push_back({ - .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst), - }); + entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, + merged_end - merged_offset)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } else { - entries.push_back({ - .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = src0_webgpu_tensor_align_offset, - .size = ggml_webgpu_tensor_binding_size(ctx, src0), - }); - entries.push_back({ - .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = src1_webgpu_tensor_align_offset, - .size = ggml_webgpu_tensor_binding_size(ctx, src1), - }); + entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), + src0_webgpu_tensor_align_offset, + ggml_webgpu_tensor_binding_size(ctx, src0))); + entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), + src1_webgpu_tensor_align_offset, + ggml_webgpu_tensor_binding_size(ctx, src1))); if (!flags.inplace && !flags.overlap) { - entries.push_back({ - .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst), - }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); } } @@ -2012,26 +1839,16 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -2059,21 +1876,14 @@ static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * (uint32_t) (dst->ne[2]) }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -2097,28 +1907,19 @@ static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3], - *(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)) // epsilon, treated as f32 in the shader }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(src)); @@ -2129,14 +1930,13 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = ggml_webgpu_tensor_equal(src0, dst), - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst); webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx); @@ -2187,41 +1987,27 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, (uint32_t) src0->ne[2], (uint32_t) n_dims, (uint32_t) mode, - *(uint32_t *) &theta_scale, - *(uint32_t *) &attn_factor, - *(uint32_t *) &freq_scale, - *(uint32_t *) &ext_factor, - *(uint32_t *) &corr_dims[0], - *(uint32_t *) &corr_dims[1], + ggml_webgpu_u32_from_f32(theta_scale), + ggml_webgpu_u32_from_f32(attn_factor), + ggml_webgpu_u32_from_f32(freq_scale), + ggml_webgpu_u32_from_f32(ext_factor), + ggml_webgpu_u32_from_f32(corr_dims[0]), + ggml_webgpu_u32_from_f32(corr_dims[1]), (uint32_t) sections[0], (uint32_t) sections[1], (uint32_t) sections[2], (uint32_t) sections[3] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) } - }; - uint32_t dst_binding = 2; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1) }; + uint32_t dst_binding = 2; if (has_freq_factor) { dst_binding = 3; - entries.push_back({ .binding = 2, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); } if (!inplace) { - entries.push_back({ .binding = dst_binding, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, dst_binding, dst)); } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); @@ -2232,12 +2018,11 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_glu_pipeline(shader_lib_ctx); @@ -2265,29 +2050,20 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], - (uint32_t) ((int32_t *) dst->op_params)[1], // swapped - *(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai - *(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai + (uint32_t) ((int32_t *) dst->op_params)[1], // swapped + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 2)), // alpha, for swiglu_oai + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 3)), // limit, for swiglu_oai }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), }; uint32_t dst_binding = 1; if (split) { dst_binding = 2; - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); } - entries.push_back({ .binding = dst_binding, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, dst_binding, dst)); uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); @@ -2296,13 +2072,12 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -2321,23 +2096,15 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], - *(uint32_t *) dst->op_params, // scale - *(uint32_t *) &dst->op_params[1] // bias + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)), // scale + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 1)) // bias }; // bindgroups unchanged - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); @@ -2349,25 +2116,23 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = ggml_webgpu_tensor_equal(src0, dst), - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst); webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx); - const int inplace = ggml_webgpu_tensor_equal(src0, dst); - const int has_mask = (src1 != nullptr); - const int has_sink = (src2 != nullptr); - float max_bias; - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); - float m0 = powf(2.0f, -(max_bias) / n_head_log2); - float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const int inplace = ggml_webgpu_tensor_equal(src0, dst); + const int has_mask = (src1 != nullptr); + const int has_sink = (src2 != nullptr); + float max_bias = ggml_get_op_params_f32(dst, 1); + float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); + float m0 = powf(2.0f, -(max_bias) / n_head_log2); + float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), @@ -2389,39 +2154,29 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, (uint32_t) src0->ne[2], has_mask ? (uint32_t) src1->ne[2] : 0, has_mask ? (uint32_t) src1->ne[3] : 0, - *(uint32_t *) dst->op_params, // scale - *(uint32_t *) &max_bias, - *(uint32_t *) &n_head_log2, - *(uint32_t *) &m0, - *(uint32_t *) &m1 + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)), // scale + ggml_webgpu_u32_from_f32(max_bias), + ggml_webgpu_u32_from_f32(n_head_log2), + ggml_webgpu_u32_from_f32(m0), + ggml_webgpu_u32_from_f32(m1) }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) } - }; - uint32_t binding_num = 1; + std::vector entries = { ggml_webgpu_make_bind_group_entry( + 0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0), + ggml_webgpu_tensor_binding_size(ctx, src0)) }; + uint32_t binding_num = 1; if (has_mask) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + entries.push_back(ggml_webgpu_make_bind_group_entry(binding_num, ggml_webgpu_tensor_buf(src1), + ggml_webgpu_tensor_align_offset(ctx, src1), + ggml_webgpu_tensor_binding_size(ctx, src1))); binding_num++; } if (has_sink) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_num, src2)); binding_num++; } if (!inplace) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_num, dst)); } return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); @@ -2432,20 +2187,13 @@ static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nelements(dst); @@ -2455,13 +2203,12 @@ static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_top_k = dst->op == GGML_OP_TOP_K; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx); auto * argsort_decisions = static_cast(argsort_pipeline.context.get()); @@ -2527,11 +2274,8 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * const uint32_t wg_x_init = std::min(total_wg_init, max_wg); const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init); std::vector init_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), init_align_offset, init_binding_size) }; dispatches.push_back({ @@ -2580,12 +2324,9 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * nrows }; std::vector merge_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_in, .size = size_in }, - { .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_out, .size = size_out } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), align_in, size_in), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), align_out, size_out) }; const uint32_t total_wg_merge = nm * nrows; @@ -2607,23 +2348,14 @@ static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nrows(dst); @@ -2641,20 +2373,13 @@ static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor total_sum ? 1 : (uint32_t) src->ne[1], total_sum ? 1 : (uint32_t) src->ne[2] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx); @@ -3133,40 +2858,24 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer const ggml_tensor * mask = tensor->src[3]; const ggml_tensor * sinks = tensor->src[4]; if (Q && K && V) { - GGML_UNUSED(sinks); - const bool kv_direct = (K->type == GGML_TYPE_F16) && - (Q->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k == 0) && - (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - const bool kv_vec_type_supported = - K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && - kv_vec_type_supported && (V->type == K->type); - if (use_vec) { - const uint32_t sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; - const uint32_t sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; - const size_t limit_bytes = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const size_t q_tile = sg_mat_m; - const size_t base_q_bytes = (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!kv_direct) { - bytes_per_kv += std::max(Q->ne[0], V->ne[0]); - } - if (mask != nullptr) { - bytes_per_kv += q_tile; - } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - uint32_t kv_tile = ((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n; - kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile)); - kv_tile = (kv_tile / sg_mat_n) * sg_mat_n; - if (kv_direct) { - GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= sg_mat_n; - } - } + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = const_cast(Q); + shader_lib_ctx.src1 = const_cast(K); + shader_lib_ctx.src2 = const_cast(V); + shader_lib_ctx.src3 = const_cast(mask); + shader_lib_ctx.src4 = const_cast(sinks); + shader_lib_ctx.dst = const_cast(tensor); + shader_lib_ctx.max_wg_size = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; + + if (ggml_webgpu_flash_attn_use_vec(ctx->webgpu_global_ctx, Q, K, V)) { + const uint32_t kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(shader_lib_ctx); const uint32_t vec_nwg_cap = std::max( 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); @@ -3271,8 +2980,9 @@ static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct } static ggml_guid_t ggml_backend_webgpu_guid(void) { - static const char * guid_str = "__ggml_webgpu :)"; - return reinterpret_cast((void *) guid_str); + static ggml_guid guid = { 0x67, 0xc7, 0xa4, 0xb1, 0x78, 0x74, 0x4f, 0x51, + 0x9d, 0x65, 0x44, 0x6d, 0xe4, 0x1b, 0x82, 0x9a }; + return &guid; } static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { @@ -3931,20 +3641,23 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { ggml_backend_reg_t ggml_backend_webgpu_reg() { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()"); - static ggml_backend_webgpu_reg_context ctx; - static ggml_backend_reg reg = { + // Intentionally leak the global registry context to avoid crashing inside + // Dawn/Vulkan static teardown during process exit. + static ggml_backend_webgpu_reg_context * ctx = new ggml_backend_webgpu_reg_context(); + + static ggml_backend_reg reg = { /* .api_version = */ GGML_BACKEND_API_VERSION, /* .iface = */ ggml_backend_webgpu_reg_i, - /* .context = */ &ctx, + /* .context = */ ctx, }; - ctx.name = GGML_WEBGPU_NAME; - ctx.device_count = 0; + ctx->name = GGML_WEBGPU_NAME; + ctx->device_count = 0; // Keep one Dawn/WebGPU instance alive for the lifetime of the static backend // registry. Recreating it on repeated registry lookups can invalidate // adapter/device references that are still held by the backend/device layer. - if (ctx.webgpu_global_ctx != nullptr && ctx.webgpu_global_ctx->instance != nullptr) { + if (ctx->webgpu_global_ctx != nullptr && ctx->webgpu_global_ctx->instance != nullptr) { return ® } @@ -3961,17 +3674,18 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { instance_descriptor.nextInChain = &instanceTogglesDesc; #endif - wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor); - ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); - ctx.webgpu_global_ctx->instance = std::move(inst); + wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor); + ctx->webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); + ctx->webgpu_global_ctx->instance = std::move(inst); // Probe for adapter support wgpu::Adapter adapter; - if (ctx.webgpu_global_ctx->instance != nullptr) { + if (ctx->webgpu_global_ctx->instance != nullptr) { wgpu::RequestAdapterOptions options = {}; - ctx.webgpu_global_ctx->instance.WaitAny( - ctx.webgpu_global_ctx->instance.RequestAdapter( + // probe for adapter support + ctx->webgpu_global_ctx->instance.WaitAny( + ctx->webgpu_global_ctx->instance.RequestAdapter( &options, wgpu::CallbackMode::AllowSpontaneous, [&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) { if (status != wgpu::RequestAdapterStatus::Success) { @@ -3984,7 +3698,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { } if (adapter != nullptr) { - ctx.device_count = 1; + ctx->device_count = 1; } return ® diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl index 82d072be7..61107c6a9 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl @@ -1,7 +1,6 @@ diagnostic(off, subgroup_uniformity); enable f16; -#define Q_TILE 1 #define KV_TILE 32 #define WG_SIZE 32 @@ -11,7 +10,7 @@ struct Params { seq_len_kv: u32, stride_mask3: u32, // Number of KV blocks and Q blocks per batch. - // nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = ceil(seq_len_q / Q_TILE). + // nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = seq_len_q. nblk0: u32, nblk1: u32, }; @@ -40,7 +39,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, return; } - let q_start = q_blk * Q_TILE; + let q_start = q_blk; let k_start = kv_blk * KV_TILE; let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u); @@ -54,11 +53,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, var local_max = -MASK_MAX; var local_any = 0u; - for (var q_rel = 0u; q_rel < Q_TILE; q_rel += 1u) { - let q_row = q_start + q_rel; - if (q_row >= params.seq_len_q) { - continue; - } + let q_row = q_start; + if (q_row < params.seq_len_q) { let row_base = mask_batch_base + q_row * params.seq_len_kv; for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) { let k_col = k_start + k_rel; From fd1c0ec3f037f47f8ed76d99d0fd3cacdcc0baab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 18 Apr 2026 08:16:04 +0200 Subject: [PATCH 03/24] llama: fit ctx size for CPU only (#21568) --- src/llama.cpp | 157 +++++++++++++++++++++++++++++++------------------- 1 file changed, 99 insertions(+), 58 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 484372d8d..7d3a34e98 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -91,12 +91,16 @@ static std::vector llama_get_device_memory_data( throw std::runtime_error("failed to create llama_context from model"); } - std::vector ret(model->devices.size()); + const size_t nd = model->n_devices(); + std::vector ret(nd + 1); std::map memory_breakdown = ctx->memory_breakdown(); for (const auto & [buft, mb] : memory_breakdown) { if (ggml_backend_buft_is_host(buft)) { + ret.back().mb.model += mb.model; + ret.back().mb.context += mb.context; + ret.back().mb.compute += mb.compute; continue; } @@ -104,7 +108,7 @@ static std::vector llama_get_device_memory_data( if (!dev) { continue; } - for (size_t i = 0; i < ret.size(); i++) { + for (size_t i = 0; i < nd; i++) { if (model->devices[i].dev == dev) { ret[i].mb.model += mb.model; ret[i].mb.context += mb.context; @@ -113,7 +117,19 @@ static std::vector llama_get_device_memory_data( } } } - for (size_t i = 0; i < ret.size(); i++) { + + { + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } + size_t free; + size_t total; + ggml_backend_dev_memory(cpu_dev, &free, &total); + ret.back().free = free; + ret.back().total = total; + } + for (size_t i = 0; i < nd; i++) { size_t free; size_t total; ggml_backend_dev_memory(model->devices[i].dev, &free, &total); @@ -122,11 +138,8 @@ static std::vector llama_get_device_memory_data( // have any to report. in this case, we will use the host memory as a fallback // fixes: https://github.com/ggml-org/llama.cpp/issues/18577 if (free == 0 && total == 0) { - ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (cpu_dev == nullptr) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); - } - ggml_backend_dev_memory(cpu_dev, &free, &total); + free = ret.back().free; + total = ret.back().total; } ret[i].free = free; ret[i].total = total; @@ -180,15 +193,15 @@ static void llama_params_fit_impl( LLAMA_LOG_DEBUG("%s: getting device memory data for initial parameters:\n", __func__); const dmds_t dmds_full = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); const size_t nd = devs.size(); // number of devices - if (nd == 0) { - LLAMA_LOG_INFO("%s: no devices with dedicated memory found\n", __func__); - return; - } std::vector margins; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits margins.reserve(nd); - for (size_t id = 0; id < nd; id++) { - margins.push_back(margins_s[id]); + if (nd == 0) { + margins.push_back(margins_s[0]); + } else { + for (size_t id = 0; id < nd; id++) { + margins.push_back(margins_s[id]); + } } std::vector dev_names; @@ -215,46 +228,59 @@ static void llama_params_fit_impl( std::vector projected_free_per_device; projected_free_per_device.reserve(nd); - if (nd > 1) { - LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__); - } - for (size_t id = 0; id < nd; id++) { - const llama_device_memory_data & dmd = dmds_full[id]; - - const int64_t projected_used = dmd.mb.total(); - const int64_t projected_free = dmd.free - projected_used; - projected_free_per_device.push_back(projected_free); - - sum_free += dmd.free; - sum_projected_used += projected_used; - sum_projected_free += projected_free; - sum_projected_model += dmd.mb.model; - - if (nd > 1) { - LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " free vs. target of %6" PRId64 "\n", - __func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, projected_free/MiB, margins[id]/MiB); - } - } - assert(sum_free >= 0 && sum_projected_used >= 0); - LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n", - __func__, sum_projected_used/MiB, sum_free/MiB); - if (nd == 1) { - if (projected_free_per_device[0] >= margins[0]) { - LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n", - __func__, projected_free_per_device[0]/MiB, margins[0]/MiB); + if (nd == 0) { + sum_projected_used = dmds_full.back().mb.total(); + sum_free = dmds_full.back().total; + sum_projected_free = sum_free - sum_projected_used; + LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of host memory vs. %" PRId64 " MiB of total host memory\n", + __func__, sum_projected_used/MiB, sum_free/MiB); + if (sum_projected_free >= margins[0]) { + LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of system memory, no changes needed\n", + __func__, sum_projected_free/MiB, margins[0]/MiB); return; } } else { - bool changes_needed = false; + if (nd > 1) { + LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__); + } for (size_t id = 0; id < nd; id++) { - if (projected_free_per_device[id] < margins[id]) { - changes_needed = true; - break; + const llama_device_memory_data & dmd = dmds_full[id]; + + const int64_t projected_used = dmd.mb.total(); + const int64_t projected_free = dmd.free - projected_used; + projected_free_per_device.push_back(projected_free); + + sum_free += dmd.free; + sum_projected_used += projected_used; + sum_projected_free += projected_free; + sum_projected_model += dmd.mb.model; + + if (nd > 1) { + LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " free vs. target of %6" PRId64 "\n", + __func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, projected_free/MiB, margins[id]/MiB); } } - if (!changes_needed) { - LLAMA_LOG_INFO("%s: targets for free memory can be met on all devices, no changes needed\n", __func__); - return; + assert(sum_free >= 0 && sum_projected_used >= 0); + LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n", + __func__, sum_projected_used/MiB, sum_free/MiB); + if (nd == 1) { + if (projected_free_per_device[0] >= margins[0]) { + LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n", + __func__, projected_free_per_device[0]/MiB, margins[0]/MiB); + return; + } + } else { + bool changes_needed = false; + for (size_t id = 0; id < nd; id++) { + if (projected_free_per_device[id] < margins[id]) { + changes_needed = true; + break; + } + } + if (!changes_needed) { + LLAMA_LOG_INFO("%s: targets for free memory can be met on all devices, no changes needed\n", __func__); + return; + } } } @@ -262,11 +288,15 @@ static void llama_params_fit_impl( { int64_t global_surplus = sum_projected_free; - for (size_t id = 0; id < nd; id++) { - global_surplus -= margins[id]; + if (nd == 0) { + global_surplus -= margins[0]; + } else { + for (size_t id = 0; id < nd; id++) { + global_surplus -= margins[id]; + } } if (global_surplus < 0) { - if (nd == 1) { + if (nd <= 1) { LLAMA_LOG_INFO("%s: cannot meet free memory target of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n", __func__, margins[0]/MiB, -global_surplus/MiB); } else { @@ -277,8 +307,12 @@ static void llama_params_fit_impl( if (cparams->n_ctx == 0) { if (hp_nct > n_ctx_min) { int64_t sum_used_target = sum_free; - for (size_t id = 0; id < nd; id++) { - sum_used_target -= margins[id]; + if (nd == 0) { + sum_used_target -= margins[0]; + } else { + for (size_t id = 0; id < nd; id++) { + sum_used_target -= margins[id]; + } } if (nd > 1) { // for multiple devices we need to be more conservative in terms of how much context we think can fit: @@ -293,8 +327,12 @@ static void llama_params_fit_impl( int64_t sum_projected_used_min_ctx = 0; cparams->n_ctx = n_ctx_min; const dmds_t dmds_min_ctx = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - for (const auto & dmd : dmds_min_ctx) { - sum_projected_used_min_ctx += dmd.mb.total(); + if (nd == 0) { + sum_projected_used_min_ctx = dmds_min_ctx.back().mb.total(); + } else { + for (size_t id = 0; id < nd; id++) { + sum_projected_used_min_ctx += dmds_min_ctx[id].mb.total(); + } } if (sum_used_target > sum_projected_used_min_ctx) { // linear interpolation between minimum and maximum context size: @@ -306,7 +344,7 @@ static void llama_params_fit_impl( const int64_t memory_reduction = (hp_nct - cparams->n_ctx) * bytes_per_ctx; LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n", __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB); - if (nd == 1) { + if (nd <= 1) { LLAMA_LOG_INFO("%s: entire model can be fit by reducing context\n", __func__); return; } @@ -329,6 +367,9 @@ static void llama_params_fit_impl( } } } + if (nd == 0) { + throw llama_params_fit_exception("was unable to fit model into system memory by reducing context, abort"); + } if (mparams->n_gpu_layers != default_mparams.n_gpu_layers) { throw llama_params_fit_exception("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort"); @@ -476,8 +517,8 @@ static void llama_params_fit_impl( std::vector ret; ret.reserve(nd); - for (const llama_device_memory_data & dmd : dmd_nl) { - ret.push_back(dmd.mb.total()); + for (size_t id = 0; id < nd; id++) { + ret.push_back(dmd_nl[id].mb.total()); } return ret; }; From 89a5474f0e7d96450adf1764a2c79f4e2d55fa8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sat, 18 Apr 2026 09:36:41 +0200 Subject: [PATCH 04/24] convert : fix (ignore for now) typings errors (#22002) --- convert_hf_to_gguf.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 42d559dfe..2df5e94fe 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -10912,14 +10912,14 @@ class NemotronHModel(GraniteHybridModel): vocab_size = -(vocab_size // -pad_vocab) * pad_vocab self.hparams["vocab_size"] = vocab_size - assert max(tokenizer.vocab.values()) < vocab_size + assert max(tokenizer.vocab.values()) < vocab_size # ty: ignore[unresolved-attribute] tokpre = self.get_vocab_base_pre(tokenizer) - reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} - added_vocab = tokenizer.get_added_vocab() + reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} # ty: ignore[unresolved-attribute] + added_vocab = tokenizer.get_added_vocab() # ty: ignore[unresolved-attribute] - added_tokens_decoder = tokenizer.added_tokens_decoder + added_tokens_decoder = tokenizer.added_tokens_decoder # ty: ignore[unresolved-attribute] for i in range(vocab_size): if i not in reverse_vocab: @@ -10930,7 +10930,7 @@ class NemotronHModel(GraniteHybridModel): if token in added_vocab: if not added_tokens_decoder[i].normalized: previous_token = token - token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False)) + token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False)) # ty: ignore[unresolved-attribute, invalid-assignment] if previous_token != token: logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer") From 83d58e02fcf7131fa601e366bf2460ecd9db5a77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sat, 18 Apr 2026 09:37:30 +0200 Subject: [PATCH 05/24] ci : free disk space for rocm release (#22012) --- .github/workflows/release.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8a49715b3..f1cc12cd4 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -687,6 +687,11 @@ jobs: with: fetch-depth: 0 + - name: Free up disk space + uses: ggml-org/free-disk-space@v1.3.1 + with: + tool-cache: true + - name: ccache uses: ggml-org/ccache-action@v1.2.21 with: From 59accc8863f0e715442496300540692da54ab0c5 Mon Sep 17 00:00:00 2001 From: SamareshSingh <97642706+ssam18@users.noreply.github.com> Date: Sat, 18 Apr 2026 03:04:51 -0500 Subject: [PATCH 06/24] ggml-backend-meta: add multi-segment read support in get_tensor (#22063) --- ggml/src/ggml-backend-meta.cpp | 40 +++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 1ee3eeb4d..24f6bc063 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1270,7 +1270,45 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co GGML_ASSERT(ggml_is_contiguous(tensor)); const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); - GGML_ASSERT(split_state.n_segments == 1); + + if (split_state.n_segments != 1) { + GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + GGML_ASSERT(tensor->ne[3] == 1); + size_t offset_data = 0; + std::vector simple_offsets(n_bufs, 0); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) { + GGML_ASSERT(tensor->ne[2] == 1); + const int64_t blck_size = ggml_blck_size(tensor->type); + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); + const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes, + tensor->ne[1], simple_tensor->nb[1], tensor->nb[1]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + GGML_ASSERT(offset_data*tensor->ne[1] == size); + return; + } + GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes, + tensor->ne[2], simple_tensor->nb[2], tensor->nb[2]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + GGML_ASSERT(offset_data*tensor->ne[2] == size); + return; + } switch (split_state.axis) { case GGML_BACKEND_SPLIT_AXIS_0: From 23b8cc49911d502d8e9710e98ad5dee5b637cf79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sat, 18 Apr 2026 11:19:40 +0200 Subject: [PATCH 07/24] android : libcommon -> libllama-common (#22076) --- examples/llama.android/lib/src/main/cpp/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llama.android/lib/src/main/cpp/CMakeLists.txt b/examples/llama.android/lib/src/main/cpp/CMakeLists.txt index 7862c61a3..20c9e3b2c 100644 --- a/examples/llama.android/lib/src/main/cpp/CMakeLists.txt +++ b/examples/llama.android/lib/src/main/cpp/CMakeLists.txt @@ -51,6 +51,6 @@ target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE target_link_libraries(${CMAKE_PROJECT_NAME} llama - common + llama-common android log) From 4f02d4733934179386cbc15b3454be26237940bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sat, 18 Apr 2026 20:12:00 +0200 Subject: [PATCH 08/24] model : refactor bias tensor variable names (#22079) * refactor bias tensor variable names * use create_tensor_qkv for jina-bert-v2 --- src/llama-graph.cpp | 18 +++---- src/llama-model.cpp | 88 ++++++++++++++++------------------ src/llama-model.h | 9 +--- src/models/apertus.cpp | 2 +- src/models/arcee.cpp | 2 +- src/models/bailingmoe.cpp | 2 +- src/models/bailingmoe2.cpp | 2 +- src/models/bert.cpp | 2 +- src/models/bitnet.cpp | 4 +- src/models/bloom.cpp | 2 +- src/models/codeshell.cpp | 2 +- src/models/cohere2-iswa.cpp | 2 +- src/models/command-r.cpp | 2 +- src/models/deci.cpp | 2 +- src/models/deepseek.cpp | 2 +- src/models/dots1.cpp | 2 +- src/models/dream.cpp | 2 +- src/models/exaone.cpp | 2 +- src/models/gpt2.cpp | 2 +- src/models/gptneox.cpp | 2 +- src/models/granite-hybrid.cpp | 2 +- src/models/granite.cpp | 2 +- src/models/grok.cpp | 2 +- src/models/grovemoe.cpp | 2 +- src/models/hunyuan-dense.cpp | 2 +- src/models/hunyuan-moe.cpp | 2 +- src/models/internlm2.cpp | 2 +- src/models/jais.cpp | 2 +- src/models/jais2.cpp | 2 +- src/models/llama.cpp | 2 +- src/models/llama4.cpp | 2 +- src/models/maincoder.cpp | 2 +- src/models/mistral3.cpp | 2 +- src/models/mpt.cpp | 2 +- src/models/nemotron-h.cpp | 2 +- src/models/nemotron.cpp | 2 +- src/models/openai-moe-iswa.cpp | 2 +- src/models/paddleocr.cpp | 2 +- src/models/pangu-embedded.cpp | 2 +- src/models/phi2.cpp | 2 +- src/models/phi3.cpp | 2 +- src/models/qwen2.cpp | 2 +- src/models/qwen2moe.cpp | 2 +- src/models/qwen2vl.cpp | 2 +- src/models/qwen3.cpp | 2 +- src/models/qwen3moe.cpp | 2 +- src/models/qwen3vl-moe.cpp | 2 +- src/models/qwen3vl.cpp | 2 +- src/models/rnd1.cpp | 2 +- src/models/seed-oss.cpp | 2 +- src/models/smallthinker.cpp | 2 +- src/models/smollm3.cpp | 2 +- src/models/starcoder.cpp | 2 +- src/models/starcoder2.cpp | 2 +- src/models/t5.cpp | 2 +- 55 files changed, 104 insertions(+), 117 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 66cffa461..2ff23f87c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1077,9 +1077,9 @@ llm_graph_qkv llm_graph_context::build_qkv( // fused QKV path ggml_tensor * qkv = build_lora_mm(layer.wqkv, cur, layer.wqkv_s); cb(qkv, "wqkv", il); - if (layer.bqkv) { - qkv = ggml_add(ctx0, qkv, layer.bqkv); - cb(qkv, "bqkv", il); + if (layer.wqkv_b) { + qkv = ggml_add(ctx0, qkv, layer.wqkv_b); + cb(qkv, "wqkv_b", il); } if (hparams.f_clamp_kqv > 0.0f) { qkv = ggml_clamp(ctx0, qkv, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); @@ -1097,8 +1097,8 @@ llm_graph_qkv llm_graph_context::build_qkv( // separate Q/K/V path Qcur = build_lora_mm(layer.wq, cur, layer.wq_s); cb(Qcur, "Qcur", il); - if (layer.bq) { - Qcur = ggml_add(ctx0, Qcur, layer.bq); + if (layer.wq_b) { + Qcur = ggml_add(ctx0, Qcur, layer.wq_b); cb(Qcur, "Qcur", il); } if (hparams.f_clamp_kqv > 0.0f) { @@ -1107,8 +1107,8 @@ llm_graph_qkv llm_graph_context::build_qkv( } Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); cb(Kcur, "Kcur", il); - if (layer.bk) { - Kcur = ggml_add(ctx0, Kcur, layer.bk); + if (layer.wk_b) { + Kcur = ggml_add(ctx0, Kcur, layer.wk_b); cb(Kcur, "Kcur", il); } if (hparams.f_clamp_kqv > 0.0f) { @@ -1117,8 +1117,8 @@ llm_graph_qkv llm_graph_context::build_qkv( } Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); cb(Vcur, "Vcur", il); - if (layer.bv) { - Vcur = ggml_add(ctx0, Vcur, layer.bv); + if (layer.wv_b) { + Vcur = ggml_add(ctx0, Vcur, layer.wv_b); cb(Vcur, "Vcur", il); } if (hparams.f_clamp_kqv > 0.0f) { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index d9781d7d2..4ded484dd 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3098,14 +3098,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_embd_qkv = n_embd_q_ + n_embd_k_ + n_embd_v_; layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", bid), {n_embd_, n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); if (layer.wqkv) { - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", bid), {n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", bid), {n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); } else { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", bid), {n_embd_, n_embd_q_}, flags); layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", bid), {n_embd_, n_embd_k_}, flags); layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", bid), {n_embd_, n_embd_v_}, flags); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", bid), {n_embd_q_}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", bid), {n_embd_k_}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", bid), {n_embd_q_}, TENSOR_NOT_REQUIRED); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", bid), {n_embd_k_}, TENSOR_NOT_REQUIRED); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED); } }; @@ -3138,7 +3138,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // optional bias tensors - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -3201,7 +3201,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // No bias for QKV projections as per config: include_bias=false, include_qkv_bias=false layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); @@ -3336,9 +3336,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); } - // optional bias tensors - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); if (n_ff > 0) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -3558,10 +3557,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -3602,8 +3601,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); @@ -3719,23 +3718,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; // JinaBertLayer - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); @@ -3783,10 +3775,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -3819,10 +3811,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); @@ -3889,7 +3881,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4068,7 +4060,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); @@ -4127,7 +4119,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0); @@ -4291,10 +4283,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -4329,7 +4321,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -4646,7 +4638,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); // optional bias tensors - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -4890,7 +4882,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); } // feed forward (w/ optional biases) @@ -5152,10 +5144,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -5570,10 +5562,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -5612,10 +5604,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // attention biases - all have shape n_embd (output dimension of projections) - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -5918,7 +5910,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); // optional bias tensors - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -5987,7 +5979,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); } else { if (n_expert != 0) { const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; @@ -6808,7 +6800,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // optional bias tensors - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -6890,7 +6882,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // attention layers (with optional bias) create_tensor_qkv(layer, i, hidden_size, n_embd_head_k * attn_num_attention_head, attn_num_key_value_head * n_embd_head_k, attn_num_key_value_head * n_embd_head_v, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * attn_num_attention_head, hidden_size}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0); @@ -7026,7 +7018,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_gate_inp_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0); layer.ffn_gate_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); @@ -7191,7 +7183,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); // optional bias tensors - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); @@ -7422,7 +7414,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // bias tensors - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); diff --git a/src/llama-model.h b/src/llama-model.h index 67349e2d6..5f101bd63 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -246,6 +246,8 @@ struct llama_layer { struct ggml_tensor * wkv_b = nullptr; struct ggml_tensor * wk_b = nullptr; struct ggml_tensor * wv_b = nullptr; + struct ggml_tensor * wqkv_b = nullptr; + struct ggml_tensor * wo_b = nullptr; struct ggml_tensor * wq_cross = nullptr; struct ggml_tensor * wk_cross = nullptr; struct ggml_tensor * wv_cross = nullptr; @@ -256,13 +258,6 @@ struct llama_layer { struct ggml_tensor * wo_enc = nullptr; struct ggml_tensor * wqkv_gate = nullptr; - // attention bias - struct ggml_tensor * bq = nullptr; - struct ggml_tensor * bk = nullptr; - struct ggml_tensor * bv = nullptr; - struct ggml_tensor * bo = nullptr; - struct ggml_tensor * bqkv = nullptr; - // relative position bias struct ggml_tensor * attn_rel_b = nullptr; struct ggml_tensor * attn_rel_b_enc = nullptr; diff --git a/src/models/apertus.cpp b/src/models/apertus.cpp index 80e63e3b4..af44cea60 100644 --- a/src/models/apertus.cpp +++ b/src/models/apertus.cpp @@ -50,7 +50,7 @@ llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur_pos", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/src/models/arcee.cpp b/src/models/arcee.cpp index 948df17d8..2e71f5d9e 100644 --- a/src/models/arcee.cpp +++ b/src/models/arcee.cpp @@ -55,7 +55,7 @@ llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/src/models/bailingmoe.cpp b/src/models/bailingmoe.cpp index 4a6969b97..67a7120d6 100644 --- a/src/models/bailingmoe.cpp +++ b/src/models/bailingmoe.cpp @@ -48,7 +48,7 @@ llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il); } diff --git a/src/models/bailingmoe2.cpp b/src/models/bailingmoe2.cpp index 016072a96..497b4babd 100644 --- a/src/models/bailingmoe2.cpp +++ b/src/models/bailingmoe2.cpp @@ -48,7 +48,7 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const ll cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } diff --git a/src/models/bert.cpp b/src/models/bert.cpp index 57916c8ae..7e046cfd2 100644 --- a/src/models/bert.cpp +++ b/src/models/bert.cpp @@ -72,7 +72,7 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); } diff --git a/src/models/bitnet.cpp b/src/models/bitnet.cpp index 257cf4ca4..71526354c 100644 --- a/src/models/bitnet.cpp +++ b/src/models/bitnet.cpp @@ -57,8 +57,8 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa cb(cur, "attn_sub_norm", il); cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); - if (model.layers[il].bo) { - cur = ggml_add(ctx0, cur, model.layers[il].bo); + if (model.layers[il].wo_b) { + cur = ggml_add(ctx0, cur, model.layers[il].wo_b); } cb(cur, "attn_out", il); } diff --git a/src/models/bloom.cpp b/src/models/bloom.cpp index cf188211d..f3b0999bf 100644 --- a/src/models/bloom.cpp +++ b/src/models/bloom.cpp @@ -33,7 +33,7 @@ llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_para n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/src/models/codeshell.cpp b/src/models/codeshell.cpp index 5efa087e7..3ceb5835b 100644 --- a/src/models/codeshell.cpp +++ b/src/models/codeshell.cpp @@ -47,7 +47,7 @@ llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/src/models/cohere2-iswa.cpp b/src/models/cohere2-iswa.cpp index bf39edc0d..670b08e7d 100644 --- a/src/models/cohere2-iswa.cpp +++ b/src/models/cohere2-iswa.cpp @@ -58,7 +58,7 @@ llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/src/models/command-r.cpp b/src/models/command-r.cpp index fb10eac9c..067961caa 100644 --- a/src/models/command-r.cpp +++ b/src/models/command-r.cpp @@ -54,7 +54,7 @@ llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/deci.cpp b/src/models/deci.cpp index ed52d2b99..30272eabd 100644 --- a/src/models/deci.cpp +++ b/src/models/deci.cpp @@ -59,7 +59,7 @@ llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/deepseek.cpp b/src/models/deepseek.cpp index 73667cd66..671b72dfe 100644 --- a/src/models/deepseek.cpp +++ b/src/models/deepseek.cpp @@ -49,7 +49,7 @@ llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/dots1.cpp b/src/models/dots1.cpp index f1668fe62..5d1750fed 100644 --- a/src/models/dots1.cpp +++ b/src/models/dots1.cpp @@ -49,7 +49,7 @@ llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/dream.cpp b/src/models/dream.cpp index ad6608b56..8e7d9ae64 100644 --- a/src/models/dream.cpp +++ b/src/models/dream.cpp @@ -43,7 +43,7 @@ llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/exaone.cpp b/src/models/exaone.cpp index 626056e4d..4f845bf41 100644 --- a/src/models/exaone.cpp +++ b/src/models/exaone.cpp @@ -46,7 +46,7 @@ llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/gpt2.cpp b/src/models/gpt2.cpp index 22e7d7f41..f8dc53eb7 100644 --- a/src/models/gpt2.cpp +++ b/src/models/gpt2.cpp @@ -37,7 +37,7 @@ llm_build_gpt2::llm_build_gpt2(const llama_model & model, const llm_graph_params n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/src/models/gptneox.cpp b/src/models/gptneox.cpp index 87010841a..0016ddede 100644 --- a/src/models/gptneox.cpp +++ b/src/models/gptneox.cpp @@ -46,7 +46,7 @@ llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/src/models/granite-hybrid.cpp b/src/models/granite-hybrid.cpp index d6e0e8d93..e983742be 100644 --- a/src/models/granite-hybrid.cpp +++ b/src/models/granite-hybrid.cpp @@ -92,7 +92,7 @@ ggml_tensor * llm_build_granite_hybrid::build_attention_layer(ggml_tensor * const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; diff --git a/src/models/granite.cpp b/src/models/granite.cpp index 7b42142c0..6ea902852 100644 --- a/src/models/granite.cpp +++ b/src/models/granite.cpp @@ -101,7 +101,7 @@ ggml_tensor * llm_build_granite::build_attention_layer( const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; diff --git a/src/models/grok.cpp b/src/models/grok.cpp index 69eccb94b..b8f35afdc 100644 --- a/src/models/grok.cpp +++ b/src/models/grok.cpp @@ -50,7 +50,7 @@ llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/grovemoe.cpp b/src/models/grovemoe.cpp index 7806a02c4..151108a2a 100644 --- a/src/models/grovemoe.cpp +++ b/src/models/grovemoe.cpp @@ -50,7 +50,7 @@ llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } diff --git a/src/models/hunyuan-dense.cpp b/src/models/hunyuan-dense.cpp index 97f5da8ee..e4e837eb4 100644 --- a/src/models/hunyuan-dense.cpp +++ b/src/models/hunyuan-dense.cpp @@ -64,7 +64,7 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons cb(Qcur, "Qcur_norm", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/src/models/hunyuan-moe.cpp b/src/models/hunyuan-moe.cpp index 0e32b7d5e..ffe1664b0 100644 --- a/src/models/hunyuan-moe.cpp +++ b/src/models/hunyuan-moe.cpp @@ -65,7 +65,7 @@ llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const ll cb(Qcur, "Qcur_norm", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/src/models/internlm2.cpp b/src/models/internlm2.cpp index 5f688840e..83be2ca0a 100644 --- a/src/models/internlm2.cpp +++ b/src/models/internlm2.cpp @@ -50,7 +50,7 @@ llm_build_internlm2::llm_build_internlm2(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/jais.cpp b/src/models/jais.cpp index 0f817c1d8..31101f3c1 100644 --- a/src/models/jais.cpp +++ b/src/models/jais.cpp @@ -27,7 +27,7 @@ llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/float(n_embd_head), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/jais2.cpp b/src/models/jais2.cpp index 30abe8bc0..507e04fa4 100644 --- a/src/models/jais2.cpp +++ b/src/models/jais2.cpp @@ -51,7 +51,7 @@ llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_para cb(Kcur, "Kcur_rope", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/src/models/llama.cpp b/src/models/llama.cpp index 3f8caeef8..ddaa6c40f 100644 --- a/src/models/llama.cpp +++ b/src/models/llama.cpp @@ -70,7 +70,7 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra cb(Kcur, "Kcur_normed", il); } cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); if (model.layers[il].wo_s) { cur = ggml_mul(ctx0, cur, model.layers[il].wo_s); diff --git a/src/models/llama4.cpp b/src/models/llama4.cpp index d40d37a92..4e4bfb43f 100644 --- a/src/models/llama4.cpp +++ b/src/models/llama4.cpp @@ -84,7 +84,7 @@ llm_build_llama4::llm_build_llama4(const llama_model & model, const llm_gr cb(Kcur, "Kcur_normed", il); } cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/src/models/maincoder.cpp b/src/models/maincoder.cpp index 1e25d50fa..8a76931c0 100644 --- a/src/models/maincoder.cpp +++ b/src/models/maincoder.cpp @@ -56,7 +56,7 @@ llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/mistral3.cpp b/src/models/mistral3.cpp index 8e0e13a74..b5ae72a2e 100644 --- a/src/models/mistral3.cpp +++ b/src/models/mistral3.cpp @@ -67,7 +67,7 @@ llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_grap } cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/src/models/mpt.cpp b/src/models/mpt.cpp index 7a7169a75..8596bbb20 100644 --- a/src/models/mpt.cpp +++ b/src/models/mpt.cpp @@ -56,7 +56,7 @@ llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } diff --git a/src/models/nemotron-h.cpp b/src/models/nemotron-h.cpp index 66eb0bdb9..dc07d43df 100644 --- a/src/models/nemotron-h.cpp +++ b/src/models/nemotron-h.cpp @@ -70,7 +70,7 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; diff --git a/src/models/nemotron.cpp b/src/models/nemotron.cpp index 09ec2936b..054b16fe0 100644 --- a/src/models/nemotron.cpp +++ b/src/models/nemotron.cpp @@ -51,7 +51,7 @@ llm_build_nemotron::llm_build_nemotron(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/openai-moe-iswa.cpp b/src/models/openai-moe-iswa.cpp index e7b7a2bc8..50992b8d5 100644 --- a/src/models/openai-moe-iswa.cpp +++ b/src/models/openai-moe-iswa.cpp @@ -48,7 +48,7 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, model.layers[il].attn_sinks, nullptr, 1.0f/sqrtf(float(n_rot)), il); cb(cur, "attn_out", il); diff --git a/src/models/paddleocr.cpp b/src/models/paddleocr.cpp index 4bc74c175..56cb1d94c 100644 --- a/src/models/paddleocr.cpp +++ b/src/models/paddleocr.cpp @@ -55,7 +55,7 @@ llm_build_paddleocr::llm_build_paddleocr(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { diff --git a/src/models/pangu-embedded.cpp b/src/models/pangu-embedded.cpp index 8046750d0..53464f21d 100644 --- a/src/models/pangu-embedded.cpp +++ b/src/models/pangu-embedded.cpp @@ -49,7 +49,7 @@ llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, co cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/src/models/phi2.cpp b/src/models/phi2.cpp index 8181afd34..0fb3ffa2e 100644 --- a/src/models/phi2.cpp +++ b/src/models/phi2.cpp @@ -51,7 +51,7 @@ llm_build_phi2::llm_build_phi2(const llama_model & model, const llm_graph_params Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head))); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/phi3.cpp b/src/models/phi3.cpp index e00a517c7..39af285d3 100644 --- a/src/models/phi3.cpp +++ b/src/models/phi3.cpp @@ -60,7 +60,7 @@ llm_build_phi3::llm_build_phi3(const llama_model & model, const llm_graph_ cb(Qcur, "Qcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/qwen2.cpp b/src/models/qwen2.cpp index f0c0553d3..2892dd750 100644 --- a/src/models/qwen2.cpp +++ b/src/models/qwen2.cpp @@ -50,7 +50,7 @@ llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/qwen2moe.cpp b/src/models/qwen2moe.cpp index 166a8fb2f..5f0a6861b 100644 --- a/src/models/qwen2moe.cpp +++ b/src/models/qwen2moe.cpp @@ -50,7 +50,7 @@ llm_build_qwen2moe::llm_build_qwen2moe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/qwen2vl.cpp b/src/models/qwen2vl.cpp index 47dfc92a1..da7937c76 100644 --- a/src/models/qwen2vl.cpp +++ b/src/models/qwen2vl.cpp @@ -53,7 +53,7 @@ llm_build_qwen2vl::llm_build_qwen2vl(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/qwen3.cpp b/src/models/qwen3.cpp index 68149bfca..e6f1fc81d 100644 --- a/src/models/qwen3.cpp +++ b/src/models/qwen3.cpp @@ -56,7 +56,7 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); if (model.layers[il].wo_s) { cur = ggml_mul(ctx0, cur, model.layers[il].wo_s); diff --git a/src/models/qwen3moe.cpp b/src/models/qwen3moe.cpp index 533e64b43..dc554b5b3 100644 --- a/src/models/qwen3moe.cpp +++ b/src/models/qwen3moe.cpp @@ -56,7 +56,7 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); if (model.layers[il].wo_s) { cur = ggml_mul(ctx0, cur, model.layers[il].wo_s); diff --git a/src/models/qwen3vl-moe.cpp b/src/models/qwen3vl-moe.cpp index fe5ef578f..29ee8278a 100644 --- a/src/models/qwen3vl-moe.cpp +++ b/src/models/qwen3vl-moe.cpp @@ -62,7 +62,7 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/src/models/qwen3vl.cpp b/src/models/qwen3vl.cpp index 333dba6ea..faa5f2ef3 100644 --- a/src/models/qwen3vl.cpp +++ b/src/models/qwen3vl.cpp @@ -62,7 +62,7 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/src/models/rnd1.cpp b/src/models/rnd1.cpp index b53c075f5..a917c19f2 100644 --- a/src/models/rnd1.cpp +++ b/src/models/rnd1.cpp @@ -58,7 +58,7 @@ llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/seed-oss.cpp b/src/models/seed-oss.cpp index 82c71d8df..6db8d9781 100644 --- a/src/models/seed-oss.cpp +++ b/src/models/seed-oss.cpp @@ -52,7 +52,7 @@ llm_build_seed_oss::llm_build_seed_oss(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/src/models/smallthinker.cpp b/src/models/smallthinker.cpp index 5d9cc82f8..55d09ec32 100644 --- a/src/models/smallthinker.cpp +++ b/src/models/smallthinker.cpp @@ -59,7 +59,7 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, cb(Kcur, "Kcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/smollm3.cpp b/src/models/smollm3.cpp index 6600abcda..83636dbf5 100644 --- a/src/models/smollm3.cpp +++ b/src/models/smollm3.cpp @@ -55,7 +55,7 @@ llm_build_smollm3::llm_build_smollm3(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/src/models/starcoder.cpp b/src/models/starcoder.cpp index be4af1f5a..cf9fe95c3 100644 --- a/src/models/starcoder.cpp +++ b/src/models/starcoder.cpp @@ -36,7 +36,7 @@ llm_build_starcoder::llm_build_starcoder(const llama_model & model, const llm_gr n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/starcoder2.cpp b/src/models/starcoder2.cpp index 1fa50b985..b6d4d5aac 100644 --- a/src/models/starcoder2.cpp +++ b/src/models/starcoder2.cpp @@ -50,7 +50,7 @@ llm_build_starcoder2::llm_build_starcoder2(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/src/models/t5.cpp b/src/models/t5.cpp index 7675532b2..9f9dfef40 100644 --- a/src/models/t5.cpp +++ b/src/models/t5.cpp @@ -41,7 +41,7 @@ llm_build_t5::llm_build_t5(const llama_model & model, const llm_graph_par ggml_tensor * kq_b = build_pos_bias(pos_bucket_dec, attn_rel_b); cur = build_attn(inp_attn_self, - model.layers[il].wo, model.layers[il].bo, model.layers[il].wo_s, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); } From 9e5647affa54ea724196db15ec9b76c4abd16d4a Mon Sep 17 00:00:00 2001 From: Cetarthoriphros Date: Sat, 18 Apr 2026 19:27:17 -0300 Subject: [PATCH 09/24] server: Expose `media_tag` on /props endpoint. (#22028) --- tools/server/README.md | 1 + tools/server/server-context.cpp | 1 + 2 files changed, 2 insertions(+) diff --git a/tools/server/README.md b/tools/server/README.md index b30309bf3..84a29cba0 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -806,6 +806,7 @@ By default, it is read-only. To make POST request to change global properties, y "modalities": { "vision": false }, + "media_marker": "<__media_YoNhud46VdDqbuFmKYEO9PY7A4ARzRfg__>", "build_info": "b(build number)-(build commit hash)", "is_sleeping": false } diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 4b899ecf0..2488f81b8 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -3537,6 +3537,7 @@ void server_routes::init_routes() { {"vision", meta->has_inp_image}, {"audio", meta->has_inp_audio}, } }, + { "media_marker", get_media_marker() }, { "endpoint_slots", params.endpoint_slots }, { "endpoint_props", params.endpoint_props }, { "endpoint_metrics", params.endpoint_metrics }, From 91fef9536293dcc782594d142bee46648f86fb3c Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Sun, 19 Apr 2026 10:21:53 +0300 Subject: [PATCH 10/24] rpc : refactor the RPC transport (#21998) * rpc : refactor the RPC transport Move all transport related code into a separate file and use the socket_t interface to hide all transport implementation details. * fix win32 * better socket_t construction --- ggml/src/ggml-rpc/CMakeLists.txt | 1 + ggml/src/ggml-rpc/ggml-rpc.cpp | 806 +++---------------------------- ggml/src/ggml-rpc/transport.cpp | 683 ++++++++++++++++++++++++++ ggml/src/ggml-rpc/transport.h | 34 ++ 4 files changed, 782 insertions(+), 742 deletions(-) create mode 100644 ggml/src/ggml-rpc/transport.cpp create mode 100644 ggml/src/ggml-rpc/transport.h diff --git a/ggml/src/ggml-rpc/CMakeLists.txt b/ggml/src/ggml-rpc/CMakeLists.txt index 8671ce5ce..40e11fead 100644 --- a/ggml/src/ggml-rpc/CMakeLists.txt +++ b/ggml/src/ggml-rpc/CMakeLists.txt @@ -2,6 +2,7 @@ message(STATUS "Using RPC backend") ggml_add_backend_library(ggml-rpc ggml-rpc.cpp + transport.cpp ) if (WIN32) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 017ef0af3..2ded73978 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -2,6 +2,7 @@ #include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-cpp.h" +#include "transport.h" #include #include @@ -12,35 +13,11 @@ #include #include #include -#ifdef _WIN32 -# define WIN32_LEAN_AND_MEAN -# ifndef NOMINMAX -# define NOMINMAX -# endif -# include -# include -#else -# include -# include -# include -# include -# include -# include -# include -#endif #include #include #include #include -#ifdef GGML_RPC_RDMA -# include -# include -# ifndef _WIN32 -# include -# endif -#endif // GGML_RPC_RDMA - static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); #define LOG_DBG(...) \ @@ -49,128 +26,6 @@ static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); namespace fs = std::filesystem; -static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB - -#ifdef _WIN32 -typedef SOCKET sockfd_t; -using ssize_t = __int64; -#else -typedef int sockfd_t; -#endif - -// cross-platform socket - -#ifdef GGML_RPC_RDMA -static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock) -static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB -static constexpr size_t RDMA_GID_SIZE = 16; // RoCE GID / IB GID is always 16 bytes -using rdma_gid_t = std::array; - -struct rdma_conn { - struct ibv_context * ctx = nullptr; - struct ibv_pd * pd = nullptr; - struct ibv_cq * scq = nullptr; // send completions - struct ibv_cq * rcq = nullptr; // recv completions - struct ibv_qp * qp = nullptr; - - void * tx_buf = nullptr; - struct ibv_mr * tx_mr = nullptr; - - void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous - struct ibv_mr * rx_mr = nullptr; - int rx_head = 0; - - uint32_t max_inline = 0; - - uint8_t * rx_slot(int i) const { - return static_cast(rx_buf) + static_cast(i) * RDMA_CHUNK; - } - - bool post_rx(int i) { - struct ibv_sge sge = {}; - sge.addr = (uintptr_t)rx_slot(i); - sge.length = RDMA_CHUNK; - sge.lkey = rx_mr->lkey; - struct ibv_recv_wr wr = {}, * bad = nullptr; - wr.wr_id = (uint64_t)i; - wr.sg_list = &sge; - wr.num_sge = 1; - return ibv_post_recv(qp, &wr, &bad) == 0; - } - - ~rdma_conn() { - if (tx_mr) ibv_dereg_mr(tx_mr); - if (rx_mr) ibv_dereg_mr(rx_mr); - free(tx_buf); - free(rx_buf); - if (qp) ibv_destroy_qp(qp); - if (scq) ibv_destroy_cq(scq); - if (rcq) ibv_destroy_cq(rcq); - if (pd) ibv_dealloc_pd(pd); - if (ctx) ibv_close_device(ctx); - } -}; - -// Local RDMA parameters captured during the probe phase and later consumed -// by rdma_activate() after the remote side's caps arrive via HELLO. -struct rdma_local_info { - uint32_t qpn = 0; - uint32_t psn = 0; - uint8_t gid[RDMA_GID_SIZE] = {}; - uint8_t ib_port = 0; - int gid_idx = 0; - enum ibv_mtu path_mtu = IBV_MTU_1024; -}; -#endif // GGML_RPC_RDMA - -// conn_caps size for transport-agnostic capability exchange -static constexpr size_t RPC_CONN_CAPS_SIZE = 24; - -// conn_caps RDMA layout helper -#ifdef GGML_RPC_RDMA -struct rdma_caps { - uint32_t qpn; - uint32_t psn; - uint8_t gid[RDMA_GID_SIZE]; -}; -static_assert(sizeof(rdma_caps) == RPC_CONN_CAPS_SIZE, "rdma_caps must match conn_caps size"); -#endif // GGML_RPC_RDMA - -// Forward declarations for transport function pointers -struct socket_t; -static bool tcp_send_impl(socket_t * sock, const void * data, size_t size); -static bool tcp_recv_impl(socket_t * sock, void * data, size_t size); - -struct socket_t { - sockfd_t fd; - bool (*fn_send)(socket_t *, const void *, size_t) = tcp_send_impl; - bool (*fn_recv)(socket_t *, void *, size_t) = tcp_recv_impl; -#ifdef GGML_RPC_RDMA - std::unique_ptr rdma; - rdma_local_info rdma_local = {}; -#endif // GGML_RPC_RDMA - socket_t(sockfd_t fd) : fd(fd) {} - ~socket_t() { -#ifdef GGML_RPC_RDMA - rdma.reset(); -#endif // GGML_RPC_RDMA - LOG_DBG("[%s] closing socket %d\n", __func__, this->fd); -#ifdef _WIN32 - if (fd != INVALID_SOCKET) closesocket(this->fd); -#else - if (fd >= 0) close(this->fd); -#endif - } - - // Advertise local transport capabilities into conn_caps. - // May probe RDMA and store the probe on this socket for update_caps. - void get_caps(uint8_t * caps); - - // Activate transport upgrade based on remote conn_caps using the probe - // previously stored by get_caps. - void update_caps(const uint8_t * remote_caps); -}; - // macro for nicer error messages on server crash #define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT("Remote RPC server crashed or returned malformed response") @@ -403,540 +258,27 @@ static uint64_t fnv_hash(const uint8_t * data, size_t len) { return hash; } -static std::shared_ptr make_socket(sockfd_t fd) { -#ifdef _WIN32 - if (fd == INVALID_SOCKET) { - return nullptr; - } -#else - if (fd < 0) { - return nullptr; - } -#endif - return std::make_shared(fd); -} - -static bool set_no_delay(sockfd_t sockfd) { - int flag = 1; - // set TCP_NODELAY to disable Nagle's algorithm - int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); - return ret == 0; -} - -static bool set_reuse_addr(sockfd_t sockfd) { - int flag = 1; - int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int)); - return ret == 0; -} - -static std::shared_ptr socket_connect(const char * host, int port) { - struct sockaddr_in addr; - auto sockfd = socket(AF_INET, SOCK_STREAM, 0); - auto sock_ptr = make_socket(sockfd); - if (sock_ptr == nullptr) { - return nullptr; - } - if (!set_no_delay(sockfd)) { - GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); - return nullptr; - } - addr.sin_family = AF_INET; - addr.sin_port = htons(port); - struct hostent * server = gethostbyname(host); - if (server == NULL) { - GGML_LOG_ERROR("Cannot resolve host '%s'\n", host); - return nullptr; - } - memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length); - if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { - return nullptr; - } - return sock_ptr; -} - -static std::shared_ptr socket_accept(sockfd_t srv_sockfd) { - auto client_socket_fd = accept(srv_sockfd, NULL, NULL); - auto client_socket = make_socket(client_socket_fd); - if (client_socket == nullptr) { - return nullptr; - } - if (!set_no_delay(client_socket_fd)) { - GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); - return nullptr; - } - return client_socket; -} - -static std::shared_ptr create_server_socket(const char * host, int port) { - auto sockfd = socket(AF_INET, SOCK_STREAM, 0); - auto sock = make_socket(sockfd); - if (sock == nullptr) { - return nullptr; - } - if (!set_reuse_addr(sockfd)) { - GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n"); - return nullptr; - } - if (inet_addr(host) == INADDR_NONE) { - GGML_LOG_ERROR("Invalid host address: %s\n", host); - return nullptr; - } - struct sockaddr_in serv_addr; - serv_addr.sin_family = AF_INET; - serv_addr.sin_addr.s_addr = inet_addr(host); - serv_addr.sin_port = htons(port); - - if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { - return nullptr; - } - if (listen(sockfd, 1) < 0) { - return nullptr; - } - return sock; -} - -static bool send_data(sockfd_t sockfd, const void * data, size_t size) { - size_t bytes_sent = 0; - while (bytes_sent < size) { - size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE); - ssize_t n = send(sockfd, (const char *)data + bytes_sent, size_to_send, 0); - if (n < 0) { - GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n", - bytes_sent, size_to_send); - return false; - } - bytes_sent += (size_t)n; - } - return true; -} - -static bool recv_data(sockfd_t sockfd, void * data, size_t size) { - size_t bytes_recv = 0; - while (bytes_recv < size) { - size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE); - ssize_t n = recv(sockfd, (char *)data + bytes_recv, size_to_recv, 0); - if (n < 0) { - GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n", - bytes_recv, size_to_recv); - return false; - } - if (n == 0) { - LOG_DBG("recv returned 0 (peer closed?)\n"); - return false; - } - bytes_recv += (size_t)n; - } - return true; -} - -// TCP transport implementations (for function-pointer dispatch) - -static bool tcp_send_impl(socket_t * sock, const void * data, size_t size) { - return send_data(sock->fd, data, size); -} - -static bool tcp_recv_impl(socket_t * sock, void * data, size_t size) { - return recv_data(sock->fd, data, size); -} - -// RDMA transport (performance-optimized, auto-negotiated) - -#ifdef GGML_RPC_RDMA - -static bool rdma_send_impl(socket_t * sock, const void * data, size_t size); -static bool rdma_recv_impl(socket_t * sock, void * data, size_t size); - -static inline bool tcp_peer_closed(int fd) { - if (fd < 0) return false; -#ifndef _WIN32 - struct pollfd pfd = { fd, POLLIN | POLLRDHUP, 0 }; - int r = poll(&pfd, 1, 0); - return r > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLRDHUP)); -#else - return false; -#endif -} - -static inline bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc, int tcp_fd) { - for (uint64_t s = 0; ; s++) { - int n = ibv_poll_cq(cq, 1, wc); - if (n > 0) { - if (wc->status != IBV_WC_SUCCESS) { - GGML_LOG_ERROR("RDMA CQ wc error: status=%d (%s) vendor_err=0x%x\n", - wc->status, ibv_wc_status_str(wc->status), wc->vendor_err); - } - return wc->status == IBV_WC_SUCCESS; - } - if (n < 0) return false; - if ((s & 0xFFFFF) == 0 && s > 0) { - if (tcp_peer_closed(tcp_fd)) { - return false; - } - } - } -} - -static bool rdma_send(rdma_conn * c, const void * data, size_t size, int tcp_fd) { - const uint8_t * src = (const uint8_t *)data; - size_t rem = size; - while (rem > 0) { - size_t chunk = std::min(rem, RDMA_CHUNK); - - struct ibv_sge sge = {}; - struct ibv_send_wr wr = {}, * bad = nullptr; - wr.opcode = IBV_WR_SEND; - wr.sg_list = &sge; - wr.num_sge = 1; - - if (chunk <= c->max_inline) { - sge.addr = (uintptr_t)src; - sge.length = chunk; - wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE; - } else { - memcpy(c->tx_buf, src, chunk); - sge.addr = (uintptr_t)c->tx_buf; - sge.length = chunk; - sge.lkey = c->tx_mr->lkey; - wr.send_flags = IBV_SEND_SIGNALED; - } - - if (ibv_post_send(c->qp, &wr, &bad) != 0) return false; - struct ibv_wc wc; - if (!rdma_poll(c->scq, &wc, tcp_fd)) return false; - - src += chunk; - rem -= chunk; - } - return true; -} - - -static bool rdma_recv(rdma_conn * c, void * data, size_t size, int tcp_fd) { - uint8_t * dst = (uint8_t *)data; - size_t rem = size; - while (rem > 0) { - struct ibv_wc wc; - if (!rdma_poll(c->rcq, &wc, tcp_fd)) return false; - - int slot = (int)wc.wr_id; - size_t got = wc.byte_len; - memcpy(dst, c->rx_slot(slot), got); - - if (!c->post_rx(slot)) return false; - - dst += got; - rem -= got; - } - return true; -} - -static bool rdma_send_impl(socket_t * sock, const void * data, size_t size) { - return rdma_send(sock->rdma.get(), data, size, sock->fd); -} - -static bool rdma_recv_impl(socket_t * sock, void * data, size_t size) { - return rdma_recv(sock->rdma.get(), data, size, sock->fd); -} - -// Build a RoCE GID-shaped 16-byte target from a TCP socket's local address. -// Used to match the socket's local IP against the kernel's GID table so that -// a single memcmp handles IPv4, IPv4-mapped IPv6, and native IPv6 uniformly: -// AF_INET -> ::ffff:a.b.c.d (bytes 10-11 = 0xff, last 4 = IPv4) -// AF_INET6 (IPv4-mapped) -> ::ffff:a.b.c.d (already in GID shape) -// AF_INET6 (native v6) -> the 16-byte IPv6 address as-is -// Returns std::nullopt on unsupported family or getsockname failure. -static std::optional rdma_build_target_gid(sockfd_t tcp_fd) { - sockaddr_storage addr = {}; - socklen_t addr_len = sizeof(addr); - if (getsockname(tcp_fd, reinterpret_cast(&addr), &addr_len) != 0) { - return std::nullopt; - } - rdma_gid_t target = {}; - if (addr.ss_family == AF_INET) { - const auto * a = reinterpret_cast(&addr); - target[10] = 0xff; - target[11] = 0xff; - memcpy(&target[12], &a->sin_addr, 4); - return target; - } - if (addr.ss_family == AF_INET6) { - const auto * a = reinterpret_cast(&addr); - memcpy(target.data(), &a->sin6_addr, RDMA_GID_SIZE); - return target; - } - return std::nullopt; -} - -static rdma_conn * rdma_probe(sockfd_t tcp_fd, rdma_local_info * out) { - const char * dev_env = std::getenv("GGML_RDMA_DEV"); - const char * gid_env = std::getenv("GGML_RDMA_GID"); - - auto target_gid = rdma_build_target_gid(tcp_fd); - if (!target_gid) { - return nullptr; - } - - const uint8_t ib_port = 1; - int num_devs = 0; - ibv_device ** devs = ibv_get_device_list(&num_devs); - if (!devs || num_devs == 0) return nullptr; - - ibv_context * ibctx = nullptr; - const char * matched_dev = nullptr; - int gid_idx = gid_env ? atoi(gid_env) : -1; - int gid_version = IBV_GID_TYPE_IB; // 0 = unknown/IB - - for (int d = 0; d < num_devs; d++) { - const char * dn = ibv_get_device_name(devs[d]); - if (dev_env && strcmp(dev_env, dn) != 0) continue; - - ibv_context * ctx = ibv_open_device(devs[d]); - if (!ctx) continue; - - ibv_port_attr pa; - if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; } - - int found_gid = gid_idx; - int found_version = IBV_GID_TYPE_IB; - if (found_gid < 0) { - // Find a GID on this port whose bytes equal the local TCP address - // (IPv4 or IPv6). Prefer RoCE v2 (UDP/IP, L3-routable) over v1 - // (raw Ethernet, same-L2 only) so silent hangs on L3-routed paths - // are avoided. ibv_query_gid_ex returns gid+type in one call. - int v2_idx = -1; - int v1_idx = -1; - for (int i = 0; i < pa.gid_tbl_len; i++) { - ibv_gid_entry entry = {}; - if (ibv_query_gid_ex(ctx, ib_port, i, &entry, 0) != 0) continue; - if (memcmp(entry.gid.raw, target_gid->data(), RDMA_GID_SIZE) != 0) continue; - if (entry.gid_type == IBV_GID_TYPE_ROCE_V2 && v2_idx < 0) { - v2_idx = i; - } else if (entry.gid_type == IBV_GID_TYPE_ROCE_V1 && v1_idx < 0) { - v1_idx = i; - } - } - if (v2_idx >= 0) { - found_gid = v2_idx; - found_version = IBV_GID_TYPE_ROCE_V2; - } else if (v1_idx >= 0) { - found_gid = v1_idx; - found_version = IBV_GID_TYPE_ROCE_V1; - } - } else { - // Explicit GID index from GGML_RDMA_GID — fetch its type for logging. - ibv_gid_entry entry = {}; - if (ibv_query_gid_ex(ctx, ib_port, found_gid, &entry, 0) == 0) { - found_version = entry.gid_type; - } - } - if (found_gid >= 0) { - ibctx = ctx; - gid_idx = found_gid; - gid_version = found_version; - matched_dev = dn; - out->path_mtu = pa.active_mtu; - break; - } - ibv_close_device(ctx); - } - ibv_free_device_list(devs); - if (!ibctx) return nullptr; - - out->ib_port = ib_port; - out->gid_idx = gid_idx; - - // unique_ptr owns ibctx and every subsequent resource via ~rdma_conn(), - // so each failure path is a plain `return nullptr;`. - auto c = std::make_unique(); - c->ctx = ibctx; - - c->pd = ibv_alloc_pd(ibctx); - if (!c->pd) return nullptr; - - c->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0); - c->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0); - if (!c->scq || !c->rcq) return nullptr; - - ibv_qp_init_attr qia = {}; - qia.send_cq = c->scq; - qia.recv_cq = c->rcq; - qia.qp_type = IBV_QPT_RC; - qia.cap.max_send_wr = 4; - qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4; - qia.cap.max_send_sge = 1; - qia.cap.max_recv_sge = 1; - qia.cap.max_inline_data = 256; - - c->qp = ibv_create_qp(c->pd, &qia); - if (!c->qp) return nullptr; - c->max_inline = qia.cap.max_inline_data; - - c->tx_buf = aligned_alloc(4096, RDMA_CHUNK); - c->rx_buf = aligned_alloc(4096, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK); - if (!c->tx_buf || !c->rx_buf) return nullptr; - - c->tx_mr = ibv_reg_mr(c->pd, c->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE); - c->rx_mr = ibv_reg_mr(c->pd, c->rx_buf, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK, - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); - if (!c->tx_mr || !c->rx_mr) return nullptr; - - ibv_gid local_gid; - if (ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid) != 0) return nullptr; - - out->qpn = c->qp->qp_num; - out->psn = c->qp->qp_num & 0xffffff; - memcpy(out->gid, &local_gid, RDMA_GID_SIZE); - - const char * ver_str = ""; - if (gid_version == IBV_GID_TYPE_ROCE_V2) { - ver_str = " RoCEv2"; - } else if (gid_version == IBV_GID_TYPE_ROCE_V1) { - ver_str = " RoCEv1"; - } - GGML_LOG_INFO("RDMA probed: dev=%s gid=%d%s qpn=%u inline=%u\n", - matched_dev, gid_idx, ver_str, out->qpn, c->max_inline); - return c.release(); -} - -// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET->INIT->pre-post->RTR->RTS. -// On success, the connection is live and ready for rdma_send/rdma_recv. -static bool rdma_activate(rdma_conn * c, const rdma_local_info * local, - uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) { - // RESET -> INIT - { - struct ibv_qp_attr a = {}; - a.qp_state = IBV_QPS_INIT; - a.port_num = local->ib_port; - a.pkey_index = 0; - a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE; - if (ibv_modify_qp(c->qp, &a, - IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { - return false; - } - } - - for (int i = 0; i < RDMA_RX_DEPTH; i++) { - if (!c->post_rx(i)) return false; - } - - // INIT -> RTR - { - struct ibv_qp_attr a = {}; - a.qp_state = IBV_QPS_RTR; - a.path_mtu = local->path_mtu; - a.dest_qp_num = remote_qpn; - a.rq_psn = remote_psn; - a.max_dest_rd_atomic = 1; - a.min_rnr_timer = 1; - a.ah_attr.is_global = 1; - memcpy(&a.ah_attr.grh.dgid, remote_gid, RDMA_GID_SIZE); - a.ah_attr.grh.hop_limit = 1; - a.ah_attr.grh.sgid_index = local->gid_idx; - a.ah_attr.dlid = 0; - a.ah_attr.port_num = local->ib_port; - if (ibv_modify_qp(c->qp, &a, - IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | - IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) { - return false; - } - } - - // RTR -> RTS - { - struct ibv_qp_attr a = {}; - a.qp_state = IBV_QPS_RTS; - a.timeout = 14; - a.retry_cnt = 7; - a.rnr_retry = 7; - a.sq_psn = local->psn; - a.max_rd_atomic = 1; - if (ibv_modify_qp(c->qp, &a, - IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | - IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) { - return false; - } - } - - GGML_LOG_INFO("RDMA activated: qpn=%u->%u mtu=%d rx_depth=%d\n", - local->qpn, remote_qpn, 128 << local->path_mtu, RDMA_RX_DEPTH); - return true; -} - -#endif // GGML_RPC_RDMA - -// --------------------------------------------------------------------------- -// socket_t transport capability methods -// --------------------------------------------------------------------------- - -void socket_t::get_caps(uint8_t * caps) { - memset(caps, 0, RPC_CONN_CAPS_SIZE); -#ifdef GGML_RPC_RDMA - rdma_local = {}; - rdma.reset(rdma_probe(fd, &rdma_local)); - if (rdma) { - rdma_caps rc = {}; - rc.qpn = rdma_local.qpn; - rc.psn = rdma_local.psn; - memcpy(rc.gid, rdma_local.gid, RDMA_GID_SIZE); - memcpy(caps, &rc, sizeof(rc)); - } -#endif // GGML_RPC_RDMA -} - -void socket_t::update_caps(const uint8_t * remote_caps) { -#ifdef GGML_RPC_RDMA - if (!rdma) { - return; - } - rdma_caps rc = {}; - memcpy(&rc, remote_caps, sizeof(rc)); - if (rc.qpn == 0) { - rdma.reset(); - return; - } - if (rdma_activate(rdma.get(), &rdma_local, rc.qpn, rc.psn, rc.gid)) { - fn_send = rdma_send_impl; - fn_recv = rdma_recv_impl; - } else { - GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n"); - rdma.reset(); - } -#else - (void)remote_caps; -#endif // GGML_RPC_RDMA -} - -// unified transport dispatch (via function pointers) - -static bool send_data(socket_t * sock, const void * data, size_t size) { - return sock->fn_send(sock, data, size); -} - -static bool recv_data(socket_t * sock, void * data, size_t size) { - return sock->fn_recv(sock, data, size); -} - -static bool send_msg(socket_t * sock, const void * msg, size_t msg_size) { - if (!send_data(sock, &msg_size, sizeof(msg_size))) { +static bool send_msg(socket_ptr sock, const void * msg, size_t msg_size) { + if (!sock->send_data(&msg_size, sizeof(msg_size))) { return false; } - return send_data(sock, msg, msg_size); + return sock->send_data(msg, msg_size); } -static bool recv_msg(socket_t * sock, void * msg, size_t msg_size) { +static bool recv_msg(socket_ptr sock, void * msg, size_t msg_size) { uint64_t size; - if (!recv_data(sock, &size, sizeof(size))) { + if (!sock->recv_data(&size, sizeof(size))) { return false; } if (size != msg_size) { return false; } - return recv_data(sock, msg, msg_size); + return sock->recv_data(msg, msg_size); } -static bool recv_msg(socket_t * sock, std::vector & input) { +static bool recv_msg(socket_ptr sock, std::vector & input) { uint64_t size; - if (!recv_data(sock, &size, sizeof(size))) { + if (!sock->recv_data(&size, sizeof(size))) { return false; } try { @@ -945,7 +287,7 @@ static bool recv_msg(socket_t * sock, std::vector & input) { GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size); return false; } - return recv_data(sock, input.data(), size); + return sock->recv_data(input.data(), size); } static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { @@ -964,15 +306,15 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | // No response -static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size) { +static bool send_rpc_cmd(socket_ptr sock, enum rpc_cmd cmd, const void * input, size_t input_size) { uint8_t cmd_byte = cmd; - if (!send_data(sock.get(), &cmd_byte, sizeof(cmd_byte))) { + if (!sock->send_data(&cmd_byte, sizeof(cmd_byte))) { return false; } - if (!send_data(sock.get(), &input_size, sizeof(input_size))) { + if (!sock->send_data(&input_size, sizeof(input_size))) { return false; } - if (!send_data(sock.get(), input, input_size)) { + if (!sock->send_data(input, input_size)) { return false; } return true; @@ -980,18 +322,18 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | // RPC response: | response_size (8 bytes) | response_data (response_size bytes) | -static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) { +static bool send_rpc_cmd(socket_ptr sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) { if (!send_rpc_cmd(sock, cmd, input, input_size)) { return false; } uint64_t out_size; - if (!recv_data(sock.get(), &out_size, sizeof(out_size))) { + if (!sock->recv_data(&out_size, sizeof(out_size))) { return false; } if (out_size != output_size) { return false; } - if (!recv_data(sock.get(), output, output_size)) { + if (!sock->recv_data(output, output_size)) { return false; } return true; @@ -1025,7 +367,6 @@ static std::shared_ptr get_socket(const std::string & endpoint) { static std::mutex mutex; std::lock_guard lock(mutex); static std::unordered_map> sockets; - static bool initialized = false; auto it = sockets.find(endpoint); if (it != sockets.end()) { @@ -1040,19 +381,10 @@ static std::shared_ptr get_socket(const std::string & endpoint) { return nullptr; } -#ifdef _WIN32 - if (!initialized) { - WSADATA wsaData; - int res = WSAStartup(MAKEWORD(2, 2), &wsaData); - if (res != 0) { - return nullptr; - } - initialized = true; + if (!rpc_transport_init()) { + return nullptr; } -#else - GGML_UNUSED(initialized); -#endif - auto sock = socket_connect(host.c_str(), port); + auto sock = socket_t::connect(host.c_str(), port); if (sock == nullptr) { return nullptr; } @@ -2110,10 +1442,10 @@ rpc_server::~rpc_server() { } static void rpc_serve_client(const std::vector & backends, const char * cache_dir, - socket_t * sockfd) { + socket_ptr sock) { rpc_server server(backends, cache_dir); uint8_t cmd; - if (!recv_data(sockfd, &cmd, 1)) { + if (!sock->recv_data(&cmd, 1)) { return; } if (cmd != RPC_CMD_HELLO) { @@ -2123,7 +1455,7 @@ static void rpc_serve_client(const std::vector & backends, const // Read input_size and validate protocol version uint64_t hello_input_size; - if (!recv_data(sockfd, &hello_input_size, sizeof(hello_input_size))) { + if (!sock->recv_data(&hello_input_size, sizeof(hello_input_size))) { return; } @@ -2134,24 +1466,22 @@ static void rpc_serve_client(const std::vector & backends, const } rpc_msg_hello_req req = {}; - if (!recv_data(sockfd, &req, sizeof(req))) { + if (!sock->recv_data(&req, sizeof(req))) { return; } rpc_msg_hello_rsp rsp = {}; server.hello(rsp); - // Advertise server transport capabilities based on client's caps - sockfd->get_caps(rsp.conn_caps); - - if (!send_msg(sockfd, &rsp, sizeof(rsp))) { + sock->get_caps(rsp.conn_caps); + if (!send_msg(sock, &rsp, sizeof(rsp))) { return; } // Activate transport upgrade using client's caps - sockfd->update_caps(req.conn_caps); + sock->update_caps(req.conn_caps); while (true) { - if (!recv_data(sockfd, &cmd, 1)) { + if (!sock->recv_data(&cmd, 1)) { break; } if (cmd >= RPC_CMD_COUNT) { @@ -2165,115 +1495,115 @@ static void rpc_serve_client(const std::vector & backends, const return; } case RPC_CMD_DEVICE_COUNT: { - if (!recv_msg(sockfd, nullptr, 0)) { + if (!recv_msg(sock, nullptr, 0)) { return; } rpc_msg_device_count_rsp response; response.device_count = backends.size(); - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_ALLOC_BUFFER: { rpc_msg_alloc_buffer_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_alloc_buffer_rsp response; if (!server.alloc_buffer(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_ALLOC_SIZE: { rpc_msg_get_alloc_size_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_alloc_size_rsp response; if (!server.get_alloc_size(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_ALIGNMENT: { rpc_msg_get_alignment_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_alignment_rsp response; if (!server.get_alignment(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_MAX_SIZE: { rpc_msg_get_max_size_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_max_size_rsp response; if (!server.get_max_size(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_BUFFER_GET_BASE: { rpc_msg_buffer_get_base_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_buffer_get_base_rsp response; if (!server.buffer_get_base(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_FREE_BUFFER: { rpc_msg_free_buffer_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } if (!server.free_buffer(request)) { return; } - if (!send_msg(sockfd, nullptr, 0)) { + if (!send_msg(sock, nullptr, 0)) { return; } break; } case RPC_CMD_BUFFER_CLEAR: { rpc_msg_buffer_clear_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } if (!server.buffer_clear(request)) { return; } - if (!send_msg(sockfd, nullptr, 0)) { + if (!send_msg(sock, nullptr, 0)) { return; } break; } case RPC_CMD_SET_TENSOR: { std::vector input; - if (!recv_msg(sockfd, input)) { + if (!recv_msg(sock, input)) { return; } if (!server.set_tensor(input)) { @@ -2283,62 +1613,62 @@ static void rpc_serve_client(const std::vector & backends, const } case RPC_CMD_SET_TENSOR_HASH: { rpc_msg_set_tensor_hash_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_set_tensor_hash_rsp response; if (!server.set_tensor_hash(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_INIT_TENSOR: { rpc_msg_init_tensor_req request; - if (!recv_msg(sockfd, &request,sizeof(request))) { + if (!recv_msg(sock, &request,sizeof(request))) { return; } if (!server.init_tensor(request)) { return; } - if (!send_msg(sockfd, nullptr, 0)) { + if (!send_msg(sock, nullptr, 0)) { return; } break; } case RPC_CMD_GET_TENSOR: { rpc_msg_get_tensor_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } std::vector response; if (!server.get_tensor(request, response)) { return; } - if (!send_msg(sockfd, response.data(), response.size())) { + if (!send_msg(sock, response.data(), response.size())) { return; } break; } case RPC_CMD_COPY_TENSOR: { rpc_msg_copy_tensor_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_copy_tensor_rsp response; if (!server.copy_tensor(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GRAPH_COMPUTE: { std::vector input; - if (!recv_msg(sockfd, input)) { + if (!recv_msg(sock, input)) { return; } if (!server.graph_compute(input)) { @@ -2348,7 +1678,7 @@ static void rpc_serve_client(const std::vector & backends, const } case RPC_CMD_GRAPH_RECOMPUTE: { rpc_msg_graph_recompute_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } if (!server.graph_recompute(request)) { @@ -2358,14 +1688,14 @@ static void rpc_serve_client(const std::vector & backends, const } case RPC_CMD_GET_DEVICE_MEMORY: { rpc_msg_get_device_memory_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_device_memory_rsp response; if (!server.get_device_memory(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; @@ -2424,36 +1754,28 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir #else printf(" transport : TCP\n"); #endif // GGML_RPC_RDMA -#ifdef _WIN32 - { - WSADATA wsaData; - int res = WSAStartup(MAKEWORD(2, 2), &wsaData); - if (res != 0) { - fprintf(stderr, "WSAStartup failed: %d\n", res); - return; - } + if (!rpc_transport_init()) { + fprintf(stderr, "Failed to initialize RPC transport\n"); + return; } -#endif - auto server_socket = create_server_socket(host.c_str(), port); + auto server_socket = socket_t::create_server(host.c_str(), port); if (server_socket == nullptr) { fprintf(stderr, "Failed to create server socket\n"); return; } while (true) { - auto client_socket = socket_accept(server_socket->fd); + auto client_socket = server_socket->accept(); if (client_socket == nullptr) { fprintf(stderr, "Failed to accept client connection\n"); return; } printf("Accepted client connection\n"); fflush(stdout); - rpc_serve_client(backends, cache_dir, client_socket.get()); + rpc_serve_client(backends, cache_dir, client_socket); printf("Client connection closed\n"); fflush(stdout); } -#ifdef _WIN32 - WSACleanup(); -#endif + rpc_transport_shutdown(); for (auto backend : backends) { ggml_backend_free(backend); } diff --git a/ggml/src/ggml-rpc/transport.cpp b/ggml/src/ggml-rpc/transport.cpp new file mode 100644 index 000000000..a72815242 --- /dev/null +++ b/ggml/src/ggml-rpc/transport.cpp @@ -0,0 +1,683 @@ +#include "transport.h" +#include "ggml-impl.h" + +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +# include +#else +# include +# include +# include +# include +# include +# include +# include +#endif +#include +#include +#include + +#ifdef GGML_RPC_RDMA +# include +# include +# ifndef _WIN32 +# include +# endif +#endif // GGML_RPC_RDMA + +#ifdef _WIN32 +typedef SOCKET sockfd_t; +using ssize_t = __int64; +#else +typedef int sockfd_t; +#endif + +static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); + +#define LOG_DBG(...) \ + do { if (RPC_DEBUG) GGML_LOG_DEBUG(__VA_ARGS__); } while (0) + +#ifdef GGML_RPC_RDMA +static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock) +static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB +static constexpr size_t RDMA_GID_SIZE = 16; // RoCE GID / IB GID is always 16 bytes +using rdma_gid_t = std::array; + +struct rdma_conn { + struct ibv_context * ctx = nullptr; + struct ibv_pd * pd = nullptr; + struct ibv_cq * scq = nullptr; // send completions + struct ibv_cq * rcq = nullptr; // recv completions + struct ibv_qp * qp = nullptr; + + void * tx_buf = nullptr; + struct ibv_mr * tx_mr = nullptr; + + void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous + struct ibv_mr * rx_mr = nullptr; + int rx_head = 0; + + uint32_t max_inline = 0; + + uint8_t * rx_slot(int i) const { + return static_cast(rx_buf) + static_cast(i) * RDMA_CHUNK; + } + + bool post_rx(int i) { + struct ibv_sge sge = {}; + sge.addr = (uintptr_t)rx_slot(i); + sge.length = RDMA_CHUNK; + sge.lkey = rx_mr->lkey; + struct ibv_recv_wr wr = {}, * bad = nullptr; + wr.wr_id = (uint64_t)i; + wr.sg_list = &sge; + wr.num_sge = 1; + return ibv_post_recv(qp, &wr, &bad) == 0; + } + + ~rdma_conn() { + if (tx_mr) ibv_dereg_mr(tx_mr); + if (rx_mr) ibv_dereg_mr(rx_mr); + free(tx_buf); + free(rx_buf); + if (qp) ibv_destroy_qp(qp); + if (scq) ibv_destroy_cq(scq); + if (rcq) ibv_destroy_cq(rcq); + if (pd) ibv_dealloc_pd(pd); + if (ctx) ibv_close_device(ctx); + } +}; + +// Local RDMA parameters captured during the probe phase and later consumed +// by rdma_activate() after the remote side's caps arrive via HELLO. +struct rdma_local_info { + uint32_t qpn = 0; + uint32_t psn = 0; + uint8_t gid[RDMA_GID_SIZE] = {}; + uint8_t ib_port = 0; + int gid_idx = 0; + enum ibv_mtu path_mtu = IBV_MTU_1024; +}; + +struct rdma_caps { + uint32_t qpn; + uint32_t psn; + uint8_t gid[RDMA_GID_SIZE]; +}; + +static_assert(sizeof(rdma_caps) == RPC_CONN_CAPS_SIZE, "rdma_caps must match conn_caps size"); + +#endif // GGML_RPC_RDMA + +struct socket_t::impl { + impl(sockfd_t fd) : use_rdma(false), fd(fd) {} + ~impl(); + bool send_data(const void * data, size_t size); + bool recv_data(void * data, size_t size); + void get_caps(uint8_t * local_caps); + void update_caps(const uint8_t * remote_caps); + +#ifdef GGML_RPC_RDMA + bool tcp_peer_closed(); + std::optional rdma_build_target_gid(); + bool rdma_probe(); + bool rdma_activate(uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid); + bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc); + bool rdma_send(const void * data, size_t size); + bool rdma_recv(void * data, size_t size); + + std::unique_ptr rdma; + rdma_local_info rdma_local = {}; +#endif // GGML_RPC_RDMA + bool use_rdma; + sockfd_t fd; +}; + +socket_t::impl::~impl() { +#ifdef GGML_RPC_RDMA + rdma.reset(); +#endif // GGML_RPC_RDMA + LOG_DBG("[%s] closing socket %d\n", __func__, this->fd); +#ifdef _WIN32 + if (fd != INVALID_SOCKET) closesocket(this->fd); +#else + if (fd >= 0) close(this->fd); +#endif +} + +#ifdef GGML_RPC_RDMA + +bool socket_t::impl::tcp_peer_closed() { + if (fd < 0) return false; +#ifndef _WIN32 + struct pollfd pfd = { fd, POLLIN | POLLRDHUP, 0 }; + int r = poll(&pfd, 1, 0); + return r > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLRDHUP)); +#else + return false; +#endif +} + +// Build a RoCE GID-shaped 16-byte target from a TCP socket's local address. +// Used to match the socket's local IP against the kernel's GID table so that +// a single memcmp handles IPv4, IPv4-mapped IPv6, and native IPv6 uniformly: +// AF_INET -> ::ffff:a.b.c.d (bytes 10-11 = 0xff, last 4 = IPv4) +// AF_INET6 (IPv4-mapped) -> ::ffff:a.b.c.d (already in GID shape) +// AF_INET6 (native v6) -> the 16-byte IPv6 address as-is +// Returns std::nullopt on unsupported family or getsockname failure. +std::optional socket_t::impl::rdma_build_target_gid() { + sockaddr_storage addr = {}; + socklen_t addr_len = sizeof(addr); + if (getsockname(fd, reinterpret_cast(&addr), &addr_len) != 0) { + return std::nullopt; + } + rdma_gid_t target = {}; + if (addr.ss_family == AF_INET) { + const auto * a = reinterpret_cast(&addr); + target[10] = 0xff; + target[11] = 0xff; + memcpy(&target[12], &a->sin_addr, 4); + return target; + } + if (addr.ss_family == AF_INET6) { + const auto * a = reinterpret_cast(&addr); + memcpy(target.data(), &a->sin6_addr, RDMA_GID_SIZE); + return target; + } + return std::nullopt; +} + +bool socket_t::impl::rdma_probe() { + const char * dev_env = std::getenv("GGML_RDMA_DEV"); + const char * gid_env = std::getenv("GGML_RDMA_GID"); + + auto target_gid = rdma_build_target_gid(); + if (!target_gid) { + return false; + } + + const uint8_t ib_port = 1; + int num_devs = 0; + ibv_device ** devs = ibv_get_device_list(&num_devs); + if (!devs || num_devs == 0) return false; + + ibv_context * ibctx = nullptr; + const char * matched_dev = nullptr; + int gid_idx = gid_env ? atoi(gid_env) : -1; + int gid_version = IBV_GID_TYPE_IB; // 0 = unknown/IB + + for (int d = 0; d < num_devs; d++) { + const char * dn = ibv_get_device_name(devs[d]); + if (dev_env && strcmp(dev_env, dn) != 0) continue; + + ibv_context * ctx = ibv_open_device(devs[d]); + if (!ctx) continue; + + ibv_port_attr pa; + if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; } + + int found_gid = gid_idx; + int found_version = IBV_GID_TYPE_IB; + if (found_gid < 0) { + // Find a GID on this port whose bytes equal the local TCP address + // (IPv4 or IPv6). Prefer RoCE v2 (UDP/IP, L3-routable) over v1 + // (raw Ethernet, same-L2 only) so silent hangs on L3-routed paths + // are avoided. ibv_query_gid_ex returns gid+type in one call. + int v2_idx = -1; + int v1_idx = -1; + for (int i = 0; i < pa.gid_tbl_len; i++) { + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, i, &entry, 0) != 0) continue; + if (memcmp(entry.gid.raw, target_gid->data(), RDMA_GID_SIZE) != 0) continue; + if (entry.gid_type == IBV_GID_TYPE_ROCE_V2 && v2_idx < 0) { + v2_idx = i; + } else if (entry.gid_type == IBV_GID_TYPE_ROCE_V1 && v1_idx < 0) { + v1_idx = i; + } + } + if (v2_idx >= 0) { + found_gid = v2_idx; + found_version = IBV_GID_TYPE_ROCE_V2; + } else if (v1_idx >= 0) { + found_gid = v1_idx; + found_version = IBV_GID_TYPE_ROCE_V1; + } + } else { + // Explicit GID index from GGML_RDMA_GID — fetch its type for logging. + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, found_gid, &entry, 0) == 0) { + found_version = entry.gid_type; + } + } + if (found_gid >= 0) { + ibctx = ctx; + gid_idx = found_gid; + gid_version = found_version; + matched_dev = dn; + rdma_local.path_mtu = pa.active_mtu; + break; + } + ibv_close_device(ctx); + } + ibv_free_device_list(devs); + if (!ibctx) return false; + + rdma_local.ib_port = ib_port; + rdma_local.gid_idx = gid_idx; + + rdma = std::make_unique(); + rdma->ctx = ibctx; + + rdma->pd = ibv_alloc_pd(ibctx); + if (!rdma->pd) return false; + + rdma->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0); + rdma->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0); + if (!rdma->scq || !rdma->rcq) return false; + + ibv_qp_init_attr qia = {}; + qia.send_cq = rdma->scq; + qia.recv_cq = rdma->rcq; + qia.qp_type = IBV_QPT_RC; + qia.cap.max_send_wr = 4; + qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4; + qia.cap.max_send_sge = 1; + qia.cap.max_recv_sge = 1; + qia.cap.max_inline_data = 256; + + rdma->qp = ibv_create_qp(rdma->pd, &qia); + if (!rdma->qp) return false; + rdma->max_inline = qia.cap.max_inline_data; + + rdma->tx_buf = aligned_alloc(4096, RDMA_CHUNK); + rdma->rx_buf = aligned_alloc(4096, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK); + if (!rdma->tx_buf || !rdma->rx_buf) return false; + + rdma->tx_mr = ibv_reg_mr(rdma->pd, rdma->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE); + rdma->rx_mr = ibv_reg_mr(rdma->pd, rdma->rx_buf, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + if (!rdma->tx_mr || !rdma->rx_mr) return false; + + ibv_gid local_gid; + if (ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid) != 0) return false; + + rdma_local.qpn = rdma->qp->qp_num; + rdma_local.psn = rdma->qp->qp_num & 0xffffff; + memcpy(&rdma_local.gid, &local_gid, RDMA_GID_SIZE); + + const char * ver_str = ""; + if (gid_version == IBV_GID_TYPE_ROCE_V2) { + ver_str = " RoCEv2"; + } else if (gid_version == IBV_GID_TYPE_ROCE_V1) { + ver_str = " RoCEv1"; + } + GGML_LOG_INFO("RDMA probed: dev=%s gid=%d%s qpn=%u inline=%u\n", + matched_dev, gid_idx, ver_str, rdma_local.qpn, rdma->max_inline); + return true; +} + +// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET->INIT->pre-post->RTR->RTS. +// On success, the connection is live and ready for rdma_send/rdma_recv. +bool socket_t::impl::rdma_activate(uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) { + // RESET -> INIT + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_INIT; + a.port_num = rdma_local.ib_port; + a.pkey_index = 0; + a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { + return false; + } + } + + for (int i = 0; i < RDMA_RX_DEPTH; i++) { + if (!rdma->post_rx(i)) return false; + } + + // INIT -> RTR + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTR; + a.path_mtu = rdma_local.path_mtu; + a.dest_qp_num = remote_qpn; + a.rq_psn = remote_psn; + a.max_dest_rd_atomic = 1; + a.min_rnr_timer = 1; + a.ah_attr.is_global = 1; + memcpy(&a.ah_attr.grh.dgid, remote_gid, RDMA_GID_SIZE); + a.ah_attr.grh.hop_limit = 1; + a.ah_attr.grh.sgid_index = rdma_local.gid_idx; + a.ah_attr.dlid = 0; + a.ah_attr.port_num = rdma_local.ib_port; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | + IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) { + return false; + } + } + + // RTR -> RTS + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTS; + a.timeout = 14; + a.retry_cnt = 7; + a.rnr_retry = 7; + a.sq_psn = rdma_local.psn; + a.max_rd_atomic = 1; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | + IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) { + return false; + } + } + + GGML_LOG_INFO("RDMA activated: qpn=%u->%u mtu=%d rx_depth=%d\n", + rdma_local.qpn, remote_qpn, 128 << rdma_local.path_mtu, RDMA_RX_DEPTH); + return true; +} + +bool socket_t::impl::rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc) { + for (uint64_t s = 0; ; s++) { + int n = ibv_poll_cq(cq, 1, wc); + if (n > 0) { + if (wc->status != IBV_WC_SUCCESS) { + GGML_LOG_ERROR("RDMA CQ wc error: status=%d (%s) vendor_err=0x%x\n", + wc->status, ibv_wc_status_str(wc->status), wc->vendor_err); + } + return wc->status == IBV_WC_SUCCESS; + } + if (n < 0) return false; + if ((s & 0xFFFFF) == 0 && s > 0) { + if (tcp_peer_closed()) { + return false; + } + } + } +} + +bool socket_t::impl::rdma_send(const void * data, size_t size) { + rdma_conn * c = rdma.get(); + const uint8_t * src = (const uint8_t *)data; + size_t rem = size; + while (rem > 0) { + size_t chunk = std::min(rem, RDMA_CHUNK); + + struct ibv_sge sge = {}; + struct ibv_send_wr wr = {}, * bad = nullptr; + wr.opcode = IBV_WR_SEND; + wr.sg_list = &sge; + wr.num_sge = 1; + + if (chunk <= c->max_inline) { + sge.addr = (uintptr_t)src; + sge.length = chunk; + wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE; + } else { + memcpy(c->tx_buf, src, chunk); + sge.addr = (uintptr_t)c->tx_buf; + sge.length = chunk; + sge.lkey = c->tx_mr->lkey; + wr.send_flags = IBV_SEND_SIGNALED; + } + + if (ibv_post_send(c->qp, &wr, &bad) != 0) return false; + struct ibv_wc wc; + if (!rdma_poll(c->scq, &wc)) return false; + + src += chunk; + rem -= chunk; + } + return true; +} + +bool socket_t::impl::rdma_recv(void * data, size_t size) { + rdma_conn * c = rdma.get(); + uint8_t * dst = (uint8_t *)data; + size_t rem = size; + while (rem > 0) { + struct ibv_wc wc; + if (!rdma_poll(c->rcq, &wc)) return false; + + int slot = (int)wc.wr_id; + size_t got = wc.byte_len; + memcpy(dst, c->rx_slot(slot), got); + + if (!c->post_rx(slot)) return false; + + dst += got; + rem -= got; + } + return true; +} + +#endif // GGML_RPC_RDMA + +bool socket_t::impl::send_data(const void * data, size_t size) { +#ifdef GGML_RPC_RDMA + if (use_rdma) { + return rdma_send(data, size); + } +#endif + size_t bytes_sent = 0; + while (bytes_sent < size) { + size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE); + ssize_t n = send(fd, (const char *)data + bytes_sent, size_to_send, 0); + if (n < 0) { + GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n", + bytes_sent, size_to_send); + return false; + } + bytes_sent += (size_t)n; + } + return true; +} + +bool socket_t::impl::recv_data(void * data, size_t size) { +#ifdef GGML_RPC_RDMA + if (use_rdma) { + return rdma_recv(data, size); + } +#endif + size_t bytes_recv = 0; + while (bytes_recv < size) { + size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE); + ssize_t n = recv(fd, (char *)data + bytes_recv, size_to_recv, 0); + if (n < 0) { + GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n", + bytes_recv, size_to_recv); + return false; + } + if (n == 0) { + LOG_DBG("recv returned 0 (peer closed?)\n"); + return false; + } + bytes_recv += (size_t)n; + } + return true; +} + +void socket_t::impl::get_caps(uint8_t * local_caps) { + memset(local_caps, 0, RPC_CONN_CAPS_SIZE); +#ifdef GGML_RPC_RDMA + rdma_local = {}; + if (rdma_probe()) { + rdma_caps rc = {}; + rc.qpn = rdma_local.qpn; + rc.psn = rdma_local.psn; + memcpy(rc.gid, rdma_local.gid, RDMA_GID_SIZE); + memcpy(local_caps, &rc, sizeof(rc)); + } else { + rdma.reset(); + } +#endif // GGML_RPC_RDMA +} + +void socket_t::impl::update_caps(const uint8_t * remote_caps) { +#ifdef GGML_RPC_RDMA + if (!rdma) { + return; + } + rdma_caps rc = {}; + memcpy(&rc, remote_caps, sizeof(rc)); + if (rc.qpn == 0) { + rdma.reset(); + return; + } + if (rdma_activate(rc.qpn, rc.psn, rc.gid)) { + use_rdma = true; + } else { + GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n"); + rdma.reset(); + } +#else + (void)remote_caps; +#endif // GGML_RPC_RDMA +} + + +///////////////////////////////////////////////////////////////////////////// + +socket_t::socket_t(std::unique_ptr p) : pimpl(std::move(p)) {} + +socket_t::~socket_t() = default; + +bool socket_t::send_data(const void * data, size_t size) { + return pimpl->send_data(data, size); +} + +bool socket_t::recv_data(void * data, size_t size) { + return pimpl->recv_data(data, size); +} + +void socket_t::get_caps(uint8_t * local_caps) { + return pimpl->get_caps(local_caps); +} + +void socket_t::update_caps(const uint8_t * remote_caps) { + return pimpl->update_caps(remote_caps); +} + +static bool is_valid_fd(sockfd_t sockfd) { +#ifdef _WIN32 + return sockfd != INVALID_SOCKET; +#else + return sockfd >= 0; +#endif +} + +static bool set_no_delay(sockfd_t sockfd) { + int flag = 1; + // set TCP_NODELAY to disable Nagle's algorithm + int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); + return ret == 0; +} + +static bool set_reuse_addr(sockfd_t sockfd) { + int flag = 1; + int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int)); + return ret == 0; +} + +socket_ptr socket_t::accept() { + auto client_socket_fd = ::accept(pimpl->fd, NULL, NULL); + if (!is_valid_fd(client_socket_fd)) { + return nullptr; + } + if (!set_no_delay(client_socket_fd)) { + GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(client_socket_fd))); +} + +socket_ptr socket_t::create_server(const char * host, int port) { + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (!is_valid_fd(sockfd)) { + return nullptr; + } + if (!set_reuse_addr(sockfd)) { + GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n"); + return nullptr; + } + if (inet_addr(host) == INADDR_NONE) { + GGML_LOG_ERROR("Invalid host address: %s\n", host); + return nullptr; + } + struct sockaddr_in serv_addr; + serv_addr.sin_family = AF_INET; + serv_addr.sin_addr.s_addr = inet_addr(host); + serv_addr.sin_port = htons(port); + + if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { + return nullptr; + } + if (listen(sockfd, 1) < 0) { + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(sockfd))); +} + +socket_ptr socket_t::connect(const char * host, int port) { + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (!is_valid_fd(sockfd)) { + return nullptr; + } + if (!set_no_delay(sockfd)) { + GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); + return nullptr; + } + struct sockaddr_in addr; + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + struct hostent * server = gethostbyname(host); + if (server == NULL) { + GGML_LOG_ERROR("Cannot resolve host '%s'\n", host); + return nullptr; + } + memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length); + if (::connect(sockfd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(sockfd))); +} + +#ifdef _WIN32 +static std::mutex g_rpc_transport_mu; +static bool g_rpc_transport_wsa_started = false; +#endif + +bool rpc_transport_init() { +#ifdef _WIN32 + std::lock_guard lock(g_rpc_transport_mu); + if (g_rpc_transport_wsa_started) { + return true; + } + WSADATA wsaData; + int res = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (res != 0) { + return false; + } + g_rpc_transport_wsa_started = true; + return true; +#else + return true; +#endif +} + +void rpc_transport_shutdown() { +#ifdef _WIN32 + std::lock_guard lock(g_rpc_transport_mu); + if (!g_rpc_transport_wsa_started) { + return; + } + WSACleanup(); + g_rpc_transport_wsa_started = false; +#endif +} diff --git a/ggml/src/ggml-rpc/transport.h b/ggml/src/ggml-rpc/transport.h new file mode 100644 index 000000000..73b85cc53 --- /dev/null +++ b/ggml/src/ggml-rpc/transport.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include + +struct socket_t; +typedef std::shared_ptr socket_ptr; + +static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB +static constexpr size_t RPC_CONN_CAPS_SIZE = 24; + +struct socket_t { + ~socket_t(); + + bool send_data(const void * data, size_t size); + bool recv_data(void * data, size_t size); + + socket_ptr accept(); + + void get_caps(uint8_t * local_caps); + void update_caps(const uint8_t * remote_caps); + + static socket_ptr create_server(const char * host, int port); + static socket_ptr connect(const char * host, int port); + +private: + struct impl; + explicit socket_t(std::unique_ptr p); + std::unique_ptr pimpl; +}; + +bool rpc_transport_init(); +void rpc_transport_shutdown(); From 455d8e4be85aaae75195435a513a2a5dbef0136c Mon Sep 17 00:00:00 2001 From: Sascha Rogmann <59577610+srogmann@users.noreply.github.com> Date: Sun, 19 Apr 2026 09:24:06 +0200 Subject: [PATCH 11/24] server : speculative checkpointing (#19493) * server : speculative decoding using checkpoints * server : fix draft check with checkpoints * server : rename spec vars * server : log levels * server : refactored spec logic to speculative.cpp * server : renamed spec checkpoints option * server : fix spec checkpoints, logging * speculative : checkpoints with draft model, logging * server : n_tokens_cur and create_checkpoint in draft * server : fix server_speculative_callback (slot.id) * spec : fix ngram-map/begin idx_last_check * spec : init ckpt (begin() wasn't called) * chore: update webui build output * server : restore sampler in spec checkpoint and clear mem * cont : avoid --spec-use-checkpoints argument * cont : remove server_prompt_checkpoint_with_size * spec : rename (leave_draft_state) * cont : clean-up * cont : do not ignore partial drafts even if the are short * cont : spec callback owned by session * cont : simplify * cont : avoid empty speculative session * cont : simplify * cont : simplify * cont : enable mtmd speculative decoding * cont : keep the spec sampler alive * cont : simplify * cont : fix nullptr deref + draft checkpoints * cont : remove common_speculative_accept_response * cont : remove callback * cont : simplify * cont : minor * cont : simplify * cont : fix accepted number --------- Co-authored-by: Georgi Gerganov --- common/chat.cpp | 2 +- common/common.h | 4 +- common/ngram-map.cpp | 8 +- common/speculative.cpp | 162 ++++++++++++-- common/speculative.h | 14 +- tools/server/server-common.cpp | 16 +- tools/server/server-common.h | 4 +- tools/server/server-context.cpp | 376 ++++++++++++++++++++------------ tools/server/server-task.cpp | 4 +- tools/server/server-task.h | 11 + 10 files changed, 421 insertions(+), 180 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index e27b6c341..e424206af 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -2334,7 +2334,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars ? input : params.generation_prompt + input; - LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str()); + //LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str()); common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_LENIENT; if (params.debug) { diff --git a/common/common.h b/common/common.h index 81c269556..4c36e85e0 100644 --- a/common/common.h +++ b/common/common.h @@ -11,7 +11,6 @@ #include #include #include -#include #include #include @@ -303,7 +302,7 @@ struct common_params_speculative { // general-purpose speculative decoding parameters int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding - int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding + int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding float p_split = 0.1f; // speculative decoding split probability float p_min = 0.75f; // minimum speculative decoding probability (greedy) @@ -312,6 +311,7 @@ struct common_params_speculative { uint16_t ngram_size_n = 12; // ngram size for lookup uint16_t ngram_size_m = 48; // mgram size for speculative tokens uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed + bool use_checkpoints = false; // use checkpoints to rewind in token history of recurrent models std::shared_ptr ngram_mod; diff --git a/common/ngram-map.cpp b/common/ngram-map.cpp index ebf771a24..8e3978f7e 100644 --- a/common/ngram-map.cpp +++ b/common/ngram-map.cpp @@ -208,7 +208,7 @@ void common_ngram_map_begin( count_keys, count_keys_del, count_values_del, count_map_entries_upd); } - map.idx_last_check = (map.size_last_begin > 0) ? map.size_last_begin - 1 : 0; + map.idx_last_check = size_begin; map.size_last_begin = size_begin; } @@ -231,7 +231,7 @@ void common_ngram_map_draft(common_ngram_map & map, GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len); } - if (map.idx_last_check > cur_len) { + if (map.idx_last_check > cur_len) { // Should not happen because of common_ngram_map_begin(). GGML_ABORT("%s: map.idx_last_check > cur_len: %zu > %zu", __func__, map.idx_last_check, cur_len); } @@ -386,7 +386,7 @@ void common_ngram_map_draft(common_ngram_map & map, LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__, curr_key.key_idx, key_offset, curr_key.key_num, draft.size()); - map.last_draft_created = false; + map.last_draft_created = true; map.last_draft_key_idx = key_offset; map.last_draft_value_idx = 0; // value 0 is used for simple mode return; @@ -524,7 +524,7 @@ void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) { struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation. // update the value statistics - LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n", + LOG_DBG("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n", n_accepted, curr_value.n_accepted); curr_value.n_accepted = n_accepted; } diff --git a/common/speculative.cpp b/common/speculative.cpp index 3e68c38e4..1789560ee 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 @@ -144,10 +145,28 @@ struct common_speculative_state { virtual void accept(uint16_t n_accepted) = 0; }; +struct common_speculative_checkpoint { + llama_pos pos_min = 0; + llama_pos pos_max = 0; + + int64_t n_tokens = 0; + + std::vector data; + + size_t size() const { + return data.size(); + } + + size_t ckpt_size = 0; +}; + struct common_speculative_state_draft : public common_speculative_state { llama_context * ctx_tgt; // only used for retokenizing from ctx_dft llama_context * ctx_dft; + struct common_speculative_checkpoint ckpt; + bool use_checkpoint; + common_sampler * smpl; llama_batch batch; @@ -160,10 +179,12 @@ struct common_speculative_state_draft : public common_speculative_state { enum common_speculative_type type, llama_context * ctx_tgt, llama_context * ctx_dft, - const std::vector> & replacements) + const std::vector> & replacements, + bool use_checkpoint) : common_speculative_state(type) , ctx_tgt(ctx_tgt) , ctx_dft(ctx_dft) + , use_checkpoint(use_checkpoint) { batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); smpl = nullptr; @@ -218,7 +239,48 @@ struct common_speculative_state_draft : public common_speculative_state { } void begin(const llama_tokens & prompt) override { - GGML_UNUSED(prompt); + if (use_checkpoint && ckpt.size() > 0) { + // delete checkpoint + LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n", + __func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024); + ckpt.pos_min = 0; + ckpt.pos_max = 0; + ckpt.n_tokens = 0; + ckpt.ckpt_size = 0; + ckpt.data.clear(); + } + } + + size_t draft_create_checkpoint(int n_tokens_prompt, int n_tokens_batch) { + int slot_id = 0; + const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id); + ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id); + ckpt.n_tokens = n_tokens_prompt - n_tokens_batch; + ckpt.data.resize(checkpoint_size); + + const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + if (n != checkpoint_size) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); + } + + LOG_DBG("%s: pos_min = %d, pos_max = %d, size = %.3f MiB\n", __func__, + ckpt.pos_min, ckpt.pos_max, (float) ckpt.data.size() / 1024 / 1024); + return n; + } + + size_t draft_restore_checkpoint(size_t ckpt_size_part_expected) { + int slot_id = 0; + LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max); + const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + if (n != ckpt_size_part_expected) { + GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", + __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n); + } + llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1); + + return n; } void draft( @@ -236,8 +298,8 @@ struct common_speculative_state_draft : public common_speculative_state { auto * mem_dft = llama_get_memory(ctx_dft); - int reuse_i = 0; - int reuse_n = 0; + int reuse_i = 0; // index of part to be reused in prompt_dft + int reuse_n = 0; // length of part to be reused in prompt_dft const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max; @@ -287,18 +349,26 @@ struct common_speculative_state_draft : public common_speculative_state { } } - LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size()); + LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n", + __func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size()); + if (use_checkpoint && ckpt.ckpt_size == 0 && reuse_n > 0) { + LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n", + __func__, reuse_i, reuse_n); + reuse_i = 0; + reuse_n = 0; + } result.clear(); result.reserve(params.n_max); - if (reuse_n == 0) { + bool needs_ckpt = use_checkpoint && prompt_dft.size() > 0; + if (reuse_n == 0 || (use_checkpoint && reuse_i > 0)) { llama_memory_clear(mem_dft, false); prompt_dft.clear(); } else { // this happens when a previous draft has been discarded (for example, due to being too small), but the // target model agreed with it. in this case, we simply pass back the previous results to save compute - if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { + if (reuse_i + reuse_n < (int64_t) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { result.push_back(prompt_dft[i]); @@ -310,19 +380,50 @@ struct common_speculative_state_draft : public common_speculative_state { return; } + bool do_restore = false; + if (prompt_dft.size() > prompt_cur.size() && reuse_i + reuse_n < (int64_t) prompt_dft.size()) { + // This can happen after a partial acceptance (speculative decoding with checkpoints) + LOG_DBG("%s: #prompt_dft=%zu, #prompt_cur=%zu, shorten draft\n", + __func__, prompt_dft.size(), prompt_cur.size()); + prompt_dft.resize(prompt_cur.size()); + do_restore = true; + } + if (reuse_i > 0) { - llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); + bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); + if (!is_removed) { + LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i); + } llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i); prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); } - if (reuse_n < (int) prompt_dft.size()) { - llama_memory_seq_rm (mem_dft, 0, reuse_n, -1); - prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); + if (reuse_n < (int) prompt_dft.size() || do_restore) { + if (use_checkpoint) { + if (ckpt.n_tokens > (int64_t) prompt_dft.size()) { + LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n", + __func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size()); + } + draft_restore_checkpoint(ckpt.ckpt_size); + reuse_n = ckpt.n_tokens; + prompt_dft.resize(reuse_n); + needs_ckpt = false; + } else { + bool is_removed = llama_memory_seq_rm (mem_dft, 0, reuse_n, -1); + if (!is_removed) { + LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", + __func__, reuse_n, prompt_dft.size()); + } + prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); + } } } + if (needs_ckpt) { + ckpt.ckpt_size = draft_create_checkpoint(prompt_dft.size(), batch.n_tokens); + } + // prepare a batch to evaluate any new tokens in the prompt common_batch_clear(batch); @@ -337,7 +438,11 @@ struct common_speculative_state_draft : public common_speculative_state { if (batch.n_tokens > 0) { //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); - llama_decode(ctx_dft, batch); + int ret = llama_decode(ctx_dft, batch); + if (ret != 0 && ret != 1) { + LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n", + __func__, ret, prompt_cur.size()); + } } const llama_pos n_past = prompt_dft.size(); @@ -351,7 +456,11 @@ struct common_speculative_state_draft : public common_speculative_state { LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); - llama_decode(ctx_dft, batch); + int ret = llama_decode(ctx_dft, batch); + if (ret != 0 && ret != 1) { + LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n", + __func__, ret, prompt_cur.size(), prompt_dft.size()); + } common_sampler_reset(smpl); @@ -387,7 +496,11 @@ struct common_speculative_state_draft : public common_speculative_state { common_batch_add(batch, id, n_past + i + 1, { 0 }, true); // evaluate the drafted tokens on the draft model - llama_decode(ctx_dft, batch); + ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode[%d] returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n", + __func__, i, ret, prompt_cur.size(), prompt_dft.size()); + } prompt_dft.push_back(id); } @@ -739,6 +852,7 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { struct common_speculative { std::vector> impls; // list of implementations to use and their states + common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) }; @@ -798,13 +912,13 @@ enum common_speculative_type common_speculative_type_from_name(const std::string return it->second; } -bool common_speculative_is_compat(llama_context * ctx_tgt) { +common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt) { auto * mem = llama_get_memory(ctx_tgt); if (mem == nullptr) { - return false; + return COMMON_SPECULATIVE_COMPAT_TYPE_NO; } - bool res = true; + common_speculative_compat_type res = COMMON_SPECULATIVE_COMPAT_TYPE_FULL; llama_memory_clear(mem, true); @@ -816,14 +930,14 @@ bool common_speculative_is_compat(llama_context * ctx_tgt) { int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size())); if (ret != 0) { LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret); - res = false; + res = COMMON_SPECULATIVE_COMPAT_TYPE_NO; goto done; } // try to remove the last tokens if (!llama_memory_seq_rm(mem, 0, 1, -1)) { LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); - res = false; + res = COMMON_SPECULATIVE_COMPAT_TYPE_CKPT; goto done; } @@ -909,9 +1023,10 @@ common_speculative * common_speculative_init( break; case COMMON_SPECULATIVE_TYPE_DRAFT: { impls.push_back(std::make_unique(config.type, - /* .ctx_tgt = */ ctx_tgt, - /* .ctx_dft = */ ctx_dft, - /* .replacements = */ params.replacements + /* .ctx_tgt = */ ctx_tgt, + /* .ctx_dft = */ ctx_dft, + /* .replacements = */ params.replacements, + /* .use_checkpoint= */ params.use_checkpoints // TODO: this should be based on the draft model! )); break; } @@ -966,7 +1081,8 @@ common_speculative * common_speculative_init( } auto * result = new common_speculative { - /* .impls = */ std::move(impls) + /* .impls = */ std::move(impls), + /* .curr_impl = */ nullptr, }; return result; diff --git a/common/speculative.h b/common/speculative.h index 876cde3d1..cbe6e5bdb 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -14,9 +14,15 @@ enum common_speculative_type common_speculative_type_from_name(const std::string // convert type to string std::string common_speculative_type_to_str(enum common_speculative_type type); +enum common_speculative_compat_type { + COMMON_SPECULATIVE_COMPAT_TYPE_NO = 0, + COMMON_SPECULATIVE_COMPAT_TYPE_FULL = 1, + COMMON_SPECULATIVE_COMPAT_TYPE_CKPT = 2, +}; + // check if the llama_context is compatible for speculative decoding // note: clears the memory of the context -bool common_speculative_is_compat(llama_context * ctx_tgt); +common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt); common_speculative * common_speculative_init( common_params_speculative & params, @@ -39,3 +45,9 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted); // print statistics about the speculative decoding void common_speculative_print_stats(const common_speculative * spec); + +struct common_speculative_deleter { + void operator()(common_speculative * s) { common_speculative_free(s); } +}; + +typedef std::unique_ptr common_speculative_ptr; diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index f66b1f255..cae64884b 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -391,15 +391,25 @@ void server_tokens::push_back(server_tokens & tokens) { } void server_tokens::insert(const llama_tokens & inp_tokens) { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); } -const llama_tokens & server_tokens::get_text_tokens() const { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled +const llama_tokens & server_tokens::get_tokens() const { + GGML_ASSERT(!has_mtmd); return tokens; } +llama_tokens server_tokens::get_text_tokens() const { + llama_tokens res; + res.reserve(tokens.size()); + for (llama_token t : tokens) { + if (t != LLAMA_TOKEN_NULL) { + res.push_back(t); + } + } + return res; +} + void server_tokens::set_token(llama_pos pos, llama_token id) { GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled tokens[pos] = id; diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 57545aa53..093a43453 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -190,7 +190,9 @@ public: void insert(const llama_tokens & inp_tokens); // for compatibility with speculative decoding, ctx shift, slot save/load - const llama_tokens & get_text_tokens() const; + const llama_tokens & get_tokens() const; + + llama_tokens get_text_tokens() const; // for compatibility with speculative decoding void set_token(llama_pos pos, llama_token id); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 2488f81b8..70ebcc225 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1,3 +1,4 @@ + #include "server-context.h" #include "server-common.h" #include "server-http.h" @@ -19,6 +20,7 @@ #include #include #include +#include // fix problem with std::min and std::max #if defined(_WIN32) @@ -33,6 +35,31 @@ using json = nlohmann::ordered_json; constexpr int HTTP_POLLING_SECONDS = 1; +static server_prompt_checkpoint server_get_checkpoint(llama_context * ctx, int id, int64_t n_tokens, llama_pos pos_min = -1, llama_pos pos_max = -1) { + if (pos_min == -1) { + pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), id); + } + if (pos_max == -1) { + pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), id); + } + + const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + auto cur = server_prompt_checkpoint { + /*.pos_min = */ pos_min, + /*.pos_max = */ pos_max, + /*.n_tokens = */ n_tokens, + /*.data = */ std::vector(checkpoint_size), + }; + + const size_t n = llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + if (n != checkpoint_size) { + GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n); + } + + return cur; +} + // state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 enum slot_state { SLOT_STATE_IDLE, @@ -57,7 +84,12 @@ struct server_slot { // multimodal mtmd_context * mctx = nullptr; - common_speculative * spec = nullptr; + // speculative decoding + llama_tokens spec_draft; + std::vector spec_i_batch; + server_prompt_checkpoint spec_ckpt; + common_speculative_ptr spec; + // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state // see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837 @@ -83,11 +115,6 @@ struct server_slot { std::string debug_generated_text; llama_tokens generated_tokens; - // idx of draft tokens in the main batch - // non-empty if we went to evaluate draft tokens - // ref: https://github.com/ggml-org/llama.cpp/pull/17808 - std::vector i_batch_dft; - std::vector generated_token_probs; bool has_next_token = true; @@ -147,8 +174,7 @@ struct server_slot { common_sampler_ptr smpl; - llama_token sampled; // in speculative mode, this is the last accepted token - llama_tokens drafted; + llama_token sampled; // in speculative mode, this is the last accepted token // stats size_t n_sent_text = 0; // number of sent text character @@ -178,8 +204,11 @@ struct server_slot { stopping_word = ""; n_sent_text = 0; - drafted.clear(); - i_batch_dft.clear(); + if (can_speculate()) { + spec_draft.clear(); + spec_i_batch.clear(); + spec_ckpt.clear(); + } generated_tokens.clear(); generated_token_probs.clear(); json_schema = json(); @@ -300,6 +329,85 @@ struct server_slot { return n_draft_max; } + void update_batch(llama_batch & batch) { + const int n_draft_max = get_n_draft_max(); + if (n_draft_max > 0) { + GGML_ASSERT(can_speculate()); + + // generate draft tokens in speculative decoding mode + // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] + // perform the speculative drafting for all sequences at the same time in a single batch + const llama_tokens & tokens = prompt.tokens.get_text_tokens(); + + const auto & params_spec = task->params.speculative; + + if (!spec_draft.empty()) { + // we have a previous (partial) draft to reuse + if (task->params.speculative.use_checkpoints) { + GGML_ASSERT(!spec_ckpt.empty()); + } + } else { + GGML_ASSERT(spec_i_batch.empty()); + + // generate a new draft + spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled); + + if (spec_draft.size() > (size_t) n_draft_max) { + SLT_WRN(*this, "draft size %d exceeds max %d, truncating\n", (int) spec_draft.size(), n_draft_max); + spec_draft.resize(n_draft_max); + } + + if (spec_draft.size() < (size_t) params_spec.n_min) { + SLT_DBG(*this, "ignoring small draft: %d < %d\n", (int) spec_draft.size(), params_spec.n_min); + spec_draft.clear(); + } + + if (!spec_draft.empty() && params_spec.use_checkpoints) { + const auto n_tokens = prompt.tokens.size(); + + auto & ckpt = spec_ckpt; + + ckpt = server_get_checkpoint(ctx, this->id, n_tokens); + + SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n", + ckpt.pos_min, ckpt.pos_max, n_tokens, (float) ckpt.data.size() / 1024 / 1024); + } + } + + GGML_ASSERT(spec_draft.size() <= (size_t) n_draft_max); + } + + if (spec_draft.empty()) { + // no speculative decoding + i_batch = batch.n_tokens; + + common_batch_add(batch, sampled, prompt.tokens.pos_next(), { this->id }, true); + + SLT_DBG(*this, "slot decode token, id=%d, n_ctx = %d, n_tokens = %d, truncated = %d\n", + sampled, n_ctx, prompt.n_tokens(), truncated); + } else { + SLT_DBG(*this, "generate_draft: id=%d, #tokens=%zu, #draft=%zu, pos_next=%d\n", + sampled, prompt.tokens.size(), spec_draft.size(), prompt.tokens.pos_next()); + + GGML_ASSERT(spec_i_batch.empty()); + + spec_i_batch.push_back(batch.n_tokens); + for (size_t i = 0; i < spec_draft.size(); i++) { + spec_i_batch.push_back(batch.n_tokens + i + 1); + } + + auto pos0 = prompt.tokens.pos_next(); + + common_batch_add(batch, sampled, pos0++, { this->id }, true); + for (auto token : spec_draft) { + common_batch_add(batch, token, pos0++, { this->id }, true); + } + } + + prompt.tokens.push_back(sampled); + prompt.tokens.insert(spec_draft); + } + void release() { if (is_processing()) { GGML_ASSERT(task); @@ -400,7 +508,7 @@ struct server_slot { ); } - common_speculative_print_stats(spec); + common_speculative_print_stats(spec.get()); } json to_json(bool only_metrics = false) const { @@ -591,16 +699,17 @@ private: void destroy() { llama_init.reset(); + ctx = nullptr; model = nullptr; mtmd_free(mctx); mctx = nullptr; - // Clear any sampling context for (server_slot & slot : slots) { - common_speculative_free(slot.spec); - slot.spec = nullptr; + if (slot.can_speculate()) { + slot.spec.reset(); + } } llama_batch_free(batch); @@ -642,9 +751,6 @@ private: llama_init = common_init_from_params(params_base); - // propagate model-metadata sampling defaults back to caller - params.sampling = params_base.sampling; - model = llama_init->model(); ctx = llama_init->context(); @@ -660,6 +766,7 @@ private: add_bos_token = llama_vocab_get_add_bos(vocab); if (params_base.speculative.has_dft()) { + // TODO speculative: move to common/speculative.cpp? SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str()); const auto & params_spec = params_base.speculative; @@ -727,11 +834,6 @@ private: params_base.n_cache_reuse = 0; SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); } - - if (params_base.speculative.type != COMMON_SPECULATIVE_TYPE_NONE) { - params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE; - SRV_WRN("%s\n", "speculative decoding is not supported by multimodal, it will be disabled"); - } } if (!llama_memory_can_shift(llama_get_memory(ctx))) { @@ -769,14 +871,23 @@ private: slots.clear(); - const bool can_spec = common_speculative_is_compat(ctx); - if (!can_spec) { + const auto spec_type = common_speculative_is_compat(ctx); + if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_NO) { SRV_WRN("%s", "speculative decoding not supported by this context\n"); } + if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_CKPT) { + SRV_WRN("%s", "speculative decoding will use checkpoints\n"); + params_base.speculative.use_checkpoints = true; + } + // initialize slots for (int i = 0; i < params_base.n_parallel; i++) { - server_slot slot; + slots.emplace_back(); + } + + for (int i = 0; i < params_base.n_parallel; i++) { + server_slot & slot = slots[i]; slot.id = i; slot.ctx = ctx; @@ -786,16 +897,11 @@ private: slot.prompt.tokens.has_mtmd = mctx != nullptr; // try speculative decoding - if (can_spec) { - slot.spec = common_speculative_init(params_base.speculative, slot.ctx); + if (spec_type != COMMON_SPECULATIVE_COMPAT_TYPE_NO) { + slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx)); + if (slot.spec) { - if (mctx) { - SRV_ERR("%s\n", "speculative decoding is not supported with multimodal"); - return false; - } SLT_INF(slot, "%s", "speculative decoding context initialized\n"); - } else { - SLT_INF(slot, "%s", "speculative decoding context not initialized\n"); } } @@ -806,8 +912,6 @@ private: }; slot.reset(); - - slots.push_back(std::move(slot)); } { @@ -854,6 +958,9 @@ private: model_aliases = params_base.model_alias; model_tags = params_base.model_tags; + // propagate new defaults back to caller + params = params_base; + if (!is_resume) { return init(); } @@ -1197,7 +1304,7 @@ private: backend_sampling &= task.params.sampling.backend_sampling; // TODO: speculative decoding requires multiple samples per batch - not supported yet - backend_sampling &= !(slot.spec && task.params.speculative.n_max > 0); + backend_sampling &= !(slot.can_speculate() && task.params.speculative.n_max > 0); // TODO: getting post/pre sampling logits is not yet supported with backend sampling backend_sampling &= !need_logits; @@ -1703,6 +1810,26 @@ private: return true; } + // n_tokens_cur: the number of tokens added to the batch for the current slot + void create_checkpoint(server_slot & slot, const int64_t n_tokens_cur, llama_pos pos_min, llama_pos pos_max) { + while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { + // make room for the new checkpoint, if needed + const auto & cur = slot.prompt.checkpoints.front(); + + SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + + slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); + } + + const auto & cur = slot.prompt.checkpoints.emplace_back(server_get_checkpoint(ctx, slot.id, slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max)); + + SLT_WRN(slot, + "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", + (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, + cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + } + void process_single_task(server_task && task) { switch (task.type) { case SERVER_TASK_TYPE_COMPLETION: @@ -1854,7 +1981,7 @@ private: std::string filename = task.slot_action.filename; std::string filepath = task.slot_action.filepath; - const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens(); + const llama_tokens & tokens = slot->prompt.tokens.get_tokens(); const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); const int64_t t_end = ggml_time_us(); @@ -2061,7 +2188,7 @@ private: { GGML_ASSERT(!slot.prompt.tokens.has_mtmd); - llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy + llama_tokens new_tokens = slot.prompt.tokens.get_tokens(); // copy for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { new_tokens[i - n_discard] = new_tokens[i]; } @@ -2100,61 +2227,7 @@ private: continue; } - // generate draft tokens in speculative decoding mode - // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] - // perform the speculative drafting for all sequences at the same time in a single batch - const int n_draft_max = slot.get_n_draft_max(); - if (n_draft_max > 0) { - if (mctx) { - // we should never reach this, as speculative is automatically disabled if mmproj is loaded - GGML_ABORT("not supported by multimodal"); - } - - const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); - - const auto & params_spec = slot.task->params.speculative; - - llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled); - - if (draft.size() > (size_t) n_draft_max) { - SLT_WRN(slot, "draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max); - draft.resize(n_draft_max); - } - - // add the sampled token to the batch - slot.i_batch_dft.push_back(batch.n_tokens); - common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); - slot.prompt.tokens.push_back(slot.sampled); - - if (slot.task->params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); - // fallback to normal decoding - slot.i_batch = slot.i_batch_dft[0]; - slot.drafted.clear(); - slot.i_batch_dft.clear(); - } else { - // keep track of total number of drafted tokens tested - slot.n_draft_total += draft.size(); - - // add all drafted tokens to the batch - for (size_t i = 0; i < draft.size(); i++) { - slot.i_batch_dft.push_back(batch.n_tokens); - common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true); - slot.prompt.tokens.push_back(draft[i]); - } - slot.drafted = std::move(draft); - } - } else { - // no speculative decoding - slot.i_batch = batch.n_tokens; - - common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); - - slot.prompt.tokens.push_back(slot.sampled); - - SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); - } + slot.update_batch(batch); } // process in chunks of params.n_batch @@ -2651,40 +2724,12 @@ private: // no need to create checkpoints that are too close together do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || slot.prompt.n_tokens() - n_tokens_cur > slot.prompt.checkpoints.back().n_tokens + 64); + SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max); // note: we create the checkpoint before calling llama_decode(), so the current batch is not // yet processed and therefore it is not part of the checkpoint. if (do_checkpoint) { - while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { - // make room for the new checkpoint, if needed - const auto & cur = slot.prompt.checkpoints.front(); - - SLT_WRN(slot, - "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 - ", size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); - - slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); - } - - const size_t checkpoint_size = - llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ - /*.pos_min = */ pos_min, - /*.pos_max = */ pos_max, - /*.n_tokens = */ slot.prompt.n_tokens() - n_tokens_cur, - /*.data = */ std::vector(checkpoint_size), - }); - - llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, - LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - SLT_WRN(slot, - "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 - ", size = %.3f MiB)\n", - (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, - cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + create_checkpoint(slot, n_tokens_cur, pos_min, pos_max); } } @@ -2856,19 +2901,19 @@ private: slot.state = SLOT_STATE_GENERATING; if (slot.can_speculate()) { - common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens()); + common_speculative_begin(slot.spec.get(), slot.prompt.tokens.get_text_tokens()); } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots } - if (slot.i_batch_dft.size() > 0) { + if (slot.can_speculate() && !slot.spec_draft.empty()) { continue; // sample using speculative decoding } const int tok_idx = slot.i_batch - i; - llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx); + llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx, tok_idx); slot.i_batch = -1; @@ -2889,7 +2934,7 @@ private: completion_token_output result; result.tok = id; - result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs if (slot.task->params.sampling.n_probs > 0) { @@ -2909,43 +2954,86 @@ private: // speculative decoding - main model sample and accept for (auto & slot : slots) { - if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty()) { + if (slot.state != SLOT_STATE_GENERATING || !slot.can_speculate() || slot.spec_draft.empty()) { continue; } - const size_t n_draft = slot.drafted.size(); + // save the original draft size + const size_t n_draft = slot.spec_draft.size(); - // the accepted tokens from the speculation - const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted); - slot.i_batch_dft.clear(); - slot.drafted.clear(); + GGML_ASSERT(n_draft > 0); + + // verify and try to accept the draft + { + const auto & params_spec = slot.task->params.speculative; + + common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get())); + + GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); + auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft); + slot.spec_i_batch.clear(); + + SLT_DBG(slot, "%s: n_draft=%zu, accepted=%zu\n", __func__, slot.spec_draft.size(), accepted.size()); + + GGML_ASSERT(accepted.size() >= 1); + + // check for partial draft acceptance + if (accepted.size() < slot.spec_draft.size() + 1) { + if (params_spec.use_checkpoints) { + // partial acceptance is not supported by the context -> truncate the draft and restore the state + slot.spec_draft = std::move(accepted); + + auto & ckpt = slot.spec_ckpt; + + SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size()); + + const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + if (n != ckpt.size()) { + GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", + __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n); + } + + llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, ckpt.pos_max + 1, -1); + + slot.prompt.tokens.keep_first(ckpt.n_tokens); + slot.smpl = std::move(smpl_save); + + continue; + } + + LOG_DBG("%s: partial acceptance: %zu < %zu\n", __func__, accepted.size(), slot.spec_draft.size()); + } + + common_speculative_accept(slot.spec.get(), accepted.size() - 1); + + slot.spec_draft = std::move(accepted); + } const int64_t t_current = ggml_time_us(); - slot.n_decoded += ids.size(); + const auto ids = std::move(slot.spec_draft); + slot.n_decoded += ids.size(); slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; // update how many tokens out of those tested were accepted slot.n_draft_accepted += ids.size() - 1; - - // inform the speculative decoding about the number of accepted tokens - common_speculative_accept(slot.spec, ids.size() - 1); - - // rollback to the state before sampling the draft tokens - slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); + slot.n_draft_total += n_draft; // add accepted tokens to the prompt + slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); - slot.sampled = ids.back(); // last accepted token - llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1); + slot.sampled = ids.back(); // last accepted token + SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); + + llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, slot.prompt.n_tokens(), -1); for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; result.tok = ids[i]; - result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.text_to_send = common_token_to_piece(slot.ctx, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // set later // TODO: set result.probs @@ -3665,7 +3753,7 @@ void server_routes::init_routes() { params.n_predict, meta->slot_n_ctx, params.spm_infill, - tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal. + tokenized_prompts[0].get_tokens() // TODO: this could maybe be multimodal. ); std::vector files; // dummy diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 4fb953b49..2187b8d21 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -162,7 +162,7 @@ common_chat_msg task_result_state::update_chat_msg( bool filter_tool_calls) { generated_text += text_added; auto msg_prv_copy = chat_msg; - SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); + //SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); auto new_msg = common_chat_parse( generated_text, is_partial, @@ -304,6 +304,8 @@ task_params server_task::params_from_json_cmpl( params.sampling.backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling); params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); + params.speculative = defaults.speculative; + params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 95f39207b..289e1fb8d 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -576,6 +576,17 @@ struct server_prompt_checkpoint { size_t size() const { return data.size(); } + + bool empty() const { + return data.empty(); + } + + void clear() { + pos_min = 0; + pos_max = 0; + n_tokens = 0; + data.clear(); + } }; struct server_prompt { From 09b4efa95f8223f6d168f046c80fe4529a68117b Mon Sep 17 00:00:00 2001 From: texasich <101962694+texasich@users.noreply.github.com> Date: Sun, 19 Apr 2026 02:25:05 -0500 Subject: [PATCH 12/24] cmake: remove CMP0194 policy to restore MSVC builds (#21934) #21630 added the CMP0194 NEW policy to silence a CMake warning, but on Windows runners it caused CMake to prefer the MinGW toolchain for ASM and broke MSVC builds. Reverting only that policy block restores the previous working behavior. The CMake 4.1+ warning comes back, but that is cosmetic and does not break any platform. Reported-by: oobabooga Refs: #21630 Co-authored-by: texasich --- ggml/CMakeLists.txt | 6 ------ 1 file changed, 6 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 6b65ecd6e..a0eb9204e 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -1,11 +1,5 @@ cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories. -# ref: https://cmake.org/cmake/help/latest/policy/CMP0194.html -# MSVC is not a valid assembler for the ASM language. -# Set to NEW to avoid a warning on CMake 4.1+ with MSVC. -if (POLICY CMP0194) - cmake_policy(SET CMP0194 NEW) -endif() project("ggml" C CXX ASM) ### GGML Version From 8685e7b07582957924393d912fffe6f4588e235e Mon Sep 17 00:00:00 2001 From: Dowon Date: Sun, 19 Apr 2026 16:25:39 +0900 Subject: [PATCH 13/24] convert : support sentence-transformer 5.4 config files (#22087) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * convert : support sentence-transformer 5.4 config files * fix: embeddinggemma * fix: mapping Co-authored-by: Sigbjørn Skjæret * fix: pooling_mode Co-authored-by: Sigbjørn Skjæret --------- Co-authored-by: Sigbjørn Skjæret --- convert_hf_to_gguf.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2df5e94fe..5b4fb79fc 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1850,20 +1850,28 @@ class TextModel(ModelBase): with open(module_path, encoding="utf-8") as f: modules = json.load(f) for mod in modules: - if mod["type"] == "sentence_transformers.models.Pooling": + if mod["type"].endswith("Pooling"): pooling_path = mod["path"] break + mode_mapping = { + "mean": gguf.PoolingType.MEAN, + "cls": gguf.PoolingType.CLS, + "lasttoken": gguf.PoolingType.LAST, + } + # get pooling type if pooling_path is not None: with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f: pooling = json.load(f) - if pooling["pooling_mode_mean_tokens"]: + if pooling.get("pooling_mode_mean_tokens"): pooling_type = gguf.PoolingType.MEAN - elif pooling["pooling_mode_cls_token"]: + elif pooling.get("pooling_mode_cls_token"): pooling_type = gguf.PoolingType.CLS - elif pooling["pooling_mode_lasttoken"]: + elif pooling.get("pooling_mode_lasttoken"): pooling_type = gguf.PoolingType.LAST + elif (pooling_mode := pooling.get("pooling_mode")) in mode_mapping: + pooling_type = mode_mapping[pooling_mode] else: raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported") self.gguf_writer.add_pooling_type(pooling_type) @@ -7180,7 +7188,7 @@ class EmbeddingGemma(Gemma3Model): with open(modules_file, encoding="utf-8") as modules_json_file: mods = json.load(modules_json_file) for mod in mods: - if mod["type"] == "sentence_transformers.models.Dense": + if mod["type"].endswith("Dense"): mod_path = mod["path"] # check if model.safetensors file for Dense layer exists model_tensors_file = self.dir_model / mod_path / "model.safetensors" From 037bfe38d0297001869df87150286952ae94cb1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sun, 19 Apr 2026 09:32:08 +0200 Subject: [PATCH 14/24] ci : install spirv-headers for vulkan-cross (#22109) --- .github/workflows/build-cross.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build-cross.yml b/.github/workflows/build-cross.yml index 74508129a..aef45afde 100644 --- a/.github/workflows/build-cross.yml +++ b/.github/workflows/build-cross.yml @@ -246,6 +246,7 @@ jobs: apt-get install -y --no-install-recommends \ build-essential \ glslc \ + spirv-headers \ gcc-14-loongarch64-linux-gnu \ g++-14-loongarch64-linux-gnu \ libvulkan-dev:loong64 From bcdcc1044ffea27ac531da14d947713ac4dad9fd Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Sun, 19 Apr 2026 15:18:35 +0530 Subject: [PATCH 15/24] ggml : reduce CPU overhead in meta backend (#22041) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * cache subgraph splits when cgraph is unchanged Skip per-call subgraph construction in ggml_backend_meta_graph_compute when the same ggml_cgraph is used consecutively. Assign uid to every sub-graph so that CUDA's fast uid check path hits too. * Address review comments * Keep the scope as is * Rename last_uid and last_n_subgraphs field. Remove last_max_tmp_size field. Refactor code. * Address review comments * Update ggml/src/ggml-backend-meta.cpp Co-authored-by: Johannes Gäßler * Update ggml/src/ggml-backend-meta.cpp Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-backend-meta.cpp | 307 +++++++++++++++++---------------- 1 file changed, 160 insertions(+), 147 deletions(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 24f6bc063..39651adc1 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1456,6 +1456,8 @@ struct ggml_backend_meta_context { int max_nnodes = 0; size_t max_tmp_size = 0; size_t max_subgraphs = 0; + size_t n_subgraphs = 0; + uint64_t uid = 0; void * comm_ctx = nullptr; ggml_backend_comm_allreduce_tensor_t comm_allreduce = nullptr; @@ -1616,6 +1618,9 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, const size_t n_backends = ggml_backend_meta_n_backends(backend); ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context; + // If the previous cgraph had a defined UID it can be used to skip rebuilding the subgraphs per simple backend. + const bool needs_rebuild = (cgraph->uid == 0) || (cgraph->uid != backend_ctx->uid); + bool max_nnodes_raised = false; if (cgraph->n_nodes > backend_ctx->max_nnodes) { for (size_t j = 0; j < n_backends; j++) { @@ -1625,173 +1630,181 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, } backend_ctx->max_nnodes = cgraph->n_nodes; max_nnodes_raised = true; + assert(needs_rebuild); } - for (size_t j = 0; j < n_backends; j++) { - auto & bcj = backend_ctx->backend_configs[j]; - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; - if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { - // FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes. - // For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash. - bcj.nodes[i] = node; - continue; + if (needs_rebuild) { + size_t n_subgraphs = 0; + size_t max_tmp_size = 0; + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { + // FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes. + // For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash. + bcj.nodes[i] = node; + continue; + } + bcj.nodes[i] = ggml_backend_meta_buffer_simple_tensor(node, j); + GGML_ASSERT(bcj.nodes[i]); } - bcj.nodes[i] = ggml_backend_meta_buffer_simple_tensor(node, j); - GGML_ASSERT(bcj.nodes[i]); } - } - size_t n_subgraphs = 0; - size_t max_tmp_size = 0; - { - // For MoE models it may make sense to delay the AllReduce in order to reduce I/O: - auto get_i_delayed = [&](const int i) -> int { - int id = i; // i_delayed - int idr = i; // i_delayed return, last safe return value + { + // For MoE models it may make sense to delay the AllReduce in order to reduce I/O: + auto get_i_delayed = [&](const int i) -> int { + int id = i; // i_delayed + int idr = i; // i_delayed return, last safe return value - ggml_tensor * node = cgraph->nodes[id]; - int32_t n_used = ggml_node_get_use_count(cgraph, id); - if (id + 1 >= cgraph->n_nodes) { - return idr; - } - { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op == GGML_OP_ADD_ID && next->src[0] == node && - ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL && - ggml_backend_meta_get_split_state(next->src[2], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { - node = next; + ggml_tensor * node = cgraph->nodes[id]; + int32_t n_used = ggml_node_get_use_count(cgraph, id); + if (id + 1 >= cgraph->n_nodes) { + return idr; + } + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op == GGML_OP_ADD_ID && next->src[0] == node && + ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL && + ggml_backend_meta_get_split_state(next->src[2], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + node = next; + id++; + idr = id; + n_used = ggml_node_get_use_count(cgraph, id); + } + } + if (id + 1 >= cgraph->n_nodes) { + return idr; + } + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op == GGML_OP_MUL && next->src[0] == node && + ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + node = next; + id++; + idr = id; + n_used = ggml_node_get_use_count(cgraph, id); + } + } + + if (n_used != node->ne[1] || id + 2*n_used-1 >= cgraph->n_nodes) { + return idr; + } + for (int32_t k = 0; k < n_used; k++) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_VIEW || next->view_src != node || next->view_offs != k*node->nb[1] || + next->ne[0] != node->ne[0] || next->ne[1] != node->ne[2] || next->nb[1] != node->nb[2] || + ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } id++; - idr = id; - n_used = ggml_node_get_use_count(cgraph, id); } - } - if (id + 1 >= cgraph->n_nodes) { - return idr; - } - { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op == GGML_OP_MUL && next->src[0] == node && - ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { - node = next; + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id - (n_used-1)] || + next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } id++; - idr = id; - n_used = ggml_node_get_use_count(cgraph, id); } - } - - if (n_used != node->ne[1] || id + 2*n_used-1 >= cgraph->n_nodes) { + for (int32_t k = 0; k < n_used - 2; k++) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id] || + next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } + id++; + } + idr = id; return idr; - } - for (int32_t k = 0; k < n_used; k++) { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op != GGML_OP_VIEW || next->view_src != node || next->view_offs != k*node->nb[1] || - next->ne[0] != node->ne[0] || next->ne[1] != node->ne[2] || next->nb[1] != node->nb[2] || - ggml_node_get_use_count(cgraph, id+1) != 1) { - return idr; - } - id++; - } - { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id - (n_used-1)] || - next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { - return idr; - } - id++; - } - for (int32_t k = 0; k < n_used - 2; k++) { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id] || - next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { - return idr; - } - id++; - } - idr = id; - return idr; - }; + }; - int i_start = 0; - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; - if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { - continue; - } - const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false); - if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { - max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node)); - } - const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL; - if (!new_subgraph) { - continue; - } + int i_start = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { + continue; + } + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { + max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node)); + } + const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL; + if (!new_subgraph) { + continue; + } - i = get_i_delayed(i); + i = get_i_delayed(i); + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + bcj.cgraphs[n_subgraphs].offset = i_start; + } + n_subgraphs++; + i_start = i + 1; + } + GGML_ASSERT(i_start == cgraph->n_nodes); + } + + backend_ctx->uid = cgraph->uid; + backend_ctx->n_subgraphs = n_subgraphs; + + if (max_tmp_size > backend_ctx->max_tmp_size) { for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; - bcj.cgraphs[n_subgraphs].offset = i_start; + bcj.buf.reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); } - n_subgraphs++; - i_start = i + 1; + backend_ctx->max_tmp_size = max_tmp_size; } - GGML_ASSERT(i_start == cgraph->n_nodes); - } - if (max_tmp_size > backend_ctx->max_tmp_size) { - for (size_t j = 0; j < n_backends; j++) { - auto & bcj = backend_ctx->backend_configs[j]; - bcj.buf.reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); - } - backend_ctx->max_tmp_size = max_tmp_size; - } - - - if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) { - backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs); - const size_t n_reduce_steps = backend_ctx->n_reduce_steps(); - const size_t n_nodes_per_device = 2 * n_reduce_steps; // tmp + ADD per step - const size_t n_cgraphs_per_device = n_reduce_steps; // 1 ADD graph per step - const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads); - const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads); - const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead(); - ggml_init_params params = { - /*.mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - backend_ctx->ctx.reset(ggml_init(params)); - for (size_t j = 0; j < n_backends; j++) { - auto & bcj = backend_ctx->backend_configs[j]; - for (size_t i = 0; i < n_subgraphs; i++) { - bcj.cgraphs[i].cgraph_main = ggml_new_graph_custom(backend_ctx->ctx.get(), cgraph->n_nodes, /*grads =*/ false); + if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) { + backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs); + const size_t n_reduce_steps = backend_ctx->n_reduce_steps(); + const size_t n_nodes_per_device = 2 * n_reduce_steps; // tmp + ADD per step + const size_t n_cgraphs_per_device = n_reduce_steps; // 1 ADD graph per step + const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads); + const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads); + const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead(); + ggml_init_params params = { + /*.mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + backend_ctx->ctx.reset(ggml_init(params)); + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + for (size_t i = 0; i < n_subgraphs; i++) { + bcj.cgraphs[i].cgraph_main = ggml_new_graph_custom(backend_ctx->ctx.get(), cgraph->n_nodes, /*grads =*/ false); + } + } + backend_ctx->cgraphs_aux.resize(n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs); + for (size_t k = 0; k < backend_ctx->cgraphs_aux.size(); k++) { + backend_ctx->cgraphs_aux[k] = ggml_new_graph_custom(backend_ctx->ctx.get(), 1, cgraph->grads); + } + backend_ctx->nodes_aux.resize(n_backends*n_nodes_per_device*backend_ctx->max_subgraphs); + for (size_t k = 0; k < backend_ctx->nodes_aux.size(); k++) { + backend_ctx->nodes_aux[k] = ggml_new_tensor_1d(backend_ctx->ctx.get(), GGML_TYPE_F32, 1); } } - backend_ctx->cgraphs_aux.resize(n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs); - for (size_t k = 0; k < backend_ctx->cgraphs_aux.size(); k++) { - backend_ctx->cgraphs_aux[k] = ggml_new_graph_custom(backend_ctx->ctx.get(), 1, cgraph->grads); - } - backend_ctx->nodes_aux.resize(n_backends*n_nodes_per_device*backend_ctx->max_subgraphs); - for (size_t k = 0; k < backend_ctx->nodes_aux.size(); k++) { - backend_ctx->nodes_aux[k] = ggml_new_tensor_1d(backend_ctx->ctx.get(), GGML_TYPE_F32, 1); - } - } - for (size_t j = 0; j < n_backends; j++) { - auto & bcj = backend_ctx->backend_configs[j]; - for (size_t i_graph = 0; i_graph < n_subgraphs; i_graph++) { - ggml_cgraph * cgraph_ij = bcj.cgraphs[i_graph].cgraph_main; - const size_t i_node_start = bcj.cgraphs[i_graph].offset; - const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs[i_graph + 1].offset : cgraph->n_nodes; - cgraph_ij->n_nodes = i_node_stop - i_node_start; - ggml_hash_set_reset(&cgraph_ij->visited_hash_set); - for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) { - ggml_tensor * node_ij = bcj.nodes[i_node]; - cgraph_ij->nodes[i_node - i_node_start] = node_ij; - const size_t hash_pos_orig = ggml_hash_find(&cgraph->visited_hash_set, cgraph->nodes[i_node]); - const size_t hash_pos_ij = ggml_hash_insert(&cgraph_ij->visited_hash_set, node_ij); - cgraph_ij->use_counts[hash_pos_ij] = cgraph->use_counts[hash_pos_orig]; + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + for (size_t i_graph = 0; i_graph < n_subgraphs; i_graph++) { + ggml_cgraph * cgraph_ij = bcj.cgraphs[i_graph].cgraph_main; + const size_t i_node_start = bcj.cgraphs[i_graph].offset; + const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs[i_graph + 1].offset : cgraph->n_nodes; + cgraph_ij->n_nodes = i_node_stop - i_node_start; + ggml_hash_set_reset(&cgraph_ij->visited_hash_set); + for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) { + ggml_tensor * node_ij = bcj.nodes[i_node]; + cgraph_ij->nodes[i_node - i_node_start] = node_ij; + const size_t hash_pos_orig = ggml_hash_find(&cgraph->visited_hash_set, cgraph->nodes[i_node]); + const size_t hash_pos_ij = ggml_hash_insert(&cgraph_ij->visited_hash_set, node_ij); + cgraph_ij->use_counts[hash_pos_ij] = cgraph->use_counts[hash_pos_orig]; + } + cgraph_ij->uid = ggml_graph_next_uid(); } } } @@ -1898,7 +1911,7 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, }; - for (size_t i = 0; i < n_subgraphs; i++) { + for (size_t i = 0; i < backend_ctx->n_subgraphs; i++) { for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, bcj.cgraphs[i].cgraph_main); @@ -1907,7 +1920,7 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, } } - if (n_backends > 1 && i < n_subgraphs - 1) { + if (n_backends > 1 && i < backend_ctx->n_subgraphs - 1) { bool backend_allreduce_success = false; if (backend_ctx->comm_ctx) { std::vector nodes; From 19124078be13c2ba81f871c96a7ab587bb00d06a Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Sun, 19 Apr 2026 11:57:21 +0200 Subject: [PATCH 16/24] mtmd: add pos_0 to mtmd_image_tokens_get_decoder_pos (breaking change) (#22082) * mtmd: add pos_0 to mtmd_image_tokens_get_decoder_pos * fix build --- tests/test-mtmd-c-api.c | 2 +- tools/mtmd/mtmd-helper.cpp | 20 ++++++++++---------- tools/mtmd/mtmd-helper.h | 2 +- tools/mtmd/mtmd.cpp | 11 +++++++---- tools/mtmd/mtmd.h | 4 +++- 5 files changed, 22 insertions(+), 17 deletions(-) diff --git a/tests/test-mtmd-c-api.c b/tests/test-mtmd-c-api.c index 7a0ce593c..b49498c87 100644 --- a/tests/test-mtmd-c-api.c +++ b/tests/test-mtmd-c-api.c @@ -42,7 +42,7 @@ int main(void) { const mtmd_image_tokens * image_tokens = mtmd_input_chunk_get_tokens_image(chunk); size_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens); // get position of the last token, which should be (nx - 1, ny - 1) - struct mtmd_decoder_pos pos = mtmd_image_tokens_get_decoder_pos(image_tokens, n_tokens - 1); + struct mtmd_decoder_pos pos = mtmd_image_tokens_get_decoder_pos(image_tokens, 0, n_tokens - 1); size_t nx = pos.x + 1; size_t ny = pos.y + 1; const char * id = mtmd_image_tokens_get_id(image_tokens); diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index 145b88cea..409407416 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -114,10 +114,10 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) { return n_pos; } -void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * chunks, mtmd_decoder_pos * out_pos) { +void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * chunks, llama_pos pos_0, mtmd_decoder_pos * out_pos) { size_t n_tokens = mtmd_image_tokens_get_n_tokens(chunks); for (size_t i = 0; i < n_tokens; i++) { - out_pos[i] = mtmd_image_tokens_get_decoder_pos(chunks, i); + out_pos[i] = mtmd_image_tokens_get_decoder_pos(chunks, pos_0, i); } } @@ -163,15 +163,15 @@ struct decode_embd_batch { } // M-RoPE for image - void set_position_mrope_2d(llama_pos pos_0, const std::vector & rel_pos, llama_seq_id seq_id) { + void set_position_mrope_2d(const std::vector & rel_pos, llama_seq_id seq_id) { GGML_ASSERT(n_pos_per_embd == 4); GGML_ASSERT(!rel_pos.empty() && (int32_t)rel_pos.size() == batch.n_tokens); seq_id_0[0] = seq_id; for (int32_t i = 0; i < batch.n_tokens; i++) { - pos[i ] = pos_0 + rel_pos[i].t; - pos[i + batch.n_tokens ] = pos_0 + rel_pos[i].y; - pos[i + batch.n_tokens * 2] = pos_0 + rel_pos[i].x; - pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused + pos[i ] = rel_pos[i].t; + pos[i + batch.n_tokens ] = rel_pos[i].y; + pos[i + batch.n_tokens * 2] = rel_pos[i].x; + pos[i + batch.n_tokens * 3] = rel_pos[i].z; } for (int i = 0; i < batch.n_tokens; i++) { batch.n_seq_id[i] = 1; @@ -188,7 +188,7 @@ struct decode_embd_batch { pos[i ] = pos_0 + i; pos[i + batch.n_tokens ] = pos_0 + i; pos[i + batch.n_tokens * 2] = pos_0 + i; - pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused + pos[i + batch.n_tokens * 3] = pos_0 + i; } for (int i = 0; i < batch.n_tokens; i++) { batch.n_seq_id[i] = 1; @@ -268,8 +268,8 @@ int32_t mtmd_helper_decode_image_chunk( } const auto n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens); std::vector rel_pos(n_tokens); - mtmd_helper_image_get_decoder_pos(image_tokens, rel_pos.data()); - batch_embd.set_position_mrope_2d(n_past, rel_pos, seq_id); + mtmd_helper_image_get_decoder_pos(image_tokens, n_past, rel_pos.data()); + batch_embd.set_position_mrope_2d(rel_pos, seq_id); } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { batch_embd.set_position_mrope_1d(n_past, seq_id); } else { diff --git a/tools/mtmd/mtmd-helper.h b/tools/mtmd/mtmd-helper.h index ff34a4121..57da78a75 100644 --- a/tools/mtmd/mtmd-helper.h +++ b/tools/mtmd/mtmd-helper.h @@ -49,7 +49,7 @@ MTMD_API llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks); // helper to get the list of relative positions corresponding to the embedding tokens, to be used by M-RoPE // out_pos must have length == mtmd_helper_get_n_tokens(image) -MTMD_API void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * image, struct mtmd_decoder_pos * out_pos); +MTMD_API void mtmd_helper_image_get_decoder_pos(const mtmd_image_tokens * image, llama_pos pos_0, struct mtmd_decoder_pos * out_pos); // helper function that automatically: // 1. run llama_decode() on text chunks diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index d0a0a4865..52fca4e81 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -1246,11 +1246,14 @@ size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) { return image_tokens->ny; } -mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, size_t i) { +mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, llama_pos pos_0, size_t i) { mtmd_decoder_pos pos; - pos.t = 0; - pos.x = i % image_tokens->nx; - pos.y = i / image_tokens->nx; + // M-RoPE logic + // TODO: support other types of position encoding if needed + pos.t = pos_0; + pos.x = pos_0 + (i % image_tokens->nx); + pos.y = pos_0 + (i / image_tokens->nx); + pos.z = 0; // unused for now return pos; } diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index a6fd8efa5..6e36cb8ec 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -196,11 +196,13 @@ struct mtmd_decoder_pos { uint32_t t; uint32_t x; uint32_t y; + uint32_t z; // unused for now, reserved for future use }; // get position for decoder attention, to be used by M-RoPE models // i is the index of the embedding token, ranging from 0 to mtmd_image_tokens_get_n_tokens() - 1 +// pos_0 is the absolute position of the first token // return relative position (for example, embedding 0 will have position (0, 0, 0); remember to adjust it to the current absolute position) -MTMD_API struct mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, size_t i); +MTMD_API struct mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, llama_pos pos_0, size_t i); // tokenize an input text prompt and a list of bitmaps (images/audio) // the prompt must have the input image marker (default: "<__media__>") in it From 471540ae8a760c9859bed07dd9c187d9bd8ab957 Mon Sep 17 00:00:00 2001 From: uvos Date: Sun, 19 Apr 2026 12:59:44 +0200 Subject: [PATCH 17/24] HIP: Remove unesscary NCCL_CHECK (#21914) --- ggml/src/ggml-cuda/vendors/hip.h | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 898fec31e..52c38908e 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -33,7 +33,6 @@ #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} -#define NCCL_CHECK(fn) {ncclResult_t err = fn; if(err != ncclSuccess) { GGML_ABORT("RCCL Failure RCCL returned: %i\n", err); }} #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) #define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width) #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) From d5b780a676b755ec8f550e9b633256d2dea8b8c5 Mon Sep 17 00:00:00 2001 From: Aldehir Rojas Date: Sun, 19 Apr 2026 06:28:35 -0500 Subject: [PATCH 18/24] common/autoparser : allow space after tool call (#22073) --- common/chat-auto-parser-generator.cpp | 6 +++--- tests/test-chat.cpp | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/common/chat-auto-parser-generator.cpp b/common/chat-auto-parser-generator.cpp index c6431b898..453559a4b 100644 --- a/common/chat-auto-parser-generator.cpp +++ b/common/chat-auto-parser-generator.cpp @@ -443,14 +443,14 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte if (!format.per_call_start.empty()) { auto wrapped_call = format.per_call_start + p.space() + tool_choice + p.space() + format.per_call_end; if (inputs.parallel_tool_calls) { - tool_calls = p.trigger_rule("tool-call", wrapped_call + p.zero_or_more(p.space() + wrapped_call)); + tool_calls = p.trigger_rule("tool-call", wrapped_call + p.zero_or_more(p.space() + wrapped_call) + p.space()); } else { - tool_calls = p.trigger_rule("tool-call", wrapped_call); + tool_calls = p.trigger_rule("tool-call", wrapped_call + p.space()); } if (!format.section_start.empty()) { tool_calls = p.trigger_rule("tool-calls", p.literal(format.section_start) + p.space() + tool_calls + p.space() + - (format.section_end.empty() ? p.end() : p.literal(format.section_end))); + (format.section_end.empty() ? p.end() : p.literal(format.section_end) + p.space())); } } else { std::string separator = ", "; // Default diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 3b8de5ce0..eee28271b 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -1796,7 +1796,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { "\n" "\n1\n\n" "\n" - "") + "\n") .enable_thinking(false) .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .tools({ special_function_tool }) @@ -1809,7 +1809,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { "\n" "\n1\n\n" "\n" - "") + "\n") .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .tools({ special_function_tool }) .expect(message_assist_call_thoughts) @@ -1826,7 +1826,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { "\n1\n\n" "\n2\n\n" "\n" - "") + "\n") .enable_thinking(false) .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .parallel_tool_calls(true) @@ -1849,7 +1849,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { "hello()\n" "\n" "\n" - "") + "\n") .enable_thinking(false) .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .tools({ @@ -1892,7 +1892,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { "hello()\n" "\n" "\n" - "" + "\n" ) .enable_thinking(true) .reasoning_format(COMMON_REASONING_FORMAT_AUTO) @@ -1908,7 +1908,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { "hello()\n" "\n" "\n" - "") + "\n") .expect_tool_calls({ { "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} }, }) From 4eac5b45095a4e8a1ff1cce4f6d030e0872fb4ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 19 Apr 2026 18:26:59 +0200 Subject: [PATCH 19/24] CUDA: refactor mma data loading for AMD (#22051) * CUDA: refactor mma data loading for AMD * fix CDNA MMQ occupancy * fix CDNA3 mma * fix RDNA3 compile --- ggml/src/ggml-cuda/common.cuh | 4 - ggml/src/ggml-cuda/fattn-mma-f16.cuh | 57 ++----- ggml/src/ggml-cuda/mma.cuh | 245 +++++++++------------------ ggml/src/ggml-cuda/mmq.cuh | 201 ++-------------------- 4 files changed, 112 insertions(+), 395 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index ddf50baf4..3aec1742e 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -269,10 +269,6 @@ static const char * cu_get_error_str(CUresult err) { #define FLASH_ATTN_AVAILABLE #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220) -#if defined(TURING_MMA_AVAILABLE) -#define LDMATRIX_TRANS_AVAILABLE -#endif // defined(TURING_MMA_AVAILABLE) - static bool fp16_available(const int cc) { return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index b613ae61f..e185449d4 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -305,12 +305,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) { constexpr int warp_size = ggml_cuda_get_physical_warp_size(); // K/V data is loaded with decreasing granularity for D for better memory bandwidth. - // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. + // The minimum granularity is 16 bytes. + constexpr int h2_per_chunk = 16/sizeof(half2); + const int chunks_per_row = D2 / h2_per_chunk; if constexpr (use_cp_async) { + static_assert(warp_size == 32, "bad warp_size"); static_assert(!oob_check, "OOB check not compatible with cp_async"); constexpr int preload = 64; - constexpr int h2_per_chunk = 16/sizeof(half2); - const int chunks_per_row = D2 / h2_per_chunk; const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); @@ -348,11 +349,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( // 6: max 1*16= 16 bytes, 8 half ggml_cuda_unroll<6>{}(load); } else { - // TODO use ggml_cuda_memcpy_1 + const half2 zero[4] = {{0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}}; auto load = [&] __device__ (const int n) { - const int stride_k = warp_size >> n; - const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k); - const int k0_stop = D2 - D2 % (1*stride_k); + const int stride_k = 32 >> n; + const int k0_start = stride_k == 32 ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); + const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); const int stride_i = warp_size / stride_k; if (k0_start == k0_stop) { @@ -371,15 +372,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); - tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f); + ggml_cuda_memcpy_1<16>(tile_KV + i*stride_tile + k*4, + !oob_check || i < i_sup ? KV + i*stride_KV + k*h2_per_chunk : zero); } } }; - // 1: max 32* 4=128 bytes, 64 half - // 2: max 16* 4= 64 bytes, 32 half - // 3: max 8* 4= 32 bytes, 16 half - // 4: max 4* 4= 16 bytes, 8 half - ggml_cuda_unroll<4>{}(load); + // 1: max 32*16=512 bytes, 256 half + // 2: max 16*16=256 bytes, 128 half + // 3: max 8*16=128 bytes, 64 half + // 4: max 4*16= 64 bytes, 32 half + // 5: max 2*16= 32 bytes, 16 half + // 6: max 1*16= 16 bytes, 8 half + ggml_cuda_unroll<6>{}(load); } } @@ -862,11 +866,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } -#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) - T_A_VKQ A_identity; - make_identity_mat(A_identity); -#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) - // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V: #pragma unroll for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) { @@ -897,29 +896,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J; T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. -#if defined(LDMATRIX_TRANS_AVAILABLE) load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); -#elif defined(AMD_MFMA_AVAILABLE) - // MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg]. - // Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T. - // Load with transposed addressing: 4 strided half loads. - { - const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2; - const half * xs0_h = (const half *) xs0; - const int stride_h = stride_tile_V * 2; // stride in half units - half * A_h = (half *) A.x; -#pragma unroll - for (int l = 0; l < 4; ++l) { - A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16]; - } - } -#else - // TODO: Try to transpose tile_V when loading gmem to smem. - // Use mma to transpose T_A_VKQ for RDNA. - T_A_VKQ A_trans; - load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); - mma(A, A_trans, A_identity); -#endif // defined(LDMATRIX_TRANS_AVAILABLE) if constexpr (T_B_KQ::I == 8) { mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); } else { diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index c91dd2d9a..b0f674635 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -86,17 +86,12 @@ namespace ggml_cuda_mma { // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR - static constexpr bool is_i_major(const data_layout dl) { - return dl == DATA_LAYOUT_I_MAJOR || - dl == DATA_LAYOUT_I_MAJOR_MIRRORED; - } - static constexpr __device__ data_layout get_input_data_layout() { -#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(RDNA3) || defined(VOLTA_MMA_AVAILABLE) return DATA_LAYOUT_I_MAJOR_MIRRORED; #else return DATA_LAYOUT_I_MAJOR; -#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(RDNA3) || defined(VOLTA_MMA_AVAILABLE) } template @@ -113,7 +108,6 @@ namespace ggml_cuda_mma { T x[ne] = {0}; static constexpr __device__ bool supported() { - if (I == 64 && J == 2) return true; if (I == 16 && J == 8) return true; if (I == 32 && J == 4) return true; if (I == 16 && J == 16) return true; @@ -122,7 +116,7 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_i(const int l) { - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> + if constexpr (I == 16 && J == 4) { return threadIdx.x % 16; } else if constexpr (I == 16 && J == 8) { return threadIdx.x % 16; @@ -139,8 +133,8 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_j(const int l) { - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> - return (2 * ((threadIdx.x / 16) % 2) + l); + if constexpr (I == 16 && J == 4) { + return threadIdx.x / 16; } else if constexpr (I == 16 && J == 8) { return 2 * (threadIdx.x / 16) + l; } else if constexpr (I == 32 && J == 4) { @@ -154,7 +148,7 @@ namespace ggml_cuda_mma { return -1; } } -#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#elif defined(VOLTA_MMA_AVAILABLE) static constexpr int ne = I * J / 32; T x[ne] = {0}; @@ -283,7 +277,7 @@ namespace ggml_cuda_mma { static constexpr int J = J_; static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; @@ -407,7 +401,7 @@ namespace ggml_cuda_mma { return -1; } } -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) }; template @@ -701,57 +695,12 @@ namespace ggml_cuda_mma { } #endif // defined(TURING_MMA_AVAILABLE) - static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) { -#if defined(RDNA4) - const int row = t.get_i(0); - const int left_right = t.get_j(0) / 4; - const int up_down = row / 8; - const int idx = row % 8; - reinterpret_cast(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f; -#else - GGML_UNUSED_VARS(t); - NO_DEVICE_CODE; -#endif // defined(RDNA4) - } - template static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { -#if defined(AMD_MFMA_AVAILABLE) - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> -#pragma unroll - for (int l = 0; l < t.ne; ++l) { - t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; - } - } else { - ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); - } -#elif defined(AMD_WMMA_AVAILABLE) - // All wmma layout has contiguous data when i-major. - if constexpr (is_i_major(dl)) { - // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes() - constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes(); - if constexpr (sizeof(t.x) > aligned_copy_bytes) { - static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size"); - constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes; -#pragma unroll - for (int i = 0; i < aligned_copy_count; ++i) { - ggml_cuda_memcpy_1(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i)); - } - } else { - ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); - } - } else { -#pragma unroll - for (int l = 0; l < t.ne; ++l) { - t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; - } - } -#else #pragma unroll for (int l = 0; l < t.ne; ++l) { t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; } -#endif // defined(AMD_MFMA_AVAILABLE) } template @@ -764,26 +713,37 @@ namespace ggml_cuda_mma { : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); #else - load_generic(t, xs0, stride); + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } - template + template static __device__ __forceinline__ void load_ldmatrix( - tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) { + tile<16, 4, T, dl> & t, const T * __restrict__ xs0, const int stride) { #ifdef TURING_MMA_AVAILABLE int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride; asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); +#elif defined(AMD_WMMA_AVAILABLE) +#ifdef RDNA3 + static_assert(dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout"); + static_assert(sizeof(t.x) == 16, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x + 0, xs0 + t.get_i(0)*stride + 0); + ggml_cuda_memcpy_1<8>(t.x + 2, xs0 + t.get_i(0)*stride + 2); +#else + static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout"); + static_assert(sizeof(t.x) == 8, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); +#endif // RDNA3 +#elif defined(AMD_MFMA_AVAILABLE) + static_assert(sizeof(t.x) == 4, "bad ne"); + ggml_cuda_memcpy_1<4>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); #else -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; -#else - load_generic(t, xs0, stride); -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #endif // TURING_MMA_AVAILABLE } @@ -796,19 +756,26 @@ namespace ggml_cuda_mma { asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) : "l"(xs)); -#else -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA -#if 1 - // TODO: more generic handling - static_assert(sizeof(T) == 4, "bad type size"); +#elif defined(VOLTA_MMA_AVAILABLE) ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0); ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4); +#elif defined(AMD_WMMA_AVAILABLE) +#ifdef RDNA3 + static_assert(dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout"); + static_assert(sizeof(t.x) == 32, "bad ne"); + ggml_cuda_memcpy_1<16>(t.x + 0, xs0 + t.get_i(0)*stride + 0); + ggml_cuda_memcpy_1<16>(t.x + 4, xs0 + t.get_i(0)*stride + 4); #else - load_generic(t, xs0, stride); -#endif // 1 + static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout"); + static_assert(sizeof(t.x) == 16, "bad ne"); + ggml_cuda_memcpy_1<16>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); +#endif // RDNA3 +#elif defined(AMD_MFMA_AVAILABLE) + static_assert(sizeof(t.x) == 8, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); #else - load_generic(t, xs0, stride); -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -827,23 +794,30 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ void load_ldmatrix( tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); #else GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } template static __device__ __forceinline__ void load_ldmatrix_trans( tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { #ifdef TURING_MMA_AVAILABLE - int * xi = (int * ) t.x; + int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) : "l"(xs)); +#elif defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + half * xh = (half *) t.x; +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + xh[2*l + 0] = ((const half *) xs0)[(2*t.get_j(l) + 0)*(2*stride) + t.get_i(l)]; + xh[2*l + 1] = ((const half *) xs0)[(2*t.get_j(l) + 1)*(2*stride) + t.get_i(l)]; + } #else GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; @@ -1218,73 +1192,27 @@ namespace ggml_cuda_mma { using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * acc = (int32x4_t *) D.x; #if defined(CDNA4) || defined(CDNA3) - acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], - ((int64_t *) B.x)[0], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); #elif defined(CDNA2) || defined(CDNA1) - acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], - B.x[0], - acc[0], - 0, 0, 0); - acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], - B.x[1], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], B.x[1], acc[0], 0, 0, 0); #endif // defined(CDNA4) || defined(CDNA3) - #elif defined(AMD_WMMA_AVAILABLE) - using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; int32x8_t * acc = (int32x8_t *) D.x; - #if defined(RDNA4) using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; int32x2_t * a_vec = (int32x2_t *) A.x; int32x2_t * b_vec = (int32x2_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - true - ); - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[1], - true, - b_vec[1], - acc[0], - true - ); - + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[0], true, b_vec[0], acc[0], true); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[1], true, b_vec[1], acc[0], true); #elif defined(RDNA3) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * a_vec = (int32x4_t *) A.x; int32x4_t * b_vec = (int32x4_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - true - ); - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[1], - true, - b_vec[1], - acc[0], - true - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[0], true, b_vec[0], acc[0], true); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[1], true, b_vec[1], acc[0], true); #endif // RDNA4 - #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -1297,19 +1225,10 @@ namespace ggml_cuda_mma { using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; int32x16_t * acc = (int32x16_t *) D.x; #if defined(CDNA4) || defined(CDNA3) - acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], - ((int64_t *) B.x)[0], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); #elif defined(CDNA2) || defined(CDNA1) - acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], - B.x[0], - acc[0], - 0, 0, 0); - acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], - B.x[1], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], B.x[0], acc[0], 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], B.x[1], acc[0], 0, 0, 0); #endif // defined(CDNA4) || defined(CDNA3) #else @@ -1329,7 +1248,7 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ void mma( tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -1344,12 +1263,12 @@ namespace ggml_cuda_mma { #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } static __device__ __forceinline__ void mma( tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -1364,41 +1283,35 @@ namespace ggml_cuda_mma { #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } template static __device__ __forceinline__ void mma( tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + int32x4_t * acc = (int32x4_t *) D.x; +#if defined(CDNA4) || defined(CDNA3) + const int64_t xA = uint32_t(A.x[0]); + const int64_t xB = uint32_t(B.x[0]); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(xA, xB, acc[0], 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA1) + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], 0, 0, 0); +#endif // defined(CDNA4) || defined(CDNA3) +#elif defined(AMD_WMMA_AVAILABLE) using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; int32x8_t * acc = (int32x8_t *) D.x; #if defined(RDNA4) using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; int32x2_t * a_vec = (int32x2_t *) A.x; int32x2_t * b_vec = (int32x2_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - false - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[0], true, b_vec[0], acc[0], false); #elif defined(RDNA3) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * a_vec = (int32x4_t *) A.x; int32x4_t * b_vec = (int32x4_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - false - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[0], true, b_vec[0], acc[0], false); #endif // RDNA4 #else GGML_UNUSED(D); diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 28b662df9..b1a319de9 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -104,7 +104,7 @@ struct tile_x_sizes { }; static int get_mmq_x_max_host(const int cc) { - return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : + return (turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? #ifdef GGML_CUDA_FORCE_MMQ 128 : 64; @@ -114,9 +114,9 @@ static int get_mmq_x_max_host(const int cc) { } static constexpr __device__ int get_mmq_x_max_device() { -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) return 128; -#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#else // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) #if defined(GGML_USE_HIP) return 64; @@ -1054,13 +1054,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); float dB; const int j = j0 + tile_C::get_j(0); @@ -1295,13 +1295,13 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); @@ -1435,57 +1435,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -1510,13 +1460,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; @@ -1742,74 +1692,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2; - const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 - : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y - : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); - - tile_C Cm; - if (k01 >= MMQ_TILE_NE_K * 3/4) { - tile_A A1; - A1.x[0] = 0x01010101; - A1.x[1] = 0x01010101; - mma(Cm, A1, B[0]); - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C Cd; - mma(Cd, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); - float tmp = Cd.x[l]*dm.x; - if (k01 >= MMQ_TILE_NE_K * 3/4) { - tmp -= Cm.x[l]*dm.y; - } - sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; - sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -1834,13 +1717,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y; @@ -2573,59 +2456,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; - const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); - sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -2651,13 +2482,13 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; From e365e658f07b63371489570dfde597f199b26c23 Mon Sep 17 00:00:00 2001 From: "Alessandro de Oliveira Faria (A.K.A.CABELO)" Date: Sun, 19 Apr 2026 19:41:43 -0300 Subject: [PATCH 20/24] vendor : update cpp-httplib to 0.42.0 (#21781) --- scripts/sync_vendor.py | 2 +- vendor/cpp-httplib/httplib.cpp | 98 +++++++----- vendor/cpp-httplib/httplib.h | 274 ++++++++++++++++++++------------- 3 files changed, 227 insertions(+), 147 deletions(-) diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 3f1e74f7c..7ce093323 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -5,7 +5,7 @@ import os import sys import subprocess -HTTPLIB_VERSION = "refs/tags/v0.40.0" +HTTPLIB_VERSION = "refs/tags/v0.42.0" vendor = { "https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp", diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index 8ff1da57b..1cd71c2ec 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -1,7 +1,5 @@ #include "httplib.h" namespace httplib { -// httplib::any — type-erased value container (C++11 compatible) -// On C++17+ builds, thin wrappers around std::any are provided. /* * Implementation that will be part of the .cc file if split into .h + .cc. @@ -1877,7 +1875,7 @@ int getaddrinfo_with_timeout(const char *node, const char *service, } return ret; -#elif TARGET_OS_MAC +#elif TARGET_OS_MAC && defined(__clang__) if (!node) { return EAI_NONAME; } // macOS implementation using CFHost API for asynchronous DNS resolution CFStringRef hostname_ref = CFStringCreateWithCString( @@ -5836,6 +5834,17 @@ std::string Request::get_param_value(const std::string &key, return std::string(); } +std::vector +Request::get_param_values(const std::string &key) const { + auto rng = params.equal_range(key); + std::vector values; + values.reserve(static_cast(std::distance(rng.first, rng.second))); + for (auto it = rng.first; it != rng.second; ++it) { + values.push_back(it->second); + } + return values; +} + size_t Request::get_param_value_count(const std::string &key) const { auto r = params.equal_range(key); return static_cast(std::distance(r.first, r.second)); @@ -7013,6 +7022,15 @@ Server &Server::set_keep_alive_timeout(time_t sec) { return *this; } +template +Server &Server::set_keep_alive_timeout( + const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t /*usec*/) { + set_keep_alive_timeout(sec); + }); + return *this; +} + Server &Server::set_read_timeout(time_t sec, time_t usec) { read_timeout_sec_ = sec; read_timeout_usec_ = usec; @@ -9119,20 +9137,21 @@ bool ClientImpl::redirect(Request &req, Response &res, Error &error) { auto location = res.get_header_value("location"); if (location.empty()) { return false; } - thread_local const std::regex re( - R"((?:(https?):)?(?://(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); + detail::UrlComponents uc; + if (!detail::parse_url(location, uc)) { return false; } - std::smatch m; - if (!std::regex_match(location, m, re)) { return false; } + // Only follow http/https redirects + if (!uc.scheme.empty() && uc.scheme != "http" && uc.scheme != "https") { + return false; + } auto scheme = is_ssl() ? "https" : "http"; - auto next_scheme = m[1].str(); - auto next_host = m[2].str(); - if (next_host.empty()) { next_host = m[3].str(); } - auto port_str = m[4].str(); - auto next_path = m[5].str(); - auto next_query = m[6].str(); + auto next_scheme = std::move(uc.scheme); + auto next_host = std::move(uc.host); + auto port_str = std::move(uc.port); + auto next_path = std::move(uc.path); + auto next_query = std::move(uc.query); auto next_port = port_; if (!port_str.empty()) { @@ -9145,7 +9164,7 @@ bool ClientImpl::redirect(Request &req, Response &res, Error &error) { if (next_host.empty()) { next_host = host_; } if (next_path.empty()) { next_path = "/"; } - auto path = decode_query_component(next_path, true) + next_query; + auto path = decode_path_component(next_path) + next_query; // Same host redirect - use current client if (next_scheme == scheme && next_host == host_ && next_port == port_) { @@ -10869,12 +10888,9 @@ Client::Client(const std::string &scheme_host_port) Client::Client(const std::string &scheme_host_port, const std::string &client_cert_path, const std::string &client_key_path) { - const static std::regex re( - R"((?:([a-z]+):\/\/)?(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)"); - - std::smatch m; - if (std::regex_match(scheme_host_port, m, re)) { - auto scheme = m[1].str(); + detail::UrlComponents uc; + if (detail::parse_url(scheme_host_port, uc) && !uc.host.empty()) { + auto &scheme = uc.scheme; #ifdef CPPHTTPLIB_SSL_ENABLED if (!scheme.empty() && (scheme != "http" && scheme != "https")) { @@ -10890,12 +10906,10 @@ Client::Client(const std::string &scheme_host_port, auto is_ssl = scheme == "https"; - auto host = m[2].str(); - if (host.empty()) { host = m[3].str(); } + auto host = std::move(uc.host); - auto port_str = m[4].str(); auto port = is_ssl ? 443 : 80; - if (!port_str.empty() && !detail::parse_port(port_str, port)) { return; } + if (!uc.port.empty() && !detail::parse_port(uc.port, port)) { return; } if (is_ssl) { #ifdef CPPHTTPLIB_SSL_ENABLED @@ -12466,6 +12480,18 @@ std::string Request::sni() const { */ #ifdef CPPHTTPLIB_OPENSSL_SUPPORT +// These wrappers forward to deprecated APIs that will be removed by v1.0.0. +// Suppress C4996 / -Wdeprecated-declarations so that MSVC /sdl builds (which +// promote C4996 to an error) compile cleanly even though the wrappers +// themselves are also marked [[deprecated]]. +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#elif defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#endif + SSL_CTX *Client::ssl_context() const { if (is_ssl_) { return static_cast(*cli_).ssl_context(); } return nullptr; @@ -12480,6 +12506,12 @@ long Client::get_verify_result() const { if (is_ssl_) { return static_cast(*cli_).get_verify_result(); } return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? } + +#if defined(_MSC_VER) +#pragma warning(pop) +#elif defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic pop +#endif #endif // CPPHTTPLIB_OPENSSL_SUPPORT /* @@ -16302,12 +16334,10 @@ bool WebSocket::is_open() const { return !closed_; } WebSocketClient::WebSocketClient( const std::string &scheme_host_port_path, const Headers &headers) : headers_(headers) { - const static std::regex re( - R"(([a-z]+):\/\/(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?(\/.*))"); - - std::smatch m; - if (std::regex_match(scheme_host_port_path, m, re)) { - auto scheme = m[1].str(); + detail::UrlComponents uc; + if (detail::parse_url(scheme_host_port_path, uc) && !uc.scheme.empty() && + !uc.host.empty() && !uc.path.empty()) { + auto &scheme = uc.scheme; #ifdef CPPHTTPLIB_SSL_ENABLED if (scheme != "ws" && scheme != "wss") { @@ -16323,14 +16353,12 @@ WebSocketClient::WebSocketClient( auto is_ssl = scheme == "wss"; - host_ = m[2].str(); - if (host_.empty()) { host_ = m[3].str(); } + host_ = std::move(uc.host); - auto port_str = m[4].str(); port_ = is_ssl ? 443 : 80; - if (!port_str.empty() && !detail::parse_port(port_str, port_)) { return; } + if (!uc.port.empty() && !detail::parse_port(uc.port, port_)) { return; } - path_ = m[5].str(); + path_ = std::move(uc.path); #ifdef CPPHTTPLIB_SSL_ENABLED is_ssl_ = is_ssl; diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index 2967ddf5e..1d12f2d2b 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -8,8 +8,8 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.40.0" -#define CPPHTTPLIB_VERSION_NUM "0x002800" +#define CPPHTTPLIB_VERSION "0.42.0" +#define CPPHTTPLIB_VERSION_NUM "0x002a00" #ifdef _WIN32 #if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00 @@ -333,13 +333,10 @@ using socket_t = int; #include #include #include -#if __cplusplus >= 201703L -#include -#endif // On macOS with a TLS backend, enable Keychain root certificates by default // unless the user explicitly opts out. -#if defined(__APPLE__) && \ +#if defined(__APPLE__) && defined(__clang__) && \ !defined(CPPHTTPLIB_DISABLE_MACOSX_AUTOMATIC_ROOT_CERTIFICATES) && \ (defined(CPPHTTPLIB_OPENSSL_SUPPORT) || \ defined(CPPHTTPLIB_MBEDTLS_SUPPORT) || \ @@ -358,7 +355,7 @@ using socket_t = int; #if defined(CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO) || \ defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) -#if TARGET_OS_MAC +#if TARGET_OS_MAC && defined(__clang__) #include #include #endif @@ -701,9 +698,96 @@ inline bool parse_port(const std::string &s, int &port) { return parse_port(s.data(), s.size(), port); } +struct UrlComponents { + std::string scheme; + std::string host; + std::string port; + std::string path; + std::string query; +}; + +inline bool parse_url(const std::string &url, UrlComponents &uc) { + uc = {}; + size_t pos = 0; + + auto sep = url.find("://"); + if (sep != std::string::npos) { + uc.scheme = url.substr(0, sep); + + // Scheme must be [a-z]+ only + if (uc.scheme.empty()) { return false; } + for (auto c : uc.scheme) { + if (c < 'a' || c > 'z') { return false; } + } + + pos = sep + 3; + } else if (url.compare(0, 2, "//") == 0) { + pos = 2; + } + + auto has_authority_prefix = pos > 0; + auto has_authority = has_authority_prefix || (!url.empty() && url[0] != '/' && + url[0] != '?' && url[0] != '#'); + if (has_authority) { + if (pos < url.size() && url[pos] == '[') { + auto close = url.find(']', pos); + if (close == std::string::npos) { return false; } + uc.host = url.substr(pos + 1, close - pos - 1); + + // IPv6 host must be [a-fA-F0-9:]+ only + if (uc.host.empty()) { return false; } + for (auto c : uc.host) { + if (!((c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F') || + (c >= '0' && c <= '9') || c == ':')) { + return false; + } + } + + pos = close + 1; + } else { + auto end = url.find_first_of(":/?#", pos); + if (end == std::string::npos) { end = url.size(); } + uc.host = url.substr(pos, end - pos); + pos = end; + } + + if (pos < url.size() && url[pos] == ':') { + ++pos; + auto end = url.find_first_of("/?#", pos); + if (end == std::string::npos) { end = url.size(); } + uc.port = url.substr(pos, end - pos); + pos = end; + } + + // Without :// or //, the entire input must be consumed as host[:port]. + // If there is leftover (path, query, etc.), this is not a valid + // host[:port] string — clear and reparse as a plain path. + if (!has_authority_prefix && pos < url.size()) { + uc.host.clear(); + uc.port.clear(); + pos = 0; + } + } + + if (pos < url.size() && url[pos] != '?' && url[pos] != '#') { + auto end = url.find_first_of("?#", pos); + if (end == std::string::npos) { end = url.size(); } + uc.path = url.substr(pos, end - pos); + pos = end; + } + + if (pos < url.size() && url[pos] == '?') { + auto end = url.find('#', pos); + if (end == std::string::npos) { end = url.size(); } + uc.query = url.substr(pos, end - pos); + } + + return true; +} + } // namespace detail -enum SSLVerifierResponse { +enum class SSLVerifierResponse { // no decision has been made, use the built-in certificate verifier NoDecisionMade, // connection certificate is verified and accepted @@ -797,38 +881,15 @@ using Match = std::smatch; using DownloadProgress = std::function; using UploadProgress = std::function; - -#if __cplusplus >= 201703L - -using any = std::any; -using bad_any_cast = std::bad_any_cast; - -template T any_cast(const any &a) { return std::any_cast(a); } -template T any_cast(any &a) { return std::any_cast(a); } -template T any_cast(any &&a) { - return std::any_cast(std::move(a)); -} -template const T *any_cast(const any *a) noexcept { - return std::any_cast(a); -} -template T *any_cast(any *a) noexcept { - return std::any_cast(a); -} - -#else // C++11/14 implementation - -class bad_any_cast : public std::bad_cast { -public: - const char *what() const noexcept override { return "bad any_cast"; } -}; - +/* + * detail: type-erased storage used by UserData. + * ABI-stable regardless of C++ standard — always uses this custom + * implementation instead of std::any. + */ namespace detail { using any_type_id = const void *; -// Returns a unique per-type ID without RTTI. -// The static address is stable across TUs because function templates are -// implicitly inline and the ODR merges their statics into one. template any_type_id any_typeid() noexcept { static const char id = 0; return &id; @@ -851,89 +912,60 @@ template struct any_value final : any_storage { } // namespace detail -class any { - std::unique_ptr storage_; - +class UserData { public: - any() noexcept = default; - any(const any &o) : storage_(o.storage_ ? o.storage_->clone() : nullptr) {} - any(any &&) noexcept = default; - any &operator=(const any &o) { - storage_ = o.storage_ ? o.storage_->clone() : nullptr; - return *this; + UserData() = default; + UserData(UserData &&) noexcept = default; + UserData &operator=(UserData &&) noexcept = default; + + UserData(const UserData &o) { + for (const auto &e : o.entries_) { + if (e.second) { entries_[e.first] = e.second->clone(); } + } } - any &operator=(any &&) noexcept = default; - template < - typename T, typename D = typename std::decay::type, - typename std::enable_if::value, int>::type = 0> - any(T &&v) : storage_(new detail::any_value(std::forward(v))) {} - - template < - typename T, typename D = typename std::decay::type, - typename std::enable_if::value, int>::type = 0> - any &operator=(T &&v) { - storage_.reset(new detail::any_value(std::forward(v))); + UserData &operator=(const UserData &o) { + if (this != &o) { + entries_.clear(); + for (const auto &e : o.entries_) { + if (e.second) { entries_[e.first] = e.second->clone(); } + } + } return *this; } - bool has_value() const noexcept { return storage_ != nullptr; } - void reset() noexcept { storage_.reset(); } + template void set(const std::string &key, T &&value) { + using D = typename std::decay::type; + entries_[key].reset(new detail::any_value(std::forward(value))); + } - template friend T *any_cast(any *a) noexcept; - template friend const T *any_cast(const any *a) noexcept; + template T *get(const std::string &key) noexcept { + auto it = entries_.find(key); + if (it == entries_.end() || !it->second) { return nullptr; } + if (it->second->type_id() != detail::any_typeid()) { return nullptr; } + return &static_cast *>(it->second.get())->value; + } + + template const T *get(const std::string &key) const noexcept { + auto it = entries_.find(key); + if (it == entries_.end() || !it->second) { return nullptr; } + if (it->second->type_id() != detail::any_typeid()) { return nullptr; } + return &static_cast *>(it->second.get())->value; + } + + bool has(const std::string &key) const noexcept { + return entries_.find(key) != entries_.end(); + } + + void erase(const std::string &key) { entries_.erase(key); } + + void clear() noexcept { entries_.clear(); } + +private: + std::unordered_map> + entries_; }; -template T *any_cast(any *a) noexcept { - if (!a || !a->storage_) { return nullptr; } - if (a->storage_->type_id() != detail::any_typeid()) { return nullptr; } - return &static_cast *>(a->storage_.get())->value; -} - -template const T *any_cast(const any *a) noexcept { - if (!a || !a->storage_) { return nullptr; } - if (a->storage_->type_id() != detail::any_typeid()) { return nullptr; } - return &static_cast *>(a->storage_.get())->value; -} - -template T any_cast(const any &a) { - using U = - typename std::remove_cv::type>::type; - const U *p = any_cast(&a); -#ifndef CPPHTTPLIB_NO_EXCEPTIONS - if (!p) { throw bad_any_cast{}; } -#else - if (!p) { std::abort(); } -#endif - return static_cast(*p); -} - -template T any_cast(any &a) { - using U = - typename std::remove_cv::type>::type; - U *p = any_cast(&a); -#ifndef CPPHTTPLIB_NO_EXCEPTIONS - if (!p) { throw bad_any_cast{}; } -#else - if (!p) { std::abort(); } -#endif - return static_cast(*p); -} - -template T any_cast(any &&a) { - using U = - typename std::remove_cv::type>::type; - U *p = any_cast(&a); -#ifndef CPPHTTPLIB_NO_EXCEPTIONS - if (!p) { throw bad_any_cast{}; } -#else - if (!p) { std::abort(); } -#endif - return static_cast(std::move(*p)); -} - -#endif // __cplusplus >= 201703L - struct Response; using ResponseHandler = std::function; @@ -1261,6 +1293,7 @@ struct Request { bool has_param(const std::string &key) const; std::string get_param_value(const std::string &key, size_t id = 0) const; + std::vector get_param_values(const std::string &key) const; size_t get_param_value_count(const std::string &key) const; bool is_multipart_form_data() const; @@ -1293,7 +1326,7 @@ struct Response { // User-defined context — set by pre-routing/pre-request handlers and read // by route handlers to pass arbitrary data (e.g. decoded auth tokens). - std::map user_data; + UserData user_data; bool has_header(const std::string &key) const; std::string get_header_value(const std::string &key, const char *def = "", @@ -1664,6 +1697,9 @@ public: Server &set_keep_alive_max_count(size_t count); Server &set_keep_alive_timeout(time_t sec); + template + Server & + set_keep_alive_timeout(const std::chrono::duration &duration); Server &set_read_timeout(time_t sec, time_t usec = 0); template @@ -2790,10 +2826,26 @@ public: "This function will be removed by v1.0.0.")]] SSL_CTX *ssl_context() const; + // Override of a deprecated virtual in ClientImpl. Suppress C4996 / + // -Wdeprecated-declarations on the override declaration itself so that + // MSVC /sdl builds compile cleanly. Will be removed together with the + // base virtual by v1.0.0. +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#elif defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#endif [[deprecated("Use set_session_verifier(session_t) instead. " "This function will be removed by v1.0.0.")]] void set_server_certificate_verifier( std::function verifier) override; +#if defined(_MSC_VER) +#pragma warning(pop) +#elif defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic pop +#endif private: bool verify_host(X509 *server_cert) const; From 9d49acb2a7102db476c8e09e9322d4a9aaea3b4b Mon Sep 17 00:00:00 2001 From: Yes You Can Have Your Own <188969017+yychyo@users.noreply.github.com> Date: Mon, 20 Apr 2026 08:30:24 +0300 Subject: [PATCH 21/24] server: rename --clear-idle to --cache-idle-slots (#21741) --- common/arg.cpp | 8 ++++---- common/common.h | 2 +- tools/server/README.md | 2 +- tools/server/server-context.cpp | 12 ++++++------ tools/server/tests/unit/test_kv_keep_only_active.py | 2 +- tools/server/tests/utils.py | 6 +++--- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 6f22f7819..43fe5a25d 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1316,13 +1316,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL})); add_opt(common_arg( - {"--clear-idle"}, - {"--no-clear-idle"}, + {"--cache-idle-slots"}, + {"--no-cache-idle-slots"}, "save and clear idle slots on new task (default: enabled, requires unified KV and cache-ram)", [](common_params & params, bool value) { - params.clear_idle = value; + params.cache_idle_slots = value; } - ).set_env("LLAMA_ARG_CLEAR_IDLE").set_examples({LLAMA_EXAMPLE_SERVER})); + ).set_env("LLAMA_ARG_CACHE_IDLE_SLOTS").set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--context-shift"}, {"--no-context-shift"}, diff --git a/common/common.h b/common/common.h index 4c36e85e0..027339294 100644 --- a/common/common.h +++ b/common/common.h @@ -567,7 +567,7 @@ struct common_params { int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting bool cache_prompt = true; // whether to enable prompt caching - bool clear_idle = true; // save and clear idle slots upon starting a new task + bool cache_idle_slots = true; // save and clear idle slots upon starting a new task int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot int32_t checkpoint_every_nt = 8192; // make a checkpoint every n tokens during prefill int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. diff --git a/tools/server/README.md b/tools/server/README.md index 84a29cba0..db1f27039 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -167,7 +167,7 @@ For the full list of features, please refer to [server's changelog](https://gith | `-cpent, --checkpoint-every-n-tokens N` | create a checkpoint every n tokens during prefill (processing), -1 to disable (default: 8192)
(env: LLAMA_ARG_CHECKPOINT_EVERY_NT) | | `-cram, --cache-ram N` | set the maximum cache size in MiB (default: 8192, -1 - no limit, 0 - disable)[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)
(env: LLAMA_ARG_CACHE_RAM) | | `-kvu, --kv-unified, -no-kvu, --no-kv-unified` | use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)
(env: LLAMA_ARG_KV_UNIFIED) | -| `--clear-idle, --no-clear-idle` | save and clear idle slots on new task (default: enabled, requires unified KV and cache-ram)
(env: LLAMA_ARG_CLEAR_IDLE) | +| `--cache-idle-slots, --no-cache-idle-slots` | save and clear idle slots on new task (default: enabled, requires unified KV and cache-ram)
(env: LLAMA_ARG_CACHE_IDLE_SLOTS) | | `--context-shift, --no-context-shift` | whether to use context shift on infinite text generation (default: disabled)
(env: LLAMA_ARG_CONTEXT_SHIFT) | | `-r, --reverse-prompt PROMPT` | halt generation at PROMPT, return control in interactive mode | | `-sp, --special` | special tokens output enabled (default: false) | diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 70ebcc225..7ffe6a303 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -987,13 +987,13 @@ private: metrics.init(); - if (params_base.clear_idle) { + if (params_base.cache_idle_slots) { if (!params_base.kv_unified) { - SRV_WRN("%s: --clear-idle requires --kv-unified, disabling\n", __func__); - params_base.clear_idle = false; + SRV_WRN("%s: --cache-idle-slots requires --kv-unified, disabling\n", __func__); + params_base.cache_idle_slots = false; } else if (params_base.cache_ram_mib == 0) { - SRV_WRN("%s: --clear-idle requires --cache-ram, disabling\n", __func__); - params_base.clear_idle = false; + SRV_WRN("%s: --cache-idle-slots requires --cache-ram, disabling\n", __func__); + params_base.cache_idle_slots = false; } else { SRV_INF("%s: idle slots will be saved to prompt cache and cleared upon starting a new task\n", __func__); SRV_DBG("%s", "__TEST_TAG_CLEAR_IDLE_ENABLED__\n"); @@ -1886,7 +1886,7 @@ private: break; // drop the task } - if (params_base.clear_idle) { + if (params_base.cache_idle_slots) { for (auto & s : slots) { if (!s.is_processing()) { slot_save_and_clear(s); diff --git a/tools/server/tests/unit/test_kv_keep_only_active.py b/tools/server/tests/unit/test_kv_keep_only_active.py index da93d5001..f4b08b5dd 100644 --- a/tools/server/tests/unit/test_kv_keep_only_active.py +++ b/tools/server/tests/unit/test_kv_keep_only_active.py @@ -91,7 +91,7 @@ def test_clear_and_restore(): def test_disabled_with_flag(): global server - server.no_clear_idle = True + server.no_cache_idle_slots = True server.start() log = LogReader(server.log_path) diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index 5ddac5be4..ddbb76c9a 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -103,7 +103,7 @@ class ServerProcess: media_path: str | None = None sleep_idle_seconds: int | None = None cache_ram: int | None = None - no_clear_idle: bool = False + no_cache_idle_slots: bool = False log_path: str | None = None webui_mcp_proxy: bool = False @@ -242,8 +242,8 @@ class ServerProcess: server_args.extend(["--sleep-idle-seconds", self.sleep_idle_seconds]) if self.cache_ram is not None: server_args.extend(["--cache-ram", self.cache_ram]) - if self.no_clear_idle: - server_args.append("--no-clear-idle") + if self.no_cache_idle_slots: + server_args.append("--no-cache-idle-slots") if self.webui_mcp_proxy: server_args.append("--webui-mcp-proxy") From 788fcbc5ddb16101e6388170bf84c0d10bd8c006 Mon Sep 17 00:00:00 2001 From: Katostrofik Date: Mon, 20 Apr 2026 01:39:45 -0400 Subject: [PATCH 22/24] [SYCL] Fix reorder MMVQ assert on unaligned vocab sizes (#22035) * [SYCL] Fix reorder MMVQ assert on unaligned vocab sizes The reorder mul_mat_vec_q dispatchers for Q4_0, Q8_0, Q4_K, and Q6_K asserted that block_num_y was a multiple of 16 subgroups. Models with a vocab size not divisible by 16 (for example HY-MT at 120818) aborted on model load when the output projection tripped the assert. I replaced the assert with padding: block_num_y now rounds up to a whole number of subgroup-sized workgroups. The kernel already has the row bounds check (`if (row >= nrows) return;`) so the extra padded threads early-exit cleanly. Row values are uniform across a subgroup so the collective reduce stays safe. For aligned vocab sizes the padded block_num_y equals the old value, so the kernel launch is identical and there is no regression. Thanks to @arthw for flagging the relationship to #21527. Fixes #22020. AI assisted coding, tested on Intel B70 hardware. * sycl: use WARP_SIZE for num_subgroups in reorder MMVQ launches Replaces the hardcoded 16 with WARP_SIZE in the four reorder_mul_mat_vec launch helpers (Q4_0, Q8_0, Q4_K, Q6_K). Compile-time no-op on the Intel target where WARP_SIZE is 16, but makes the relationship to subgroup size explicit. Per review by @NeoZhangJianyu on #22035. Assisted by Claude. --- ggml/src/ggml-sycl/mmvq.cpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index af22b98dd..3a4577ecb 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -537,9 +537,9 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx, static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK4_0 == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); @@ -682,9 +682,9 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK8_0 == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); @@ -798,9 +798,9 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); @@ -842,9 +842,9 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); From de71b5f81c3b6b9f8bdaf1b2a21198e1eede3fda Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 20 Apr 2026 08:42:37 +0300 Subject: [PATCH 23/24] server : refactor "use checkpoint" logic (#22114) --- common/arg.cpp | 2 +- common/common.cpp | 38 +++++++++++++++++++- common/common.h | 25 ++++++++++--- common/hf-cache.cpp | 4 +-- common/speculative.cpp | 62 ++++++++------------------------- common/speculative.h | 10 ------ tools/server/server-context.cpp | 44 ++++++++++------------- 7 files changed, 93 insertions(+), 92 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 43fe5a25d..099f0aeab 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -292,7 +292,7 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa hf_tag = "default"; } - std::string model_endpoint = get_model_endpoint(); + std::string model_endpoint = common_get_model_endpoint(); auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini"; // prepare local path for caching diff --git a/common/common.cpp b/common/common.cpp index d3f1cee39..6cde71d81 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1382,7 +1382,7 @@ common_init_result_ptr common_init_from_params(common_params & params) { common_init_result::~common_init_result() = default; -std::string get_model_endpoint() { +std::string common_get_model_endpoint() { const char * model_endpoint_env = getenv("MODEL_ENDPOINT"); // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility. const char * hf_endpoint_env = getenv("HF_ENDPOINT"); @@ -1397,6 +1397,42 @@ std::string get_model_endpoint() { return model_endpoint; } +common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) { + auto * mem = llama_get_memory(ctx); + if (mem == nullptr) { + return COMMON_CONTEXT_SEQ_RM_TYPE_NO; + } + + common_context_seq_rm_type res = COMMON_CONTEXT_SEQ_RM_TYPE_PART; + + llama_memory_clear(mem, true); + + // eval 2 tokens to check if the context is compatible + std::vector tmp; + tmp.push_back(0); + tmp.push_back(0); + + int ret = llama_decode(ctx, llama_batch_get_one(tmp.data(), tmp.size())); + if (ret != 0) { + LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret); + res = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + goto done; + } + + // try to remove the last tokens + if (!llama_memory_seq_rm(mem, 0, 1, -1)) { + LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); + res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + goto done; + } + +done: + llama_memory_clear(mem, true); + llama_synchronize(ctx); + + return res; +} + void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora) { std::vector loras; std::vector scales; diff --git a/common/common.h b/common/common.h index 027339294..cbcb3bb65 100644 --- a/common/common.h +++ b/common/common.h @@ -308,10 +308,9 @@ struct common_params_speculative { // ngram-based speculative decoding - uint16_t ngram_size_n = 12; // ngram size for lookup - uint16_t ngram_size_m = 48; // mgram size for speculative tokens - uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed - bool use_checkpoints = false; // use checkpoints to rewind in token history of recurrent models + uint16_t ngram_size_n = 12; // ngram size for lookup + uint16_t ngram_size_m = 48; // mgram size for speculative tokens + uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed std::shared_ptr ngram_mod; @@ -847,7 +846,23 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p // clear LoRA adapters from context, then apply new list of adapters void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora); -std::string get_model_endpoint(); +// model endpoint from env +std::string common_get_model_endpoint(); + +// +// Context utils +// + +enum common_context_seq_rm_type { + COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module) + COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences + COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only +}; + +// check if the llama_context can remove sequences +// note: clears the memory of the context +common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx); + // // Batch utils diff --git a/common/hf-cache.cpp b/common/hf-cache.cpp index 38a4c17a9..ea5b2150d 100644 --- a/common/hf-cache.cpp +++ b/common/hf-cache.cpp @@ -230,7 +230,7 @@ static nl::json api_get(const std::string & url, static std::string get_repo_commit(const std::string & repo_id, const std::string & token) { try { - auto endpoint = get_model_endpoint(); + auto endpoint = common_get_model_endpoint(); auto json = api_get(endpoint + "api/models/" + repo_id + "/refs", token); if (!json.is_object() || @@ -308,7 +308,7 @@ hf_files get_repo_files(const std::string & repo_id, hf_files files; try { - auto endpoint = get_model_endpoint(); + auto endpoint = common_get_model_endpoint(); auto json = api_get(endpoint + "api/models/" + repo_id + "/tree/" + commit + "?recursive=true", token); if (!json.is_array()) { diff --git a/common/speculative.cpp b/common/speculative.cpp index 1789560ee..daa2b5a8a 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -164,8 +164,8 @@ struct common_speculative_state_draft : public common_speculative_state { llama_context * ctx_tgt; // only used for retokenizing from ctx_dft llama_context * ctx_dft; + bool use_ckpt = false; struct common_speculative_checkpoint ckpt; - bool use_checkpoint; common_sampler * smpl; @@ -180,11 +180,11 @@ struct common_speculative_state_draft : public common_speculative_state { llama_context * ctx_tgt, llama_context * ctx_dft, const std::vector> & replacements, - bool use_checkpoint) + bool use_ckpt) : common_speculative_state(type) , ctx_tgt(ctx_tgt) , ctx_dft(ctx_dft) - , use_checkpoint(use_checkpoint) + , use_ckpt(use_ckpt) { batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); smpl = nullptr; @@ -239,7 +239,7 @@ struct common_speculative_state_draft : public common_speculative_state { } void begin(const llama_tokens & prompt) override { - if (use_checkpoint && ckpt.size() > 0) { + if (use_ckpt && ckpt.size() > 0) { // delete checkpoint LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n", __func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024); @@ -351,7 +351,7 @@ struct common_speculative_state_draft : public common_speculative_state { LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n", __func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size()); - if (use_checkpoint && ckpt.ckpt_size == 0 && reuse_n > 0) { + if (use_ckpt && ckpt.ckpt_size == 0 && reuse_n > 0) { LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n", __func__, reuse_i, reuse_n); reuse_i = 0; @@ -361,8 +361,8 @@ struct common_speculative_state_draft : public common_speculative_state { result.clear(); result.reserve(params.n_max); - bool needs_ckpt = use_checkpoint && prompt_dft.size() > 0; - if (reuse_n == 0 || (use_checkpoint && reuse_i > 0)) { + bool needs_ckpt = use_ckpt && prompt_dft.size() > 0; + if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) { llama_memory_clear(mem_dft, false); prompt_dft.clear(); } else { @@ -400,7 +400,7 @@ struct common_speculative_state_draft : public common_speculative_state { } if (reuse_n < (int) prompt_dft.size() || do_restore) { - if (use_checkpoint) { + if (use_ckpt) { if (ckpt.n_tokens > (int64_t) prompt_dft.size()) { LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n", __func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size()); @@ -912,42 +912,6 @@ enum common_speculative_type common_speculative_type_from_name(const std::string return it->second; } -common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt) { - auto * mem = llama_get_memory(ctx_tgt); - if (mem == nullptr) { - return COMMON_SPECULATIVE_COMPAT_TYPE_NO; - } - - common_speculative_compat_type res = COMMON_SPECULATIVE_COMPAT_TYPE_FULL; - - llama_memory_clear(mem, true); - - // eval 2 tokens to check if the context is compatible - std::vector tmp; - tmp.push_back(0); - tmp.push_back(0); - - int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size())); - if (ret != 0) { - LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret); - res = COMMON_SPECULATIVE_COMPAT_TYPE_NO; - goto done; - } - - // try to remove the last tokens - if (!llama_memory_seq_rm(mem, 0, 1, -1)) { - LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); - res = COMMON_SPECULATIVE_COMPAT_TYPE_CKPT; - goto done; - } - -done: - llama_memory_clear(mem, true); - llama_synchronize(ctx_tgt); - - return res; -} - // initialization of the speculative decoding system // common_speculative * common_speculative_init( @@ -1022,11 +986,13 @@ common_speculative * common_speculative_init( case COMMON_SPECULATIVE_TYPE_NONE: break; case COMMON_SPECULATIVE_TYPE_DRAFT: { + const bool use_ckpt = common_context_can_seq_rm(ctx_dft) == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + impls.push_back(std::make_unique(config.type, - /* .ctx_tgt = */ ctx_tgt, - /* .ctx_dft = */ ctx_dft, - /* .replacements = */ params.replacements, - /* .use_checkpoint= */ params.use_checkpoints // TODO: this should be based on the draft model! + /* .ctx_tgt = */ ctx_tgt, + /* .ctx_dft = */ ctx_dft, + /* .replacements = */ params.replacements, + /* .use_ckpt = */ use_ckpt )); break; } diff --git a/common/speculative.h b/common/speculative.h index cbe6e5bdb..bca78d32b 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -14,16 +14,6 @@ enum common_speculative_type common_speculative_type_from_name(const std::string // convert type to string std::string common_speculative_type_to_str(enum common_speculative_type type); -enum common_speculative_compat_type { - COMMON_SPECULATIVE_COMPAT_TYPE_NO = 0, - COMMON_SPECULATIVE_COMPAT_TYPE_FULL = 1, - COMMON_SPECULATIVE_COMPAT_TYPE_CKPT = 2, -}; - -// check if the llama_context is compatible for speculative decoding -// note: clears the memory of the context -common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt); - common_speculative * common_speculative_init( common_params_speculative & params, llama_context * ctx_tgt); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 7ffe6a303..99856e6c3 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -78,9 +78,10 @@ enum server_state { struct server_slot { int id; - // TODO: change to unique_ptrs for consistency: llama_context * ctx = nullptr; + common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO; + // multimodal mtmd_context * mctx = nullptr; @@ -90,7 +91,6 @@ struct server_slot { server_prompt_checkpoint spec_ckpt; common_speculative_ptr spec; - // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state // see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837 std::unique_ptr task; @@ -343,7 +343,7 @@ struct server_slot { if (!spec_draft.empty()) { // we have a previous (partial) draft to reuse - if (task->params.speculative.use_checkpoints) { + if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { GGML_ASSERT(!spec_ckpt.empty()); } } else { @@ -362,15 +362,13 @@ struct server_slot { spec_draft.clear(); } - if (!spec_draft.empty() && params_spec.use_checkpoints) { + if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { const auto n_tokens = prompt.tokens.size(); - auto & ckpt = spec_ckpt; - - ckpt = server_get_checkpoint(ctx, this->id, n_tokens); + spec_ckpt = server_get_checkpoint(ctx, this->id, n_tokens); SLT_DBG(*this, "created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n", - ckpt.pos_min, ckpt.pos_max, n_tokens, (float) ckpt.data.size() / 1024 / 1024); + spec_ckpt.pos_min, spec_ckpt.pos_max, n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024); } } @@ -871,14 +869,13 @@ private: slots.clear(); - const auto spec_type = common_speculative_is_compat(ctx); - if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_NO) { + const auto ctx_seq_rm_type = common_context_can_seq_rm(ctx); + if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO) { SRV_WRN("%s", "speculative decoding not supported by this context\n"); } - if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_CKPT) { + if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { SRV_WRN("%s", "speculative decoding will use checkpoints\n"); - params_base.speculative.use_checkpoints = true; } // initialize slots @@ -893,11 +890,13 @@ private: slot.ctx = ctx; slot.n_ctx = n_ctx_slot; + slot.ctx_seq_rm_type = ctx_seq_rm_type; + slot.mctx = mctx; slot.prompt.tokens.has_mtmd = mctx != nullptr; // try speculative decoding - if (spec_type != COMMON_SPECULATIVE_COMPAT_TYPE_NO) { + if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) { slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx)); if (slot.spec) { @@ -2588,15 +2587,11 @@ private: // make a checkpoint of the parts of the memory that cannot be rolled back. // checkpoints are created only if: + // - the model does not support partial sequence removal // - the model uses SWA and we are not using `swa_full` - // - the model architecture is marked as recurrent or hybrid - // - // TODO: try to make this conditional on the context or the memory module, instead of the model type do_checkpoint = do_checkpoint && ( - llama_model_is_recurrent(model) || - llama_model_is_hybrid(model) || - (llama_model_n_swa(model) > 0 && !params_base.swa_full) - ); + (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) || + (llama_model_n_swa(model) > 0 && !params_base.swa_full)); bool has_mtmd = false; @@ -2965,8 +2960,6 @@ private: // verify and try to accept the draft { - const auto & params_spec = slot.task->params.speculative; - common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get())); GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); @@ -2979,13 +2972,14 @@ private: // check for partial draft acceptance if (accepted.size() < slot.spec_draft.size() + 1) { - if (params_spec.use_checkpoints) { + if (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { // partial acceptance is not supported by the context -> truncate the draft and restore the state slot.spec_draft = std::move(accepted); - auto & ckpt = slot.spec_ckpt; + const auto & ckpt = slot.spec_ckpt; - SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size()); + SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", + ckpt.pos_min, ckpt.pos_max, ckpt.size()); const size_t n = llama_state_seq_set_data_ext(slot.ctx, ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); if (n != ckpt.size()) { From 81df3f7cfaa6f99de14e792b38d5771bf427383e Mon Sep 17 00:00:00 2001 From: SamareshSingh <97642706+ssam18@users.noreply.github.com> Date: Mon, 20 Apr 2026 02:32:46 -0500 Subject: [PATCH 24/24] fix: GLM-DSA crash in llama-tokenize when using vocab_only (#22102) * llama: fix crash in print_info for GLM-DSA when vocab_only is set * addressed code review comments * cont : simplify --------- Co-authored-by: Georgi Gerganov --- src/llama-model.cpp | 186 ++++++++++++++++++++++---------------------- 1 file changed, 93 insertions(+), 93 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 4ded484dd..5f543e762 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8180,114 +8180,114 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); size_t i = 0; - for (auto label : classifier_labels) { + for (const auto & label : classifier_labels) { LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); } } - } - if (arch == LLM_ARCH_MAMBA || - arch == LLM_ARCH_MAMBA2 || - arch == LLM_ARCH_JAMBA || - arch == LLM_ARCH_FALCON_H1 || - arch == LLM_ARCH_PLAMO2 || - arch == LLM_ARCH_GRANITE_HYBRID || - arch == LLM_ARCH_QWEN3NEXT || - arch == LLM_ARCH_QWEN35 || - arch == LLM_ARCH_QWEN35MOE || - arch == LLM_ARCH_NEMOTRON_H || - arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); - LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); - LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); - LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); - LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); - LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); - } + if (arch == LLM_ARCH_MAMBA || + arch == LLM_ARCH_MAMBA2 || + arch == LLM_ARCH_JAMBA || + arch == LLM_ARCH_FALCON_H1 || + arch == LLM_ARCH_PLAMO2 || + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_QWEN3NEXT || + arch == LLM_ARCH_QWEN35 || + arch == LLM_ARCH_QWEN35MOE || + arch == LLM_ARCH_NEMOTRON_H || + arch == LLM_ARCH_NEMOTRON_H_MOE) { + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); + LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + } - LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); - if (pimpl->n_elements >= 1e12) { - LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); - } else if (pimpl->n_elements >= 1e9) { - LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); - } else if (pimpl->n_elements >= 1e6) { - LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); - } else { - LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); - } + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); + if (pimpl->n_elements >= 1e12) { + LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); + } else if (pimpl->n_elements >= 1e9) { + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); + } else if (pimpl->n_elements >= 1e6) { + LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); + } else { + LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); + } - // general kv - LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); + // general kv + LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); - if (arch == LLM_ARCH_DEEPSEEK) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - } + if (arch == LLM_ARCH_DEEPSEEK) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + } - if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); - LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); - LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); - LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - } + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); + LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); + LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + } - if (arch == LLM_ARCH_QWEN2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - } + if (arch == LLM_ARCH_QWEN2MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + } - if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - } + if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + } - if (arch == LLM_ARCH_MINICPM || - arch == LLM_ARCH_GRANITE || - arch == LLM_ARCH_GRANITE_MOE || - arch == LLM_ARCH_GRANITE_HYBRID || - arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); - LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); - LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - } + if (arch == LLM_ARCH_MINICPM || + arch == LLM_ARCH_GRANITE || + arch == LLM_ARCH_GRANITE_MOE || + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_NEMOTRON_H_MOE) { + LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); + LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); + LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + } - if (arch == LLM_ARCH_BAILINGMOE) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - } + if (arch == LLM_ARCH_BAILINGMOE) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + } - if (arch == LLM_ARCH_BAILINGMOE2) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); - } + if (arch == LLM_ARCH_BAILINGMOE2) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); + } - if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - } + if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + } - if (arch == LLM_ARCH_GROVEMOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); - LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); - LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); + if (arch == LLM_ARCH_GROVEMOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); + LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); + LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); + } } vocab.print_info();