reasoning-budget: clone should do a deep-copy (#23095)

This commit is contained in:
Aman Gupta 2026-05-15 17:59:07 +08:00 committed by GitHub
parent d528444580
commit ac33f032ac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 74 additions and 12 deletions

View file

@ -171,22 +171,12 @@ static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
ctx->force_pos = 0;
}
// forward declaration for use in clone
static struct llama_sampler * common_reasoning_budget_init_state(
const struct llama_vocab * vocab, const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & forced_tokens,
int32_t budget, common_reasoning_budget_state initial_state);
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
return common_reasoning_budget_init_state(
ctx->vocab,
ctx->start_matcher.tokens,
ctx->end_matcher.tokens,
ctx->forced_tokens,
ctx->budget,
ctx->state);
}
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl);
static void common_reasoning_budget_free(struct llama_sampler * smpl) {
delete (common_reasoning_budget_ctx *) smpl->ctx;
@ -205,6 +195,15 @@ static struct llama_sampler_i common_reasoning_budget_i = {
/* .backend_set_input = */ nullptr,
};
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
return llama_sampler_init(
/* .iface = */ &common_reasoning_budget_i,
/* .ctx = */ new common_reasoning_budget_ctx(*ctx)
);
}
static struct llama_sampler * common_reasoning_budget_init_state(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,

View file

@ -124,6 +124,66 @@ static void test_reasoning_budget(
(void)sequence;
}
static llama_token get_forced_token(struct llama_sampler * sampler, llama_token max_token) {
std::vector<llama_token_data> cur;
const size_t n_vocab = (size_t) max_token + 1;
for (size_t i = 0; i < n_vocab; i++) {
cur.emplace_back(llama_token_data{(llama_token) i, logf((float) (i + 1)), 0.0f});
}
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
llama_sampler_apply(sampler, &cur_p);
size_t finite_count = 0;
llama_token finite_token = LLAMA_TOKEN_NULL;
for (size_t i = 0; i < cur.size(); i++) {
if (std::isfinite(cur[i].logit)) {
finite_count++;
finite_token = cur[i].id;
}
}
GGML_ASSERT(finite_count == 1 && "sampler is not forcing exactly one token");
return finite_token;
}
static void test_reasoning_budget_clone_mid_counting() {
const std::vector<llama_token> start = {100};
const std::vector<llama_token> end = {101};
const std::vector<llama_token> forced = {102, 101};
auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 2, REASONING_BUDGET_IDLE);
llama_sampler_accept(sampler, 100); // COUNTING, remaining=2
llama_sampler_accept(sampler, 50); // COUNTING, remaining=1
auto * clone = llama_sampler_clone(sampler);
llama_sampler_accept(clone, 51); // should exhaust the cloned remaining budget
GGML_ASSERT(get_forced_token(clone, 102) == 102 && "cloned counting state lost remaining budget");
llama_sampler_free(clone);
llama_sampler_free(sampler);
}
static void test_reasoning_budget_clone_mid_forcing() {
const std::vector<llama_token> start = {100};
const std::vector<llama_token> end = {101};
const std::vector<llama_token> forced = {102, 101};
auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 0, REASONING_BUDGET_FORCING);
GGML_ASSERT(get_forced_token(sampler, 102) == 102);
llama_sampler_accept(sampler, 102); // advance to the second forced token
auto * clone = llama_sampler_clone(sampler);
GGML_ASSERT(get_forced_token(clone, 102) == 101 && "cloned forcing state lost force position");
llama_sampler_free(clone);
llama_sampler_free(sampler);
}
// UTF-8 boundary detection unit test
// Tests common_utf8_is_complete() from reasoning-budget.h
static void test_utf8_boundary_detection() {
@ -250,7 +310,10 @@ int main(void) {
7); // forcing continues through i=7
}
printf("OK (6 tests passed)\n");
test_reasoning_budget_clone_mid_counting();
test_reasoning_budget_clone_mid_forcing();
printf("OK (8 tests passed)\n");
printf("Testing UTF-8 boundary detection... ");
test_utf8_boundary_detection();