diff --git a/common/arg.cpp b/common/arg.cpp index 2f999dcc..a338f613 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -627,12 +627,19 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); add_opt(llama_arg( - {"--draft"}, "N", - format("number of tokens to draft for speculative decoding (default: %d)", params.n_draft), + {"--draft-max", "--draft", "--draft-n"}, "N", + format("number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max), [](gpt_params & params, int value) { - params.n_draft = value; + params.speculative.n_max = value; } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP})); + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER})); + add_opt(llama_arg( + {"--draft-min", "--draft-n-min"}, "N", + format("minimum number of draft tokens to use for speculative decoding (default: %d)", params.speculative.n_min), + [](gpt_params & params, int value) { + params.speculative.n_min = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER})); add_opt(llama_arg( {"-ps", "--p-split"}, "N", format("speculative decoding split probability (default: %.1f)", (double)params.p_split), @@ -640,6 +647,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, params.p_split = std::stof(value); } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(llama_arg( + {"--draft-p-min"}, "P", + format("minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min), + [](gpt_params & params, const std::string & value) { + params.speculative.p_min = std::stof(value); + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); add_opt(llama_arg( {"-lcs", "--lookup-cache-static"}, "FNAME", "path to static lookup cache to use for lookup decoding (not updated by generation)", diff --git a/common/common.h b/common/common.h index 044dfdf5..b454f799 100644 --- a/common/common.h +++ b/common/common.h @@ -177,7 +177,6 @@ struct gpt_params { int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_draft = 5; // number of tokens to draft during speculative decoding int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_parallel = 1; // number of parallel sequences to decode int32_t n_sequences = 1; // number of sequences to decode diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 0a6c4701..ad610ac8 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -169,7 +169,7 @@ int main(int argc, char ** argv) { const auto t_enc_end = ggml_time_us(); // how many tokens to draft each time - int n_draft = params.n_draft; + int n_draft = params.speculative.n_max; int n_predict = 0; int n_drafted = 0;