diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 1b9c9b5ca..7f35a21b3 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -1593,7 +1593,12 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar for (const auto & reject : rejects) { candidates->data[reject.index].logit = -INFINITY; } - + + auto first = candidates->data; + auto last = first + candidates->size; + last = std::remove_if(first, last, + [&](const llama_token_data & tk){ return tk.logit == -INFINITY; }); + candidates->size = last - first; } void sample_guidance(struct llama_context * ctx, struct llama_context * guidance_ctx, int n_vocab, float scale) @@ -1643,15 +1648,30 @@ const std::vector & sampler_order, llama_grammar * grammar, float dyna llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - if (grammar != nullptr) { - sample_grammar(file_format, n_vocab, &candidates_p, grammar); - } - //dry always first as logits cannot be resorted sample_dry(n_ctx, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_sequence_breakers, &candidates_p); - + //prefilter to top 3k tokens for improved speed + bool use_grammar = grammar != nullptr; + size_t n_pre_cull = candidates_p.size; + sample_top_k(&candidates_p, 3000); + + if (use_grammar) { + + (debugmode == 1 && printf("\nGrammar sampling %zu candidates.\n", candidates_p.size)); + sample_grammar(file_format, n_vocab, &candidates_p, grammar); + (debugmode == 1 && printf("\nGrammar returned %zu candidates.\n", candidates_p.size)); + + // if top_k 3000 doesn't contain a valid candidate for this grammar, try again pre-cull + if (candidates_p.size <= 0) { + candidates_p.size = n_pre_cull; + (debugmode == 1 && printf("\nRe-sampling grammar with %zu pre-cull tokens.\n", candidates_p.size)); + sample_grammar(file_format, n_vocab, &candidates_p, grammar); + (debugmode == 1 && printf("\nGrammar returned %zu candidates.\n", candidates_p.size)); + sample_top_k(&candidates_p, 3000); + } + } if (mirostat == 1 || mirostat == 2) { @@ -1745,7 +1765,6 @@ static void grammar_accept_token(FileFormat file_format, int32_t n_vocab, struct const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8); const auto & code_points = decoded.first; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - auto prev_stacks = grammar->stacks; llama_grammar_accept(grammar, *it); } grammar->partial_utf8 = decoded.second; @@ -3941,6 +3960,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs) } if (grammar != nullptr) { + (debugmode == 1 && printf("\nGrammar attempting to accept token...\n")); grammar_accept_token(file_format, n_vocab, grammar, id); }