From 098705a29e83d2352e9ff1a266afd8eccd6ebad2 Mon Sep 17 00:00:00 2001 From: Anav Prasad Date: Wed, 29 Apr 2026 11:39:56 -0700 Subject: [PATCH 01/11] CUDA: fuse SSM_CONV + ADD(bias) + SILU (#22478) --- ggml/src/ggml-cuda/ggml-cuda.cu | 35 ++++++++++++++++- ggml/src/ggml-cuda/ssm-conv.cu | 34 ++++++++++++++--- ggml/src/ggml-cuda/ssm-conv.cuh | 2 +- tests/test-backend-ops.cpp | 66 +++++++++++++++++++++++++++++++++ 4 files changed, 129 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index fd8dd9171..0e6f74685 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3556,6 +3556,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) { const ggml_tensor * ssm_conv = cgraph->nodes[node_idx]; const ggml_tensor * silu = cgraph->nodes[node_idx+1]; + if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) { + return false; + } if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { return false; @@ -3564,6 +3567,31 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return true; } + if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_ADD + && ops.begin()[2] == GGML_OP_UNARY && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) { + const ggml_tensor * ssm_conv = cgraph->nodes[node_idx]; + const ggml_tensor * add = cgraph->nodes[node_idx+1]; + const ggml_tensor * silu = cgraph->nodes[node_idx+2]; + if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) { + return false; + } + + if (ssm_conv->type != GGML_TYPE_F32 || add->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { + return false; + } + + // ADD must consume ssm_conv's output and broadcast a 1-D channel-wise bias. + const ggml_tensor * bias = (add->src[0] == ssm_conv) ? add->src[1] : add->src[0]; + if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) { + return false; + } + if (ggml_nelements(bias) != ssm_conv->ne[0] || bias->ne[0] != ssm_conv->ne[0]) { + return false; + } + + return true; + } + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL && unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) { const ggml_tensor * unary = cgraph->nodes[node_idx]; @@ -3966,8 +3994,13 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph return 1; } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { + ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]); + return 2; + } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { - ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1]); + ggml_cuda_op_ssm_conv(*cuda_ctx, node, /*bias_add_node=*/ nullptr, cgraph->nodes[i + 1]); return 1; } diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index b77cdc1c1..4841389fb 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -3,6 +3,7 @@ template static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, + const float * __restrict__ bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { @@ -27,6 +28,8 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float w[j] = w_block[tid * stride_w + j]; } + float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0.0f; + for (int64_t i = 0; i < n_t; i++) { float sumf = 0.0f; @@ -42,12 +45,14 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float for (size_t j = 0; j < d_conv; j++) { sumf += x[(i + j) % d_conv] * w[j]; } + sumf += b; y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } template static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1, + const float * __restrict__ bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { @@ -97,6 +102,8 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, w[j] = w_block[tid * stride_w + j]; } + float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0.0f; + // Compute from shared memory for (int64_t i = 0; i < local_n_t; i++) { float sumf = 0.0f; @@ -104,12 +111,13 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, for (size_t j = 0; j < d_conv; j++) { sumf += smem[tid * n_cols + i + j] * w[j]; } + sumf += b; y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } template -static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1, +static void ssm_conv_f32_cuda(const float * src0, const float * src1, const float * bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t, const int64_t n_s, cudaStream_t stream) { @@ -120,14 +128,14 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int constexpr int kNC = decltype(NC)::value; if (n_t <= 32) { const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); - ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, + ssm_conv_f32<<>>(src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { const int64_t split_n_t = 32; dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); const size_t smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float); ssm_conv_long_token_f32<<>>( - src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); + src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } }; @@ -140,11 +148,18 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int } } -void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) { +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node, ggml_tensor * silu_dst) { const struct ggml_tensor * src0 = dst->src[0]; // conv_x const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight + const bool fuse_bias = bias_add_node != nullptr; const bool fuse_silu = silu_dst != nullptr; + // bias always comes with silu. + GGML_ASSERT(!fuse_bias || fuse_silu); + + // The bias (when fused) is the non-conv operand of the ADD node. + const struct ggml_tensor * bias = fuse_bias ? (bias_add_node->src[0] == dst ? bias_add_node->src[1] : bias_add_node->src[0]) : nullptr; + // When fusing, write to silu_dst (the node downstream references). const struct ggml_tensor * out = fuse_silu ? silu_dst : dst; @@ -160,16 +175,23 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, g const float * src0_d = (const float *) src0->data; const float * src1_d = (const float *) src1->data; + const float * bias_d = fuse_bias ? (const float *) bias->data : nullptr; float * dst_d = (float *) out->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(out->type == GGML_TYPE_F32); + if (fuse_bias) { + GGML_ASSERT(bias->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(bias)); + GGML_ASSERT(ggml_nelements(bias) == nr); + } + if (fuse_silu) { - ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + ssm_conv_f32_cuda(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], out->nb[2], nc, nr, n_t, n_s, stream); } else { - ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + ssm_conv_f32_cuda(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], out->nb[2], nc, nr, n_t, n_s, stream); } } diff --git a/ggml/src/ggml-cuda/ssm-conv.cuh b/ggml/src/ggml-cuda/ssm-conv.cuh index f96a1cd24..8514ca849 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cuh +++ b/ggml/src/ggml-cuda/ssm-conv.cuh @@ -1,3 +1,3 @@ #include "common.cuh" -void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst = nullptr); +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node = nullptr, ggml_tensor * silu_dst = nullptr); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 941c20ce1..be2a6ef56 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3579,6 +3579,49 @@ struct test_ssm_conv : public test_case { } }; +// GGML_OP_SSM_CONV + GGML_OP_ADD (channel-wise bias, optional) + GGML_OP_UNARY(SILU) (fused operation) +struct test_ssm_conv_bias_silu : public test_case { + const ggml_type type; + const std::array ne_a; + const std::array ne_b; + const bool fuse_bias; + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "SSM_CONV_BIAS_SILU"; + } + + bool run_whole_graph() override { return true; } + + std::string vars() override { + return VARS_TO_STR4(type, ne_a, ne_b, fuse_bias); + } + + test_ssm_conv_bias_silu(ggml_type type, std::array ne_a, std::array ne_b, + bool fuse_bias) + : type(type), ne_a(ne_a), ne_b(ne_b), fuse_bias(fuse_bias) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data()); + ggml_set_name(a, "a"); + ggml_set_name(b, "b"); + + ggml_tensor * out = ggml_ssm_conv(ctx, a, b); + + if (fuse_bias) { + ggml_tensor * bias = ggml_new_tensor_1d(ctx, type, out->ne[0]); + ggml_set_name(bias, "bias"); + out = ggml_add(ctx, out, bias); + } + + out = ggml_silu(ctx, out); + + ggml_set_name(out, "out"); + return out; + } +}; + // GGML_OP_SSM_SCAN struct test_ssm_scan : public test_case { const ggml_type type; @@ -7977,6 +8020,27 @@ static std::vector> make_test_cases_eval() { } } + // fused ssm_conv + (optional) bias_add + silu. The bias-only graph (no silu) is intentionally + // not tested since there's no fusion for that pattern in ggml_cuda_can_fuse. + for (int64_t d_conv : {3, 4, 9}) { + for (int64_t d_inner : {1024, 1536, 2048}) { + for (bool fuse_bias : {false, true}) { + // short token path (n_t <= 32) + test_cases.emplace_back(new test_ssm_conv_bias_silu( + GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias)); + test_cases.emplace_back(new test_ssm_conv_bias_silu( + GGML_TYPE_F32, {2 * d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias)); + test_cases.emplace_back(new test_ssm_conv_bias_silu( + GGML_TYPE_F32, {d_conv, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}, fuse_bias)); + // long token path (n_t > 32) + test_cases.emplace_back(new test_ssm_conv_bias_silu( + GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias)); + test_cases.emplace_back(new test_ssm_conv_bias_silu( + GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}, fuse_bias)); + } + } + } + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1 test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2 test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64, 8, 2, 32, 4)); // Falcon-H1 @@ -8993,6 +9057,8 @@ static std::vector> make_test_cases_perf() { // Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1})); // prefill test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1})); // generate + test_cases.emplace_back(new test_ssm_conv_bias_silu(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1}, true)); // prefill + test_cases.emplace_back(new test_ssm_conv_bias_silu(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1}, true)); // generate test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate From 41a63be28eb9b8fda13452101993412d241f16de Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Wed, 29 Apr 2026 11:51:21 -0700 Subject: [PATCH 02/11] hexagon: make vmem and buffer-size configurable (#22487) * hexagon: allow host to set max vmem size We use a sane default but it's helpful to allow for an override if needed. * hexagon: add support for measuring vmem space and move pinned mmaping management to host * hexagon: update vmem checks to use uint64 * hexagon: bump op buffers to 16 (matches max mmaps) * hexagon: bump default vmem to 3.2GB * hexagon: add support for autodetecting vmem space and some logging cleanup in that area * hexagon: fix whitespace warnings * Update scripts/snapdragon/adb/run-cli.sh Co-authored-by: Pascal * hex-adb: fix run-completion script --------- Co-authored-by: Pascal --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 238 ++++++++++++++--------- ggml/src/ggml-hexagon/htp/htp-ctx.h | 4 +- ggml/src/ggml-hexagon/htp/htp-ops.h | 8 +- ggml/src/ggml-hexagon/htp/htp_iface.idl | 4 +- ggml/src/ggml-hexagon/htp/main.c | 27 ++- scripts/snapdragon/adb/run-cli.sh | 12 +- scripts/snapdragon/adb/run-completion.sh | 8 +- 7 files changed, 180 insertions(+), 121 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 0d9b5e289..9345da621 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -48,14 +48,16 @@ using intvec = std::vector; using uintvec = std::vector; using u32vec = std::vector; -static size_t opt_ndev = 1; -static size_t opt_nhvx = 0; // use all -static int opt_arch = 0; // autodetect -static int opt_etm = 0; -static int opt_verbose = 0; -static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu) -static int opt_hostbuf = 1; // hostbuf ON by default -static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only +static int opt_arch = 0; // autodetect +static size_t opt_ndev = 1; +static size_t opt_nhvx = 0; // use all +static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only +static size_t opt_vmem = HTP_OP_MAX_VMEM_DEFAULT; // max available va space for buffer mappings +static size_t opt_mbuf = 1ul * 1024 * 1024 * 1024; // max buffer size +static int opt_etm = 0; +static int opt_verbose = 0; +static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu) +static int opt_hostbuf = 1; // hostbuf ON by default // Default PMU events, if profiling with PMU (mode=2) is enabled // See https://docs.qualcomm.com/doc/80-N2040-60/topic/pmu-events.html @@ -66,6 +68,7 @@ static u32vec opt_pmu_evt { 0x3, 0x111, 0x100, 0x105, 0x240, 0x256, 0x7D, 0x8C } static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE; static int opt_opbatch = 1024; // max number of ops in a batch static int opt_opqueue = 16; // max number of pending batches + static std::regex* opt_opfilter = NULL; // regex of ops to not claim #define HEX_VERBOSE(...) \ @@ -110,7 +113,7 @@ static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct if (!opt_verbose) return; op_desc desc(op); - GGML_LOG_DEBUG("ggml-hex: %s supports-op %s : %s : %s : %s : %s : %s : %s\n", sess_name.c_str(), + GGML_LOG_DEBUG("ggml-hex: %s supports-op %s: %s : %s : %s : %s : %s : %s\n", sess_name.c_str(), ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no"); } @@ -118,8 +121,6 @@ static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_t uint32_t op_usec, uint32_t op_cycles, const uint32_t pmu[]) { if (!opt_profile) return; - op_desc desc(op); - char pmu_str[256] = ""; if (opt_profile > 1) { static_assert(HTP_PROF_PMU_NCNT == 8, "current implementation assumes 8 PMU counters"); @@ -127,6 +128,7 @@ static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_t pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]); } + op_desc desc(op); GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u%s\n", sess_name.c_str(), ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, op_usec, op_cycles, pmu_str); } @@ -191,33 +193,30 @@ struct ggml_hexagon_shared_buffer { bool mapped; bool pinned; - void mmap(bool pinned = false) { - int err = fastrpc_mmap(sess->domain_id, this->fd, (void *) this->base, 0, this->size, FASTRPC_MAP_FD_DELAYED); + void mmap() { + fastrpc_map_flags flags = this->pinned ? FASTRPC_MAP_FD : FASTRPC_MAP_FD_DELAYED; + + int err = fastrpc_mmap(sess->domain_id, this->fd, (void *) this->base, 0, this->size, flags); if (err != 0) { GGML_LOG_ERROR("ggml-hex: %s buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\n", sess->c_name(), sess->domain_id, this->size, this->fd, (unsigned) err); throw std::runtime_error("ggml-hex: fastrpc_mmap failed (see log for details)"); } - if (pinned) { - err = htp_iface_mmap(sess->handle, this->fd, this->size, pinned); - if (err != 0) { - GGML_LOG_ERROR("ggml-hex: %s buffer pinning failed : domain_id %d size %zu fd %d error 0x%08x\n", sess->c_name(), - sess->domain_id, this->size, this->fd, (unsigned) err); - throw std::runtime_error("ggml-hex: htp_iface_mmap failed (see log for details)"); - } - } - - this->mapped = true; - this->pinned = pinned; HEX_VERBOSE("ggml-hex: %s mapped buffer: base %p size %zu fd %d pinned %u\n", sess->c_name(), (void *) this->base, this->size, this->fd, pinned); + + this->mapped = true; } void unmap() { if (!this->mapped) return; - htp_iface_munmap(sess->handle, this->fd); + if (!this->pinned) { + // HTP might still hold a reference, tell it drop it + htp_iface_munmap(sess->handle, this->fd); + } + fastrpc_munmap(sess->domain_id, this->fd, (void *) this->base, this->size); HEX_VERBOSE("ggml-hex: %s unmapped buffer: base %p size %zu fd %d\n", sess->c_name(), @@ -227,7 +226,7 @@ struct ggml_hexagon_shared_buffer { this->fd = -1; } - void alloc(size_t size, bool pinned = false) { + void alloc(size_t size) { if (this->base) return; this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, size); @@ -245,8 +244,7 @@ struct ggml_hexagon_shared_buffer { HEX_VERBOSE("ggml-hex: %s allocated buffer: base %p size %zu fd %d pinned %d\n", sess->c_name(), (void *) this->base, this->size, this->fd, (int) pinned); - - mmap(pinned); + mmap(); } void free() { @@ -262,15 +260,14 @@ struct ggml_hexagon_shared_buffer { } ggml_hexagon_shared_buffer(ggml_hexagon_session * sess, size_t size, bool pinned = false) { - size += 4 * 1024; // extra page for padding - this->sess = sess; this->size = 0; this->base = nullptr; this->fd = -1; this->mapped = false; + this->pinned = pinned; - alloc(size, pinned); + alloc(size); } ~ggml_hexagon_shared_buffer() { @@ -1475,6 +1472,7 @@ static ggml_backend_buffer_t ggml_backend_hexagon_buffer_type_alloc_buffer( ggml_backend_buffer_type_t buffer_type, size_t size) { auto sess = static_cast(buffer_type->context)->sess; try { + size += 4 * 1024; // guard page ggml_hexagon_shared_buffer * sbuf = new ggml_hexagon_shared_buffer(sess, size); return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, sbuf, size); } catch (const std::exception & exc) { @@ -1487,6 +1485,7 @@ static ggml_backend_buffer_t ggml_backend_hexagon_repack_buffer_type_alloc_buffe ggml_backend_buffer_type_t buffer_type, size_t size) { auto sess = static_cast(buffer_type->context)->sess; try { + size += 4 * 1024; // guard page ggml_hexagon_shared_buffer * sbuf = new ggml_hexagon_shared_buffer(sess, size); return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, sbuf, size); } catch (const std::exception & exc) { @@ -1505,7 +1504,7 @@ static size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffe } static size_t ggml_backend_hexagon_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { - return 1UL * 1024 * 1024 * 1024; // 1GB per buffer + return opt_mbuf; // typically 1GB per buffer GGML_UNUSED(buffer_type); } @@ -1573,14 +1572,14 @@ struct ggml_hexagon_opbatch { d_map.clear(); } - ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t batch_size) { + ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t batch_size, size_t max_vmem) { this->sess = sess; n_bufs_max = HTP_OP_MAX_BUFS; n_ops_max = batch_size; n_tens_max = n_ops_max + n_ops_max * HTP_OP_MAX_INPUTS; - b_vmem_max = HTP_OP_MAX_VMEM; + b_vmem_max = max_vmem; ops.resize(n_ops_max); @@ -1592,6 +1591,9 @@ struct ggml_hexagon_opbatch { t_map.reserve(n_tens_max); d_map.reserve(n_tens_max); + GGML_LOG_INFO("ggml-hex: %s op batching: n-bufs %u n-tensors %u n-ops %u vmem %zu\n", + sess->c_name(), n_bufs_max, n_tens_max, n_ops_max, b_vmem_max); + reset(); } @@ -1925,6 +1927,8 @@ void ggml_hexagon_session::flush_batch() { // Bump pending flag (cleared in the session::flush once we get the response) this->op_pending++; // atomic inc + HEX_VERBOSE("ggml-hex: %s queue-opbatch: %p size %u\n", this->c_name(), dbuf.ptr, dbuf.size); + int err = dspqueue_write(this->queue, 0, 1, &dbuf, sizeof(req), (const uint8_t*) &req, DSPQUEUE_TIMEOUT); if (err != 0) { GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->c_name(), (unsigned) err); @@ -1944,6 +1948,35 @@ void ggml_hexagon_session::flush(bool all) { flush_pending(all); } +static size_t ggml_hexagon_measure_max_vmem(ggml_hexagon_session *sess) { + // Allocate a bunch pinned buffers till failure. + // This is kind of expensive but handy for figuring out exactly how much we can mmap on a specific device. + // Typically we're going to allocate all/most of these buffers anyway for the model weights. + + std::vector sbufs; + + const size_t MiB = 1024 * 1024; + const size_t GiB = MiB * 1024; + + size_t vmem = 0; + size_t step = 256u * MiB; + + try { + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB; + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB; + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB; + + while (1) { + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, step, true)); + vmem += step; + } + } catch (...) { } + + for (auto b : sbufs) { delete b; } + + return vmem - step; // backoff to account for overhead from internal mappings +} + void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->valid_session = false; this->valid_handle = false; @@ -1957,7 +1990,7 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->op_pending = 0; - GGML_LOG_INFO("ggml-hex: allocating new session: %s\n", this->name.c_str()); + GGML_LOG_DEBUG("ggml-hex: %s allocating new session\n", this->name.c_str()); domain * my_domain = get_domain(this->domain_id); if (my_domain == NULL) { @@ -2033,9 +2066,6 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->valid_handle = true; - GGML_LOG_INFO("ggml-hex: new session: %s : session-id %d domain-id %d uri %s handle 0x%lx\n", this->name.c_str(), - this->session_id, this->domain_id, session_uri, (unsigned long) this->handle); - // Enable FastRPC QoS mode { struct remote_rpc_control_latency l; @@ -2047,6 +2077,9 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { } } + GGML_LOG_INFO("ggml-hex: %s new session : session-id %d domain-id %d uri %s handle 0x%lx\n", this->c_name(), + this->session_id, this->domain_id, session_uri, (unsigned long) this->handle); + const size_t req_q_size = (sizeof(htp_opbatch_req) * opt_opqueue * 2) + 1024; const size_t rsp_q_size = (sizeof(htp_opbatch_rsp) * opt_opqueue * 2) + 1024; @@ -2091,13 +2124,19 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { } // Allocate buffers and state for op batching - this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch); this->op_queue = new ggml_hexagon_opqueue(this, opt_opbatch, opt_opqueue); - // Start processing op batch requests - err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx); + if (!opt_vmem) { + opt_vmem = ggml_hexagon_measure_max_vmem(this); + GGML_LOG_INFO("ggml-hex: %s measured max vmem %zu\n", this->c_name(), opt_vmem); + } + + this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch, opt_vmem); + + // Start dspqueue/opbatch processing + err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx, opt_vmem); if (err != 0) { - GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err); + GGML_LOG_ERROR("ggml-hex: %s failed to start session: 0x%08x\n", this->c_name(), (unsigned) err); throw std::runtime_error("ggml-hex: iface start failed (see log for details)"); } this->valid_iface = true; @@ -2108,17 +2147,17 @@ void ggml_hexagon_session::release() noexcept(true) { int err; - delete this->op_batch; - delete this->op_queue; - - // Stop the DSP-side service and close the queue if (this->valid_iface) { + // Stop dspqueue/opbatch processing err = htp_iface_stop(this->handle); if (err != 0) { GGML_ABORT("ggml-hex: htp_iface_stop failed: 0x%08x\n", (unsigned) err); } } + delete this->op_batch; + delete this->op_queue; + if (opt_etm) { err = htp_iface_etm(this->handle, 0); if (err != 0) { @@ -3380,21 +3419,6 @@ struct ggml_hexagon_registry { ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) { GGML_LOG_INFO("ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev %zu\n", opt_ndev); - if (!opt_arch) { - int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch); - if (err != 0) { - GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err); - opt_arch = 73; - } - } - -#if defined(__ANDROID__) - if (opt_arch < 75) { - opt_ndev = 1; - GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n"); - } -#endif - GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch); // Create devices / sessions @@ -3480,32 +3504,67 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { static_assert((unsigned int) HTP_TYPE_IQ4_NL == (unsigned int) GGML_TYPE_IQ4_NL, "please update hexagon_type to match ggml_type"); - const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE"); - const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF"); - const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE"); - const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); - const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); - const char * str_opfilter= getenv("GGML_HEXAGON_OPFILTER"); - const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); - const char * str_etm = getenv("GGML_HEXAGON_ETM"); - const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); - const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX"); - const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); - const char * str_arch = getenv("GGML_HEXAGON_ARCH"); + const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE"); + const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF"); + const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE"); + const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); + const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); + const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER"); + const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); + const char * str_etm = getenv("GGML_HEXAGON_ETM"); + const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); + const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX"); + const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); + const char * str_arch = getenv("GGML_HEXAGON_ARCH"); + const char * str_vmem = getenv("GGML_HEXAGON_VMEM"); + const char * str_mbuf = getenv("GGML_HEXAGON_MBUF"); + + // Init Arch first since it affects other defaults + if (!str_arch) { + int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err); + opt_arch = 73; + } + } else { + if (str_arch[0] == 'v' || str_arch[0] == 'V') { + str_arch++; + } + opt_arch = strtoul(str_arch, NULL, 0); + } + + size_t MiB = 1024 * 1024; + + // Update vmem default + opt_vmem = opt_arch >= 75 ? HTP_OP_MAX_VMEM_DEFAULT : 3000 * MiB; auto RE_ICASE = std::regex_constants::icase; - opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL; - opt_verbose = str_verbose ? atoi(str_verbose) : 0; - opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; - opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage; - opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; - opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; - opt_etm = str_etm ? atoi(str_etm) : 0; - opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; - opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; - opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; - opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL; + opt_verbose = str_verbose ? atoi(str_verbose) : 0; + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage; + opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; + opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; + opt_profile = str_profile ? atoi(str_profile) : 0; + opt_etm = str_etm ? atoi(str_etm) : 0; + opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; + opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; + opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_mbuf = str_mbuf ? strtoul(str_mbuf, NULL, 0) * MiB : opt_mbuf; + opt_vmem = str_vmem ? strtoul(str_vmem, NULL, 0) * MiB : opt_vmem; + + if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { + opt_ndev = GGML_HEXAGON_MAX_SESSIONS; + } + +#if defined(__ANDROID__) + if (opt_arch < 75) { + opt_ndev = 1; + GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n"); + } +#endif if (str_profile) { opt_pmu_evt = [&]() -> std::vector { @@ -3520,17 +3579,6 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { vec_to_str(opt_pmu_evt).c_str()); } - if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { - opt_ndev = GGML_HEXAGON_MAX_SESSIONS; - } - - if (str_arch) { - if (str_arch[0] == 'v') { - str_arch++; - } - opt_arch = strtoul(str_arch, NULL, 0); - } - reg->context = new ggml_hexagon_registry(reg); } diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index d704fedee..e9c563ca8 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -20,7 +20,7 @@ struct htp_mmap { uint64_t size; uint64_t base; uint32_t fd; - uint32_t pinned; + uint32_t reserved; }; // Scratchpad state @@ -77,6 +77,8 @@ struct htp_context { atomic_bool vtcm_valid; atomic_bool vtcm_needs_release; + uint64_t max_vmem; + struct htp_ops_context octx; #ifdef HTP_HAS_HMX diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 4397245c5..66a3150c1 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -90,15 +90,11 @@ enum htp_op_code { #define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS #define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS -#define HTP_OP_MAX_BUFS 8 +#define HTP_OP_MAX_BUFS 16 #define HTP_OP_MAX_REQS 256 #define HTP_OP_MAX_TENSORS (HTP_OP_MAX_REQS * HTP_OP_MAX_INPUTS + HTP_OP_MAX_REQS) -#if __HVX_ARCH__ < 75 -#define HTP_OP_MAX_VMEM (3167538380u) -#else -#define HTP_OP_MAX_VMEM (3221225472u) -#endif +#define HTP_OP_MAX_VMEM_DEFAULT (3355443200u) #define HTP_MMAP_MAX_VMEM (2147483648u) diff --git a/ggml/src/ggml-hexagon/htp/htp_iface.idl b/ggml/src/ggml-hexagon/htp/htp_iface.idl index dbcafd1d8..d696a5fba 100644 --- a/ggml/src/ggml-hexagon/htp/htp_iface.idl +++ b/ggml/src/ggml-hexagon/htp/htp_iface.idl @@ -11,9 +11,9 @@ struct htp_iface_pmu_conf { }; interface htp_iface : remote_handle64 { - AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx); + AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx, in uint64 max_vmem); AEEResult stop(); - AEEResult mmap(in uint32 fd, in uint32 size, in uint32 pinned); + AEEResult mmap(in uint32 fd, in uint32 size); AEEResult munmap(in uint32 fd); AEEResult profiler(in uint32 mode, in htp_iface_pmu_conf pmu); AEEResult etm(in uint32 enable); diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index f58347304..49c1a15b3 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -210,7 +210,7 @@ AEEResult htp_iface_close(remote_handle64 handle) { return AEE_SUCCESS; } -AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32 pinned) { +AEEResult htp_iface_mmap(remote_handle64 handle, uint32_t fd, uint32_t size) { struct htp_context * ctx = (struct htp_context *) handle; if (!ctx) { return AEE_EBADPARM; @@ -220,7 +220,6 @@ AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32 for (uint32_t i=0; immap[i]; if (m->fd == fd) { - m->pinned = pinned; return AEE_SUCCESS; } } @@ -229,7 +228,7 @@ AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32 for (uint32_t i=0; immap[i]; if (!m->size) { - FARF(HIGH, "mmap : fd %u size %u pinned %u", fd, size, pinned); + FARF(HIGH, "mmap : fd %u size %u", fd, size); #if __HVX_ARCH__ > 73 void *va = HAP_mmap2(NULL, size, HAP_PROT_READ | HAP_PROT_WRITE, 0, fd, 0); #else @@ -248,7 +247,6 @@ AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32 m->base = (uint64_t) va; m->fd = fd; m->size = size; - m->pinned = pinned; return AEE_SUCCESS; } @@ -275,7 +273,6 @@ AEEResult htp_iface_munmap(remote_handle64 handle, uint32 fd) { m->size = 0; m->base = NULL; m->fd = -1; - m->pinned = 0; } } @@ -358,7 +355,7 @@ static void vtcm_free(struct htp_context * ctx) { static void htp_packet_callback(dspqueue_t queue, int error, void * context); static void htp_error_callback(dspqueue_t queue, int error, void * context); -AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx) { +AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx, uint64_t max_vmem) { struct htp_context * ctx = (struct htp_context *) handle; if (!ctx) { @@ -376,12 +373,12 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que htp_error_callback, // Error callback; no errors expected on the DSP (void *) ctx, // Callback context &ctx->queue); - if (err) { FARF(ERROR, "Queue import failed with 0x%08x", (unsigned) err); return err; } + ctx->max_vmem = max_vmem; ctx->thread_id = qurt_thread_get_id(); ctx->thread_prio = qurt_thread_get_priority(ctx->thread_id); @@ -622,8 +619,8 @@ static inline bool reuse_buf(struct htp_context *ctx, uint32_t *m_reuse, struct } static inline void drop_mmap(struct htp_context *ctx, struct htp_mmap *m) { - if (m->size && !m->pinned) { - FARF(HIGH, "unmap : fd %u base %p size %u pinned %u", m->fd, (void*) m->base, (uint32_t) m->size, m->pinned); + if (m->size) { + FARF(HIGH, "unmap : fd %u base %p size %u", m->fd, (void*) m->base, (uint32_t) m->size); #if __HVX_ARCH__ > 73 HAP_munmap2((void *) m->base, m->size); #else @@ -660,9 +657,8 @@ static inline void mmap_buf(struct htp_context *ctx, struct htp_buf_desc *b) { m->base = b->base = (uint64_t) va; m->fd = b->fd; m->size = b->size; - m->pinned = 0; - FARF(HIGH, "mmap : fd %u base %p size %u pinned %u", m->fd, (void*) m->base, (uint32_t) m->size, m->pinned); + FARF(HIGH, "mmap : fd %u base %p size %u", m->fd, (void*) m->base, (uint32_t) m->size); return; } } @@ -672,8 +668,8 @@ static void prep_op_bufs(struct htp_context *ctx, struct htp_buf_desc *bufs, uin uint32_t m_reuse = 0; // mmap reuse mask (index from ctx->mmap array) uint32_t b_reuse = 0; // buf reuse count - size_t m_vmem = 0; // mapped vmem - size_t e_vmem = 0; // extra vmem + uint64_t m_vmem = 0; // mapped vmem + uint64_t e_vmem = 0; // extra vmem // See what we can reuse for (uint32_t i=0; i < n_bufs; i++) { @@ -687,9 +683,10 @@ static void prep_op_bufs(struct htp_context *ctx, struct htp_buf_desc *bufs, uin // See how much vmem we have mmaped right now for (uint32_t i=0; immap[i].size; } - FARF(HIGH, "prep-bufs : pass1 mmap-vmem %zu extra-vmem %zu n-bufs %u b-reuse %u", m_vmem, e_vmem, n_bufs, b_reuse); + FARF(HIGH, "prep-bufs : pass1 mmap-vmem %zu extra-vmem %zu max-vmem %zu : n-bufs %u b-reuse %u", + (size_t) m_vmem, (size_t) e_vmem, (size_t) ctx->max_vmem, n_bufs, b_reuse); - if ((m_vmem + e_vmem) > HTP_OP_MAX_VMEM) { + if ((m_vmem + e_vmem) > ctx->max_vmem) { // Drop unused mappings for (uint32_t i=0; i < HTP_MAX_MMAPS; i++) { bool used = m_reuse & (1< Date: Wed, 29 Apr 2026 14:10:58 -0500 Subject: [PATCH 03/11] common : do not pass prompt tokens to reasoning budget sampler (#22488) --- common/reasoning-budget.cpp | 28 ----------------- common/reasoning-budget.h | 17 ++-------- common/sampling.cpp | 63 +++++++++++++++++++++---------------- common/sampling.h | 4 +-- 4 files changed, 40 insertions(+), 72 deletions(-) diff --git a/common/reasoning-budget.cpp b/common/reasoning-budget.cpp index 74fce5367..c6e1f86c9 100644 --- a/common/reasoning-budget.cpp +++ b/common/reasoning-budget.cpp @@ -232,34 +232,6 @@ static struct llama_sampler * common_reasoning_budget_init_state( ); } -struct llama_sampler * common_reasoning_budget_init( - const struct llama_vocab * vocab, - const std::vector & start_tokens, - const std::vector & end_tokens, - const std::vector & forced_tokens, - int32_t budget, - const std::vector & prefill_tokens) { - // Determine initial state from prefill: COUNTING if the prefill begins with - // the start sequence but does not also contain the end sequence after it. - common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE; - if (!prefill_tokens.empty() && !start_tokens.empty() && - prefill_tokens.size() >= start_tokens.size() && - std::equal(start_tokens.begin(), start_tokens.end(), prefill_tokens.begin())) { - initial_state = REASONING_BUDGET_COUNTING; - // If the end sequence also follows the start in the prefill, reasoning - // was opened and immediately closed — stay IDLE. - if (!end_tokens.empty() && - prefill_tokens.size() >= start_tokens.size() + end_tokens.size()) { - auto end_start = prefill_tokens.end() - (ptrdiff_t) end_tokens.size(); - if (end_start >= prefill_tokens.begin() + (ptrdiff_t) start_tokens.size() && - std::equal(end_tokens.begin(), end_tokens.end(), end_start)) { - initial_state = REASONING_BUDGET_IDLE; - } - } - } - return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state); -} - struct llama_sampler * common_reasoning_budget_init( const struct llama_vocab * vocab, const std::vector & start_tokens, diff --git a/common/reasoning-budget.h b/common/reasoning-budget.h index ee1a30ed3..ef37f46ee 100644 --- a/common/reasoning-budget.h +++ b/common/reasoning-budget.h @@ -29,10 +29,7 @@ enum common_reasoning_budget_state { // end_tokens - token sequence for natural deactivation // forced_tokens - token sequence forced when budget expires // budget - max tokens allowed in the reasoning block -// prefill_tokens - tokens already present in the prompt (generation prompt); -// used to determine the initial state: COUNTING if they begin -// with start_tokens (but don't also end with end_tokens), -// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING. +// initial_state - initial state // struct llama_sampler * common_reasoning_budget_init( const struct llama_vocab * vocab, @@ -40,16 +37,6 @@ struct llama_sampler * common_reasoning_budget_init( const std::vector & end_tokens, const std::vector & forced_tokens, int32_t budget, - const std::vector & prefill_tokens = {}); - -// Variant that takes an explicit initial state (used by tests and clone). -// COUNTING with budget <= 0 is promoted to FORCING. -struct llama_sampler * common_reasoning_budget_init( - const struct llama_vocab * vocab, - const std::vector & start_tokens, - const std::vector & end_tokens, - const std::vector & forced_tokens, - int32_t budget, - common_reasoning_budget_state initial_state); + common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE); common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl); diff --git a/common/sampling.cpp b/common/sampling.cpp index b2e6d8e8d..d4a2fdcda 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -260,32 +260,35 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st } } + // Compute prefill tokens from the generation prompt + std::vector prefill_tokens; + if (!params.generation_prompt.empty()) { + GGML_ASSERT(vocab != nullptr); + auto tokens = common_tokenize(vocab, params.generation_prompt, false, true); + for (size_t i = 0; i < tokens.size(); i++) { + std::string piece = common_token_to_piece(vocab, tokens[i], true); + if (i == 0 && std::isspace(piece[0]) && !std::isspace(params.generation_prompt[0])) { + // Some tokenizers will add a space before the first special token, need to exclude + continue; + } + LOG_DBG("%s: prefill token: %d = %s\n", __func__, tokens[i], piece.c_str()); + prefill_tokens.push_back(tokens[i]); + } + } + // Feed generation prompt tokens to the grammar sampler so it advances past // tokens the template already placed in the prompt. // Only applies to output-format and tool-call grammars; user-supplied grammars must not be prefilled. - std::vector prefill_tokens; - if (!params.generation_prompt.empty() && common_grammar_needs_prefill(params.grammar)) { - GGML_ASSERT(vocab != nullptr); - prefill_tokens = common_tokenize(vocab, params.generation_prompt, false, true); - if (!prefill_tokens.empty()) { - std::string first_token = common_token_to_piece(vocab, prefill_tokens[0], true); - if (std::isspace(first_token[0]) && !std::isspace(params.generation_prompt[0])) { - // Some tokenizers will add a space before the first special token, need to remove - prefill_tokens = std::vector(prefill_tokens.begin() + 1, prefill_tokens.end()); - } - } - - if (grmr && !params.grammar_lazy) { - try { - for (const auto & token : prefill_tokens) { - llama_sampler_accept(grmr, token); - LOG_DBG("%s: accepted prefill token (%d)\n", __func__, token); - } - } catch (std::exception &e) { - LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__, - common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str()); - throw e; + if (grmr && !params.grammar_lazy && common_grammar_needs_prefill(params.grammar)) { + try { + for (const auto & token : prefill_tokens) { + llama_sampler_accept(grmr, token); + LOG_DBG("%s: grammar accepted prefill token (%d)\n", __func__, token); } + } catch (std::exception &e) { + LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__, + common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str()); + throw e; } } @@ -296,8 +299,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st params.reasoning_budget_start, params.reasoning_budget_end, params.reasoning_budget_forced, - params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens, - prefill_tokens); + params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens); + + for (const auto & token : prefill_tokens) { + llama_sampler_accept(rbudget, token); + LOG_DBG("%s: reasoning-budget accepted prefill token (%d)\n", __func__, token); + } } if (params.has_logit_bias()) { @@ -431,7 +438,7 @@ static bool grammar_should_apply(struct common_sampler * gsmpl) { return true; } -void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) { +void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool is_generated) { if (!gsmpl) { return; } @@ -439,9 +446,11 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo const auto tm = gsmpl->tm(); // grammar_should_apply() checks the reasoning budget state, so calculate this before we accept - accept_grammar = accept_grammar && grammar_should_apply(gsmpl); + const auto accept_grammar = is_generated && grammar_should_apply(gsmpl); - llama_sampler_accept(gsmpl->rbudget, token); + if (gsmpl->rbudget && is_generated) { + llama_sampler_accept(gsmpl->rbudget, token); + } if (gsmpl->grmr && accept_grammar) { llama_sampler_accept(gsmpl->grmr, token); diff --git a/common/sampling.h b/common/sampling.h index 5b57ad658..49506a00c 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -41,8 +41,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st void common_sampler_free(struct common_sampler * gsmpl); -// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar -void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar); +// if is_generated is true, the token is accepted by the sampling chain, the reasoning budget sampler, and the grammar sampler +void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool is_generated); void common_sampler_reset (struct common_sampler * gsmpl); struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl); From b42c7fa5b835a87be53a89b062571070f86b86fe Mon Sep 17 00:00:00 2001 From: Peter Sideris Date: Thu, 30 Apr 2026 08:18:25 +0300 Subject: [PATCH 04/11] spec : fix vocab compat checks in spec example (#22426) * port #22358 PR to examples/speculative/speculative.cpp * use vocab_[tgt,dft] instead of ctx_[tgt,dft] when logging on draft model / target model vocabulary mismatch Co-authored-by: Petros Sideris --- examples/speculative/speculative.cpp | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 6ed9c9143..f7fa5e306 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -110,13 +110,21 @@ int main(int argc, char ** argv) { return 1; } - if ( - llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) || - llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) || - llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) || - llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft) - ) { - LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__); + if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) || + (llama_vocab_get_add_bos(vocab_tgt) && llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft))) { + LOG_ERR("%s: draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n", + __func__, + llama_vocab_get_add_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_dft), + llama_vocab_bos(vocab_tgt), llama_vocab_bos(vocab_dft)); + return 1; + } + + if (llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) || + (llama_vocab_get_add_eos(vocab_tgt) && llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft))) { + LOG_ERR("%s: draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n", + __func__, + llama_vocab_get_add_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_dft), + llama_vocab_eos(vocab_tgt), llama_vocab_eos(vocab_dft)); return 1; } @@ -137,11 +145,12 @@ int main(int argc, char ** argv) { for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); + if (std::strcmp(token_text_tgt, token_text_dft) != 0) { LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__); LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i, - common_token_to_piece(ctx_tgt, i).c_str(), - common_token_to_piece(ctx_dft, i).c_str()); + common_token_to_piece(vocab_tgt, i).c_str(), + common_token_to_piece(vocab_dft, i).c_str()); return 1; } } From 80afa33aadcc4f71212b17e5e52904491c76b63e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 30 Apr 2026 08:32:18 +0300 Subject: [PATCH 05/11] spec : fix draft model checkpoints (#22521) * spec : fix draft model checkpoints * cont : clean-up * cont : gate the ngram-mod reset warning behind verbose flag --- common/speculative.cpp | 99 +++++++++++++++------------------ tools/server/server-context.cpp | 22 ++++++-- 2 files changed, 62 insertions(+), 59 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index bda9993b1..bbf88fa6e 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -167,8 +167,6 @@ struct common_speculative_checkpoint { size_t size() const { return data.size(); } - - size_t ckpt_size = 0; }; struct common_speculative_state_draft : public common_speculative_state { @@ -176,7 +174,7 @@ struct common_speculative_state_draft : public common_speculative_state { llama_context * ctx_dft; bool use_ckpt = false; - struct common_speculative_checkpoint ckpt; + common_speculative_checkpoint ckpt; common_sampler * smpl; @@ -249,26 +247,16 @@ struct common_speculative_state_draft : public common_speculative_state { llama_batch_free(batch); } - void begin(const llama_tokens & prompt) override { - if (use_ckpt && ckpt.size() > 0) { - // delete checkpoint - LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n", - __func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024); - ckpt.pos_min = 0; - ckpt.pos_max = 0; - ckpt.n_tokens = 0; - ckpt.ckpt_size = 0; - ckpt.data.clear(); - } + void begin(const llama_tokens & /*prompt*/) override { } - size_t draft_create_checkpoint(int n_tokens_prompt, int n_tokens_batch) { + size_t create_checkpoint(int n_tokens_prompt) { int slot_id = 0; const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id); ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id); - ckpt.n_tokens = n_tokens_prompt - n_tokens_batch; + ckpt.n_tokens = n_tokens_prompt; ckpt.data.resize(checkpoint_size); const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); @@ -281,13 +269,13 @@ struct common_speculative_state_draft : public common_speculative_state { return n; } - size_t draft_restore_checkpoint(size_t ckpt_size_part_expected) { + size_t restore_checkpoint() { int slot_id = 0; LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max); const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - if (n != ckpt_size_part_expected) { - GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu", - __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n); + if (n != ckpt.size()) { + GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu", + __func__, ckpt.pos_min, ckpt.pos_max, ckpt.size()); } llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1); @@ -346,13 +334,18 @@ struct common_speculative_state_draft : public common_speculative_state { const int i_start = std::max(0, (int) prompt_cur.size() - n_ctx); + if (use_ckpt && i_start > 0) { + LOG_WRN("%s: context shift is not supported with checkpoint-based contexts - skipping\n", __func__); + return; + } + // reuse as much as possible from the old draft context // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt for (int i = 0; i < (int) prompt_dft.size(); ++i) { int cur = 0; while (i_start + cur < (int) prompt_cur.size() && - i + cur < (int) prompt_dft.size() && - prompt_cur[i_start + cur] == prompt_dft[i + cur]) { + i + cur < (int) prompt_dft.size() && + prompt_cur[i_start + cur] == prompt_dft[i + cur]) { cur++; } @@ -360,21 +353,26 @@ struct common_speculative_state_draft : public common_speculative_state { reuse_i = i; reuse_n = cur; } + + if (use_ckpt) { + break; + } } LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n", __func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size()); - if (use_ckpt && ckpt.ckpt_size == 0 && reuse_n > 0) { - LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n", - __func__, reuse_i, reuse_n); + if (use_ckpt && ckpt.n_tokens > reuse_n) { + LOG_DBG("%s: checkpoint (n_tokens = %d) is outdated -> delete it\n", __func__, (int) ckpt.n_tokens); + reuse_i = 0; reuse_n = 0; + + ckpt = {}; } result.clear(); result.reserve(sparams.n_max); - bool needs_ckpt = use_ckpt && prompt_dft.size() > 0; if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) { llama_memory_clear(mem_dft, false); prompt_dft.clear(); @@ -393,50 +391,38 @@ struct common_speculative_state_draft : public common_speculative_state { return; } - bool do_restore = false; - if (prompt_dft.size() > prompt_cur.size() && reuse_i + reuse_n < (int64_t) prompt_dft.size()) { - // This can happen after a partial acceptance (speculative decoding with checkpoints) - LOG_DBG("%s: #prompt_dft=%zu, #prompt_cur=%zu, shorten draft\n", - __func__, prompt_dft.size(), prompt_cur.size()); - prompt_dft.resize(prompt_cur.size()); - do_restore = true; - } - if (reuse_i > 0) { + GGML_ASSERT(!use_ckpt); + bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); if (!is_removed) { LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i); + return; } llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i); prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); } - if (reuse_n < (int) prompt_dft.size() || do_restore) { + if (reuse_n < (int) prompt_dft.size()) { if (use_ckpt) { - if (ckpt.n_tokens > (int64_t) prompt_dft.size()) { - LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n", - __func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size()); + if (ckpt.n_tokens > 0) { + LOG_DBG("%s: restoring checkpoint, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size()); + restore_checkpoint(); + reuse_n = ckpt.n_tokens; + prompt_dft.resize(reuse_n); } - draft_restore_checkpoint(ckpt.ckpt_size); - reuse_n = ckpt.n_tokens; - prompt_dft.resize(reuse_n); - needs_ckpt = false; } else { - bool is_removed = llama_memory_seq_rm (mem_dft, 0, reuse_n, -1); + const bool is_removed = llama_memory_seq_rm(mem_dft, 0, reuse_n, -1); if (!is_removed) { - LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", - __func__, reuse_n, prompt_dft.size()); + LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n", __func__, reuse_n, prompt_dft.size()); + return; } prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); } } } - if (needs_ckpt) { - ckpt.ckpt_size = draft_create_checkpoint(prompt_dft.size(), batch.n_tokens); - } - // prepare a batch to evaluate any new tokens in the prompt common_batch_clear(batch); @@ -450,12 +436,17 @@ struct common_speculative_state_draft : public common_speculative_state { // we should rarely end-up here during normal decoding if (batch.n_tokens > 0) { //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); + LOG_DBG("%s: draft prompt batch: %d tokens\n", __func__, batch.n_tokens); int ret = llama_decode(ctx_dft, batch); if (ret != 0 && ret != 1) { LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n", __func__, ret, prompt_cur.size()); } + + if (use_ckpt) { + create_checkpoint(prompt_dft.size()); + } } const llama_pos n_past = prompt_dft.size(); @@ -784,17 +775,15 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { } void accept(uint16_t n_accepted) override { - if (verbose) { - LOG_INF("%s: accepted %d tokens from %zu drafted tokens\n", __func__, n_accepted, n_draft_last); - } - // compute acceptance fraction if we have a recorded draft length if (n_draft_last > 0) { const double f_acc = (double)n_accepted / (double)n_draft_last; if (f_acc < 0.5) { n_low++; if (n_low >= 3) { - LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low); + if (verbose) { + LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, n_low); + } mod.reset(); n_low = 0; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index ee8366d28..2d3003f03 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -680,6 +680,7 @@ private: // slots / clients std::vector slots; + int trace = 0; int slots_debug = 0; int n_empty_consecutive = 0; @@ -918,12 +919,21 @@ private: slot.reset(); } + { + const char * LLAMA_TRACE = getenv("LLAMA_TRACE"); + trace = LLAMA_TRACE ? atoi(LLAMA_TRACE) : 0; + + if (trace) { + SRV_WRN("LLAMA_TRACE = %d\n", trace); + } + } + { const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG"); slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0; if (slots_debug) { - SRV_WRN("slots debug = %d\n", slots_debug); + SRV_WRN("LLAMA_SERVER_SLOTS_DEBUG = %d\n", slots_debug); } } @@ -2974,13 +2984,15 @@ private: auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft); slot.spec_i_batch.clear(); - SLT_DBG(slot, "%s: n_draft=%zu, accepted=%zu\n", __func__, slot.spec_draft.size(), accepted.size()); - GGML_ASSERT(accepted.size() >= 1); // check for partial draft acceptance if (accepted.size() < slot.spec_draft.size() + 1) { if (use_ckpt) { + if (trace > 0) { + SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size()); + } + // partial acceptance is not supported by the context -> truncate the draft and restore the state slot.spec_draft = std::move(accepted); @@ -3002,8 +3014,10 @@ private: continue; } + } - LOG_DBG("%s: partial acceptance: %zu < %zu\n", __func__, accepted.size(), slot.spec_draft.size()); + if (trace > 0) { + SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft); } common_speculative_accept(slot.spec.get(), accepted.size() - 1); From 45155597aa23243c5f6d10064bd9bca3eaddee16 Mon Sep 17 00:00:00 2001 From: Rithik Sharma Date: Wed, 29 Apr 2026 22:58:32 -0700 Subject: [PATCH 06/11] add fast matmul iquants (#22504) --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 19 + ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 +- .../wgsl-shaders/mul_mat_decls.tmpl | 423 ++++++++++++++++++ 3 files changed, 443 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index b7771ac23..5239164cd 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1806,6 +1806,25 @@ class ggml_webgpu_shader_lib { defines.push_back("U32_DEQUANT_HELPERS"); defines.push_back("SRC0_INNER_TYPE=u32"); + switch (context.src0->type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + defines.push_back(type_upper + "_GRID"); + break; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_TABLES"); + break; + default: + break; + } + variant += std::string("_") + src0_name; break; } diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f7fd73ae1..5e55a2a1e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1422,7 +1422,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: - use_fast = is_vec; + use_fast = true; break; default: break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 15b22c4f7..51cf08f19 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -740,3 +740,426 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } #endif // INIT_SRC0_SHMEM_Q6_K + +#ifdef INIT_SRC0_SHMEM_IQ4_NL +const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 18u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_at_src0(block_byte_base); + + let pos = k_in_block % 16u; + let nib_shift = (k_in_block / 16u) * 4u; + let q_packed = load_u32_at_src0(block_byte_base + 2u + (pos / 4u) * 4u); + let nib = (get_byte(q_packed, pos % 4u) >> nib_shift) & 0xFu; + + shmem[elem_idx] = d * f16(kvalues_iq4nl[nib]); + } +} +#endif // INIT_SRC0_SHMEM_IQ4_NL + +#ifdef INIT_SRC0_SHMEM_IQ4_XS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 136u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + + let d_scales_h = load_u32_at_src0(block_byte_base); + let d = bitcast>(d_scales_h).x; + let scales_h = d_scales_h >> 16u; + + let ib = k_in_block / 32u; + let pos = k_in_block % 32u; + + let scales_l_word = load_u32_at_src0(block_byte_base + 4u); + let ls_lo = (get_byte(scales_l_word, ib / 2u) >> ((ib & 1u) * 4u)) & 0xFu; + let ls_hi = ((scales_h >> (2u * ib)) & 3u) << 4u; + let dl = d * f16(i32(ls_lo | ls_hi) - 32); + + let iqs = ib * 16u + (pos % 16u); + let nib_shift = (pos / 16u) * 4u; + let q_packed = load_u32_at_src0(block_byte_base + 8u + (iqs / 4u) * 4u); + let nib = (get_byte(q_packed, iqs % 4u) >> nib_shift) & 0xFu; + + shmem[elem_idx] = dl * f16(kvalues_iq4nl[nib]); + } +} +#endif // INIT_SRC0_SHMEM_IQ4_XS + +#ifdef INIT_SRC0_SHMEM_IQ1_S +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 50u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib = k_in_block / 32u; + let pos = k_in_block % 32u; + let l = pos / 8u; + let j = pos % 8u; + + let qh = load_u32_at_src0(block_byte_base + 34u + ib * 2u) & 0xFFFFu; + let dl = d * (2.0 * f32((qh >> 12u) & 7u) + 1.0); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); + + let qs_w = load_u32_at_src0(block_byte_base + 2u + ib * 4u); + let ig = (get_byte(qs_w, l) | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; + + let gw = iq1_grid[(ig + j) / 16u]; + let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u; + let gs = bitcast(g << 30u) >> 30u; + + shmem[elem_idx] = f16(dl * (f32(gs) + delta)); + } +} +#endif // INIT_SRC0_SHMEM_IQ1_S + +#ifdef INIT_SRC0_SHMEM_IQ1_M +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 56u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + + let scales0 = load_u32_at_src0(block_byte_base + 48u); + let scales1 = load_u32_at_src0(block_byte_base + 52u); + let scale_packed = ((scales0 >> 12u) & 0xFu) | + ((scales0 >> 24u) & 0x00F0u) | + ((scales1 >> 4u) & 0x0F00u) | + ((scales1 >> 16u) & 0xF000u); + let d = f32(bitcast>(scale_packed).x); + + let ib = k_in_block / 32u; + let pos = k_in_block % 32u; + let l = pos / 8u; + let j = pos % 8u; + + let scales = select(scales0, scales1, ib >= 4u); + let sw = (scales >> (16u * ((ib / 2u) % 2u))) & 0xFFFFu; + let s_pair = (sw >> (6u * (ib % 2u) + 3u * (l / 2u))) & 0x7u; + let dl = d * f32(2u * s_pair + 1u); + + let qh_word = load_u32_at_src0(block_byte_base + 32u + (ib / 2u) * 4u); + let qh = qh_word >> (16u * (ib % 2u)); + let qh_nib = (qh >> (4u * l)) & 0xFu; + + let qs_w = load_u32_at_src0(block_byte_base + ib * 4u); + let idx = get_byte(qs_w, l) | ((qh_nib & 7u) << 8u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_nib & 0x8u) != 0u); + + let ig = idx * 8u; + let gw = iq1_grid[(ig + j) / 16u]; + let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u; + let gs = bitcast(g << 30u) >> 30u; + + shmem[elem_idx] = f16(dl * (f32(gs) + delta)); + } +} +#endif // INIT_SRC0_SHMEM_IQ1_M + +#ifdef INIT_SRC0_SHMEM_IQ2_XXS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 66u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let entry_idx = k_in_block / 8u; + let j = k_in_block % 8u; + + let ib = entry_idx & ~3u; + let l = entry_idx & 3u; + + let aux0 = load_u32_at_src0(block_byte_base + 2u + ib * 2u); + let aux1 = load_u32_at_src0(block_byte_base + 2u + (ib + 2u) * 2u); + let db = d * (0.5 + f32(aux1 >> 28u)) * 0.25; + + let ig = get_byte(aux0, l) * 8u; + let is = (aux1 >> (7u * l)) & 127u; + let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + + let g = get_byte(iq2xxs_grid[(ig + j) / 4u], (ig + j) % 4u); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + + shmem[elem_idx] = f16(db * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ2_XXS + +#ifdef INIT_SRC0_SHMEM_IQ2_XS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 74u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let entry_idx = k_in_block / 8u; + let j = k_in_block % 8u; + + let ib = entry_idx & ~3u; + let l = entry_idx & 3u; + + let scales_word = load_u32_at_src0(block_byte_base + 66u + (ib / 16u) * 4u); + let s = get_byte(scales_word, (ib % 16u) / 4u); + let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u); + let dl = d * (0.5 + f32(s_nib)) * 0.25; + + let qs_word = load_u32_at_src0(block_byte_base + 2u + (ib + l) * 2u); + let qs_val = qs_word & 0xFFFFu; + let ig = (qs_val & 511u) * 8u; + let is = qs_val >> 9u; + let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + + let g = get_byte(iq2xs_grid[(ig + j) / 4u], (ig + j) % 4u); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + + shmem[elem_idx] = f16(dl * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ2_XS + +#ifdef INIT_SRC0_SHMEM_IQ2_S +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 82u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib = k_in_block / 32u; + let l = (k_in_block % 32u) / 8u; + let j = k_in_block % 8u; + + let scales_word = load_u32_at_src0(block_byte_base + 74u + (ib / 4u) * 4u); + let s = get_byte(scales_word, ib % 4u); + let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u); + let dl = d * (0.5 + f32(s_nib)) * 0.25; + + let qs_word = load_u32_at_src0(block_byte_base + 2u + ib * 4u); + let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 4u) * 4u); + let qh_b = (get_byte(qh_word, ib % 4u) << (8u - 2u * l)) & 0x300u; + let ig = (get_byte(qs_word, l) | qh_b) * 8u; + + let signs_word = load_u32_at_src0(block_byte_base + 34u + ib * 4u); + let signs = get_byte(signs_word, l); + + let g = get_byte(iq2s_grid[(ig + j) / 4u], (ig + j) % 4u); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + + shmem[elem_idx] = f16(dl * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ2_S + +#ifdef INIT_SRC0_SHMEM_IQ3_XXS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 98u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib_pair = k_in_block / 32u; + let in_pair = k_in_block % 32u; + let l = in_pair / 8u; + let in_l = in_pair % 8u; + let k2 = in_l / 4u; + let j = in_l % 4u; + + let ib = ib_pair * 2u; + let sc_sign_off = block_byte_base + 2u + (ib + 32u) * 2u; + let sc_sign = load_u32_at_src0(sc_sign_off); + let db = d * (0.5 + f32(sc_sign >> 28u)) * 0.5; + let is = (sc_sign >> (7u * l)) & 127u; + let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + + let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 2u + l) * 2u) & 0xFFFFu; + let ig_byte = get_byte(ig_word, k2); + let g = get_byte(iq3xxs_grid[ig_byte], j); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u); + + shmem[elem_idx] = f16(db * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ3_XXS + +#ifdef INIT_SRC0_SHMEM_IQ3_S +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 110u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib = k_in_block / 64u; + let rest = k_in_block % 64u; + let k = rest / 32u; + let in_k = rest % 32u; + let l = in_k / 8u; + let in_l = in_k % 8u; + let k2 = in_l / 4u; + let j = in_l % 4u; + + let scales_word = load_u32_at_src0(block_byte_base + 106u); + let s = get_byte(scales_word, ib); + let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, k != 0u); + let dl = d * (1.0 + 2.0 * f32(s_nib)); + + let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 2u) * 4u); + let qh_byte = get_byte(qh_word, (ib % 2u) * 2u + k); + + let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 8u + k * 4u + l) * 2u) & 0xFFFFu; + let ig_lo = get_byte(ig_word, 0u) | ((qh_byte << (8u - 2u * l)) & 256u); + let ig_hi = get_byte(ig_word, 1u) | ((qh_byte << (7u - 2u * l)) & 256u); + let ig = select(ig_lo, ig_hi, k2 != 0u); + + let signs_word = load_u32_at_src0(block_byte_base + 74u + (ib * 2u + k) * 4u); + let signs = get_byte(signs_word, l); + + let g = get_byte(iq3s_grid[ig], j); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u); + + shmem[elem_idx] = f16(dl * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ3_S From 27aef3dd91e7cde049e7c242dbf6c8fe86574d01 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 30 Apr 2026 09:20:26 +0300 Subject: [PATCH 07/11] scripts : add wc2wt.sh - create worktree from current HEAD (#22513) * scripts : add wc2wt.sh - create worktree from current HEAD Add a script to create a git worktree on a new branch from the current HEAD. Similar to pr2wt.sh but for local development branches instead of PRs. Usage: ./scripts/wc2wt.sh gg/new-feature ./scripts/wc2wt.sh gg/new-feature "bash -l" Assisted-by: llama.cpp:local pi * cont : no need to try to delete the branch --- scripts/wc2wt.sh | 58 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100755 scripts/wc2wt.sh diff --git a/scripts/wc2wt.sh b/scripts/wc2wt.sh new file mode 100755 index 000000000..157881b45 --- /dev/null +++ b/scripts/wc2wt.sh @@ -0,0 +1,58 @@ +#!/usr/bin/env bash + +# initialize a new worktree from a branch name: +# +# - creates a new branch from current HEAD +# - creates a new worktree in a parent folder, suffixed with the branch name +# +# sample usage: +# ./scripts/wc2wt.sh gg/new-feature-foo-bar +# ./scripts/wc2wt.sh gg/new-feature-foo-bar opencode +# ./scripts/wc2wt.sh gg/new-feature-foo-bar "cmake -B build && cmake --build build" +# ./scripts/wc2wt.sh gg/new-feature-foo-bar "bash -l" + +function usage() { + echo "usage: $0 [cmd]" + exit 1 +} + +# check we are in the right directory +if [[ ! -f "scripts/wc2wt.sh" ]]; then + echo "error: this script must be run from the root of the repository" + exit 1 +fi + +if [[ $# -lt 1 || $# -gt 2 ]]; then + usage +fi + +BRANCH=$1 + +if [[ -z "$BRANCH" ]]; then + echo "error: branch name must not be empty" + exit 1 +fi + +dir=$(basename $(pwd)) +# sanitize branch name for directory name (replace / with -) +dir_suffix=$(echo "$BRANCH" | tr '/' '-') + +git worktree add -b "$BRANCH" "../$dir-$dir_suffix" HEAD + +og_path=$(pwd) +wt_path=$(cd "../$dir-$dir_suffix" && pwd) + +echo "git worktree created in $wt_path" + +cd "$wt_path" + +# pi agent setup in the worktree +if [[ -f "$og_path/.pi/SYSTEM.md" && ! -f ".pi/SYSTEM.md" ]]; then + mkdir -p .pi + ln -sfn "$og_path/.pi/SYSTEM.md" .pi/SYSTEM.md +fi + +if [[ $# -eq 2 ]]; then + echo "executing: $2" + eval "$2" +fi From e82aaf258786bc9a1d018c082697f1a15007f23f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 30 Apr 2026 13:04:50 +0200 Subject: [PATCH 08/11] CUDA: fix tile FA kernel on Pascal (#22541) --- ggml/src/ggml-cuda/fattn-tile.cuh | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 928b856f9..585f2c228 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -68,7 +68,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) @@ -130,7 +130,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64) @@ -1124,7 +1124,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr size_t nbytes_shared = 0; #ifdef GGML_USE_HIP - if constexpr (DV <= 128) { + if constexpr (DKQ <= 128) { if (Q->ne[1] > 32/ncols2) { constexpr int cols_per_block = 64; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; @@ -1138,7 +1138,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm #endif // GGML_USE_HIP #ifndef GGML_USE_HIP - if constexpr (DV <= 256) + if constexpr (DKQ <= 256) #endif // GGML_USE_HIP { if (Q->ne[1] > 16/ncols2) { @@ -1220,11 +1220,22 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX; const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; - if constexpr (DKQ == 320) { // Mistral Small 4 + if constexpr (DKQ == 320) { + // This branch is only used for Mistral Small 4 which has a GQA ratio of 32. + // On AMD, simply use that GQA ratio with 32 columns / block since we always have enough SRAM. + // On NVIDIA however, the tile kernel is only used for GPUs that can't use the mma kernel (Pascal and older). + // Therefore, use a GQA ratio of 16 with 16 columns / block to stay below 48 kiB of SRAM / block. +#ifdef GGML_USE_HIP if (use_gqa_opt && gqa_ratio % 32 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; } +#else + if (use_gqa_opt && gqa_ratio % 16 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } +#endif // GGML_USE_HIP GGML_ABORT("flash-attn tile (320/256): expected GQA ratio multiple of 32"); } From 5f0ab726f798daa5bd6da7404df2deb247017a9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Thu, 30 Apr 2026 15:04:39 +0200 Subject: [PATCH 09/11] vendor : update cpp-httplib to 0.43.2 (#22548) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- scripts/sync_vendor.py | 2 +- vendor/cpp-httplib/httplib.cpp | 68 +++++++++++++++++----------------- vendor/cpp-httplib/httplib.h | 4 +- 3 files changed, 36 insertions(+), 38 deletions(-) diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index ff1dd0753..d2a8a50b5 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -5,7 +5,7 @@ import os import sys import subprocess -HTTPLIB_VERSION = "refs/tags/v0.43.1" +HTTPLIB_VERSION = "refs/tags/v0.43.2" vendor = { "https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp", diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index 95bf0eb1b..66c0b6ebd 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -1464,8 +1464,9 @@ bool mmap::open(const char *path) { auto wpath = u8string_to_wstring(path); if (wpath.empty()) { return false; } - hFile_ = ::CreateFile2(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ, - OPEN_EXISTING, NULL); + hFile_ = + ::CreateFile2(wpath.c_str(), GENERIC_READ, + FILE_SHARE_READ | FILE_SHARE_WRITE, OPEN_EXISTING, NULL); if (hFile_ == INVALID_HANDLE_VALUE) { return false; } @@ -2052,56 +2053,50 @@ int getaddrinfo_with_timeout(const char *node, const char *service, return 0; #elif defined(_GNU_SOURCE) && defined(__GLIBC__) && \ (__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 2)) - // Linux implementation using getaddrinfo_a for asynchronous DNS resolution - struct gaicb request; + // #2431: gai_cancel() is non-blocking and may return EAI_NOTCANCELED while + // the resolver worker still references the stack-local gaicb. The cancel + // path therefore waits (gai_suspend with no timeout) for the worker to + // actually finish before letting the stack frame go. The trade-off is that + // a wedged DNS server can hold this thread for the system resolver timeout + // (~30s by default) past the caller's connection timeout. + struct gaicb request {}; struct gaicb *requests[1] = {&request}; - struct sigevent sevp; - struct timespec timeout; + struct sigevent sevp {}; + struct timespec timeout { + timeout_sec, 0 + }; - // Initialize the request structure - memset(&request, 0, sizeof(request)); request.ar_name = node; request.ar_service = service; request.ar_request = hints; - - // Set up timeout - timeout.tv_sec = timeout_sec; - timeout.tv_nsec = 0; - - // Initialize sigevent structure (not used, but required) - memset(&sevp, 0, sizeof(sevp)); sevp.sigev_notify = SIGEV_NONE; - // Start asynchronous resolution - int start_result = getaddrinfo_a(GAI_NOWAIT, requests, 1, &sevp); - if (start_result != 0) { return start_result; } + int rc = getaddrinfo_a(GAI_NOWAIT, requests, 1, &sevp); + if (rc != 0) { return rc; } - // Wait for completion with timeout - int wait_result = - gai_suspend((const struct gaicb *const *)requests, 1, &timeout); + auto cleanup = scope_exit([&] { + if (request.ar_result) { freeaddrinfo(request.ar_result); } + }); + + int wait_result = gai_suspend(requests, 1, &timeout); if (wait_result == 0 || wait_result == EAI_ALLDONE) { - // Completed successfully, get the result int gai_result = gai_error(&request); if (gai_result == 0) { *res = request.ar_result; + request.ar_result = nullptr; return 0; - } else { - // Clean up on error - if (request.ar_result) { freeaddrinfo(request.ar_result); } - return gai_result; } - } else if (wait_result == EAI_AGAIN) { - // Timeout occurred, cancel the request - gai_cancel(&request); - return EAI_AGAIN; - } else { - // Other error occurred - gai_cancel(&request); - return wait_result; + return gai_result; } + + gai_cancel(&request); + while (gai_error(&request) == EAI_INPROGRESS) { + gai_suspend(requests, 1, nullptr); + } + return wait_result; #else - // Fallback implementation using thread-based timeout for other Unix systems + // Fallback implementation using thread-based timeout for other Unix systems. struct GetAddrInfoState { ~GetAddrInfoState() { @@ -14142,6 +14137,9 @@ ssize_t read(session_t session, void *buf, size_t len, TlsError &err) { err.code = impl::map_mbedtls_error(ret, err.sys_errno); err.backend_code = static_cast(-ret); impl::mbedtls_last_error() = ret; + // mbedTLS signals a clean close_notify via a negative error code rather + // than 0; surface it as a clean EOF the way OpenSSL/wolfSSL do. + if (err.code == ErrorCode::PeerClosed) { return 0; } return -1; } diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index 8581d1695..7e530961b 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -8,8 +8,8 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.43.1" -#define CPPHTTPLIB_VERSION_NUM "0x002b01" +#define CPPHTTPLIB_VERSION "0.43.2" +#define CPPHTTPLIB_VERSION_NUM "0x002b02" #ifdef _WIN32 #if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00 From 6118c043b186cc3727b6dcb91daa897b3254c457 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Thu, 30 Apr 2026 15:15:54 +0200 Subject: [PATCH 10/11] ci : bump ty to 0.0.33 (#22535) * bump ty to 0.0.33 * update typings --- .github/workflows/python-type-check.yml | 2 +- convert_hf_to_gguf.py | 2 +- scripts/jinja/jinja-tester.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python-type-check.yml b/.github/workflows/python-type-check.yml index dc7aebe24..2d3fa163d 100644 --- a/.github/workflows/python-type-check.yml +++ b/.github/workflows/python-type-check.yml @@ -31,7 +31,7 @@ jobs: uses: actions/setup-python@v6 with: python-version: "3.11" - pip-install: -r requirements/requirements-all.txt ty==0.0.26 + pip-install: -r requirements/requirements-all.txt ty==0.0.33 # - name: Type-check with Pyright # uses: jakebailey/pyright-action@v2 # with: diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 90c2b7094..5287c4df9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6658,7 +6658,7 @@ class BertModel(TextModel): tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size - toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size # ty: ignore[invalid-assignment] + toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size if isinstance(tokenizer, SentencePieceProcessor): for token_id in range(tokenizer.vocab_size()): diff --git a/scripts/jinja/jinja-tester.py b/scripts/jinja/jinja-tester.py index 4f79b8da3..a83f02541 100755 --- a/scripts/jinja/jinja-tester.py +++ b/scripts/jinja/jinja-tester.py @@ -20,6 +20,7 @@ from PySide6.QtCore import Qt, QRect, QSize from jinja2 import TemplateSyntaxError from jinja2.sandbox import ImmutableSandboxedEnvironment from datetime import datetime +from typing import Callable def format_template_content(template_content): @@ -395,7 +396,7 @@ class JinjaTester(QMainWindow): ensure_ascii=ensure_ascii, ) ) - env.globals["strftime_now"] = lambda format: datetime.now().strftime(format) # ty: ignore[invalid-assignment] + env.globals["strftime_now"]: Callable[[str], str] = lambda format: datetime.now().strftime(format) env.globals["raise_exception"] = raise_exception # ty: ignore[invalid-assignment] try: template = env.from_string(template_str) From c20c44514a06a3fd70e3d2e8b830812047360e5b Mon Sep 17 00:00:00 2001 From: Ben Guidarelli Date: Thu, 30 Apr 2026 10:32:32 -0400 Subject: [PATCH 11/11] spec: fix argument typo (#22552) --- common/arg.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/arg.cpp b/common/arg.cpp index 943d0766f..c21598e76 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3499,7 +3499,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_N_MIN")); add_opt(common_arg( - {"--spec--draft-p-split", "--draft-p-split"}, "P", + {"--spec-draft-p-split", "--draft-p-split"}, "P", string_format("speculative decoding split probability (default: %.2f)", (double)params.speculative.draft.p_split), [](common_params & params, const std::string & value) { params.speculative.draft.p_split = std::stof(value);