mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-16 19:59:16 +00:00
reasoning-budget: clone should do a deep-copy (#23095)
This commit is contained in:
parent
d528444580
commit
ac33f032ac
2 changed files with 74 additions and 12 deletions
|
|
@ -171,22 +171,12 @@ static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
|
|||
ctx->force_pos = 0;
|
||||
}
|
||||
|
||||
// forward declaration for use in clone
|
||||
static struct llama_sampler * common_reasoning_budget_init_state(
|
||||
const struct llama_vocab * vocab, const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget, common_reasoning_budget_state initial_state);
|
||||
|
||||
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
|
||||
return common_reasoning_budget_init_state(
|
||||
ctx->vocab,
|
||||
ctx->start_matcher.tokens,
|
||||
ctx->end_matcher.tokens,
|
||||
ctx->forced_tokens,
|
||||
ctx->budget,
|
||||
ctx->state);
|
||||
}
|
||||
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl);
|
||||
|
||||
static void common_reasoning_budget_free(struct llama_sampler * smpl) {
|
||||
delete (common_reasoning_budget_ctx *) smpl->ctx;
|
||||
|
|
@ -205,6 +195,15 @@ static struct llama_sampler_i common_reasoning_budget_i = {
|
|||
/* .backend_set_input = */ nullptr,
|
||||
};
|
||||
|
||||
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
|
||||
|
||||
return llama_sampler_init(
|
||||
/* .iface = */ &common_reasoning_budget_i,
|
||||
/* .ctx = */ new common_reasoning_budget_ctx(*ctx)
|
||||
);
|
||||
}
|
||||
|
||||
static struct llama_sampler * common_reasoning_budget_init_state(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
|
|
|
|||
|
|
@ -124,6 +124,66 @@ static void test_reasoning_budget(
|
|||
(void)sequence;
|
||||
}
|
||||
|
||||
static llama_token get_forced_token(struct llama_sampler * sampler, llama_token max_token) {
|
||||
std::vector<llama_token_data> cur;
|
||||
const size_t n_vocab = (size_t) max_token + 1;
|
||||
for (size_t i = 0; i < n_vocab; i++) {
|
||||
cur.emplace_back(llama_token_data{(llama_token) i, logf((float) (i + 1)), 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
|
||||
size_t finite_count = 0;
|
||||
llama_token finite_token = LLAMA_TOKEN_NULL;
|
||||
for (size_t i = 0; i < cur.size(); i++) {
|
||||
if (std::isfinite(cur[i].logit)) {
|
||||
finite_count++;
|
||||
finite_token = cur[i].id;
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(finite_count == 1 && "sampler is not forcing exactly one token");
|
||||
return finite_token;
|
||||
}
|
||||
|
||||
static void test_reasoning_budget_clone_mid_counting() {
|
||||
const std::vector<llama_token> start = {100};
|
||||
const std::vector<llama_token> end = {101};
|
||||
const std::vector<llama_token> forced = {102, 101};
|
||||
|
||||
auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 2, REASONING_BUDGET_IDLE);
|
||||
|
||||
llama_sampler_accept(sampler, 100); // COUNTING, remaining=2
|
||||
llama_sampler_accept(sampler, 50); // COUNTING, remaining=1
|
||||
|
||||
auto * clone = llama_sampler_clone(sampler);
|
||||
llama_sampler_accept(clone, 51); // should exhaust the cloned remaining budget
|
||||
|
||||
GGML_ASSERT(get_forced_token(clone, 102) == 102 && "cloned counting state lost remaining budget");
|
||||
|
||||
llama_sampler_free(clone);
|
||||
llama_sampler_free(sampler);
|
||||
}
|
||||
|
||||
static void test_reasoning_budget_clone_mid_forcing() {
|
||||
const std::vector<llama_token> start = {100};
|
||||
const std::vector<llama_token> end = {101};
|
||||
const std::vector<llama_token> forced = {102, 101};
|
||||
|
||||
auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 0, REASONING_BUDGET_FORCING);
|
||||
|
||||
GGML_ASSERT(get_forced_token(sampler, 102) == 102);
|
||||
llama_sampler_accept(sampler, 102); // advance to the second forced token
|
||||
|
||||
auto * clone = llama_sampler_clone(sampler);
|
||||
|
||||
GGML_ASSERT(get_forced_token(clone, 102) == 101 && "cloned forcing state lost force position");
|
||||
|
||||
llama_sampler_free(clone);
|
||||
llama_sampler_free(sampler);
|
||||
}
|
||||
|
||||
// UTF-8 boundary detection unit test
|
||||
// Tests common_utf8_is_complete() from reasoning-budget.h
|
||||
static void test_utf8_boundary_detection() {
|
||||
|
|
@ -250,7 +310,10 @@ int main(void) {
|
|||
7); // forcing continues through i=7
|
||||
}
|
||||
|
||||
printf("OK (6 tests passed)\n");
|
||||
test_reasoning_budget_clone_mid_counting();
|
||||
test_reasoning_budget_clone_mid_forcing();
|
||||
|
||||
printf("OK (8 tests passed)\n");
|
||||
|
||||
printf("Testing UTF-8 boundary detection... ");
|
||||
test_utf8_boundary_detection();
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue