diff --git a/common/reasoning-budget.cpp b/common/reasoning-budget.cpp index c6e1f86c9..958c9cacf 100644 --- a/common/reasoning-budget.cpp +++ b/common/reasoning-budget.cpp @@ -171,22 +171,12 @@ static void common_reasoning_budget_reset(struct llama_sampler * smpl) { ctx->force_pos = 0; } -// forward declaration for use in clone static struct llama_sampler * common_reasoning_budget_init_state( const struct llama_vocab * vocab, const std::vector & start_tokens, const std::vector & end_tokens, const std::vector & forced_tokens, int32_t budget, common_reasoning_budget_state initial_state); -static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) { - const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx; - return common_reasoning_budget_init_state( - ctx->vocab, - ctx->start_matcher.tokens, - ctx->end_matcher.tokens, - ctx->forced_tokens, - ctx->budget, - ctx->state); -} +static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl); static void common_reasoning_budget_free(struct llama_sampler * smpl) { delete (common_reasoning_budget_ctx *) smpl->ctx; @@ -205,6 +195,15 @@ static struct llama_sampler_i common_reasoning_budget_i = { /* .backend_set_input = */ nullptr, }; +static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx; + + return llama_sampler_init( + /* .iface = */ &common_reasoning_budget_i, + /* .ctx = */ new common_reasoning_budget_ctx(*ctx) + ); +} + static struct llama_sampler * common_reasoning_budget_init_state( const struct llama_vocab * vocab, const std::vector & start_tokens, diff --git a/tests/test-reasoning-budget.cpp b/tests/test-reasoning-budget.cpp index f7a601789..10d0ac21d 100644 --- a/tests/test-reasoning-budget.cpp +++ b/tests/test-reasoning-budget.cpp @@ -124,6 +124,66 @@ static void test_reasoning_budget( (void)sequence; } +static llama_token get_forced_token(struct llama_sampler * sampler, llama_token max_token) { + std::vector cur; + const size_t n_vocab = (size_t) max_token + 1; + for (size_t i = 0; i < n_vocab; i++) { + cur.emplace_back(llama_token_data{(llama_token) i, logf((float) (i + 1)), 0.0f}); + } + + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; + llama_sampler_apply(sampler, &cur_p); + + size_t finite_count = 0; + llama_token finite_token = LLAMA_TOKEN_NULL; + for (size_t i = 0; i < cur.size(); i++) { + if (std::isfinite(cur[i].logit)) { + finite_count++; + finite_token = cur[i].id; + } + } + + GGML_ASSERT(finite_count == 1 && "sampler is not forcing exactly one token"); + return finite_token; +} + +static void test_reasoning_budget_clone_mid_counting() { + const std::vector start = {100}; + const std::vector end = {101}; + const std::vector forced = {102, 101}; + + auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 2, REASONING_BUDGET_IDLE); + + llama_sampler_accept(sampler, 100); // COUNTING, remaining=2 + llama_sampler_accept(sampler, 50); // COUNTING, remaining=1 + + auto * clone = llama_sampler_clone(sampler); + llama_sampler_accept(clone, 51); // should exhaust the cloned remaining budget + + GGML_ASSERT(get_forced_token(clone, 102) == 102 && "cloned counting state lost remaining budget"); + + llama_sampler_free(clone); + llama_sampler_free(sampler); +} + +static void test_reasoning_budget_clone_mid_forcing() { + const std::vector start = {100}; + const std::vector end = {101}; + const std::vector forced = {102, 101}; + + auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 0, REASONING_BUDGET_FORCING); + + GGML_ASSERT(get_forced_token(sampler, 102) == 102); + llama_sampler_accept(sampler, 102); // advance to the second forced token + + auto * clone = llama_sampler_clone(sampler); + + GGML_ASSERT(get_forced_token(clone, 102) == 101 && "cloned forcing state lost force position"); + + llama_sampler_free(clone); + llama_sampler_free(sampler); +} + // UTF-8 boundary detection unit test // Tests common_utf8_is_complete() from reasoning-budget.h static void test_utf8_boundary_detection() { @@ -250,7 +310,10 @@ int main(void) { 7); // forcing continues through i=7 } - printf("OK (6 tests passed)\n"); + test_reasoning_budget_clone_mid_counting(); + test_reasoning_budget_clone_mid_forcing(); + + printf("OK (8 tests passed)\n"); printf("Testing UTF-8 boundary detection... "); test_utf8_boundary_detection();