llama : refactor sampling v2 (#9294)

- Add `struct llama_sampler` and `struct llama_sampler_i`
- Add `llama_sampler_` API
- Add `llama_sampler_chain_` API for chaining multiple samplers
- Remove `LLAMA_API_INTERNAL`
- Add `llama_perf_` API and remove old `llama_print_timings` and `llama_reset_timings`
This commit is contained in:
Georgi Gerganov 2024-09-07 15:16:19 +03:00 committed by GitHub
parent 947538acb8
commit df270ef745
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
48 changed files with 3497 additions and 2914 deletions

View file

@ -1,6 +1,5 @@
#include "llama-impl.h"
#include "llama-vocab.h"
#include "llama-grammar.h"
#include "llama-sampling.h"
#include "unicode.h"
@ -3179,7 +3178,6 @@ struct llama_sbatch {
struct llama_context {
llama_context(const llama_model & model)
: model(model)
, sampling(llama_n_vocab(&model))
, t_start_us(model.t_start_us)
, t_load_us(model.t_load_us) {}
@ -3196,7 +3194,6 @@ struct llama_context {
const struct llama_model & model;
struct llama_cparams cparams;
struct llama_sampling sampling;
struct llama_sbatch sbatch;
struct llama_kv_cache kv_self;
struct llama_control_vector cvec;
@ -3217,16 +3214,16 @@ struct llama_context {
bool has_evaluated_once = false;
int64_t t_start_us;
int64_t t_load_us;
int64_t t_p_eval_us = 0;
int64_t t_eval_us = 0;
mutable int64_t t_start_us;
mutable int64_t t_load_us;
mutable int64_t t_p_eval_us = 0;
mutable int64_t t_eval_us = 0;
int64_t t_compute_start_us = 0;
int64_t n_queued_tokens = 0;
mutable int64_t t_compute_start_us = 0;
mutable int64_t n_queued_tokens = 0;
int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
int32_t n_eval = 0; // number of eval calls
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
mutable int32_t n_eval = 0; // number of eval calls
// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_t buf_output = nullptr;
@ -6251,6 +6248,7 @@ static void llm_load_vocab(
const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
vocab.n_vocab = n_vocab;
vocab.id_to_token.resize(n_vocab);
for (uint32_t i = 0; i < n_vocab; i++) {
@ -17892,7 +17890,6 @@ struct llama_model_params llama_model_default_params() {
struct llama_context_params llama_context_default_params() {
struct llama_context_params result = {
/*.seed =*/ LLAMA_DEFAULT_SEED,
/*.n_ctx =*/ 512,
/*.n_batch =*/ 2048,
/*.n_ubatch =*/ 512,
@ -17925,6 +17922,14 @@ struct llama_context_params llama_context_default_params() {
return result;
}
struct llama_sampler_chain_params llama_sampler_chain_default_params() {
struct llama_sampler_chain_params result = {
/*.no_perf =*/ true,
};
return result;
}
struct llama_model_quantize_params llama_model_quantize_default_params() {
struct llama_model_quantize_params result = {
/*.nthread =*/ 0,
@ -18178,10 +18183,6 @@ struct llama_context * llama_new_context_with_model(
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
}
if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
}
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
@ -18192,10 +18193,10 @@ struct llama_context * llama_new_context_with_model(
ctx->abort_callback = params.abort_callback;
ctx->abort_callback_data = params.abort_callback_data;
ctx->sampling.rng = std::mt19937(params.seed);
ctx->logits_all = params.logits_all;
ctx->logits_all = params.logits_all;
// build worst-case graph for encoder if a model contains encoder
ctx->is_encoding = llama_model_has_encoder(model);
ctx->is_encoding = llama_model_has_encoder(model);
uint32_t kv_size = cparams.n_ctx;
ggml_type type_k = params.type_k;
@ -18473,14 +18474,6 @@ void llama_free(struct llama_context * ctx) {
delete ctx;
}
const struct llama_model * llama_get_model(const struct llama_context * ctx) {
return &ctx->model;
}
const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx) {
return &ctx->model.vocab;
}
uint32_t llama_n_ctx(const struct llama_context * ctx) {
return ctx->cparams.n_ctx;
}
@ -18501,6 +18494,30 @@ enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
return model->vocab.type;
}
int32_t llama_n_vocab(const struct llama_model * model) {
return model->hparams.n_vocab;
}
int32_t llama_n_ctx_train(const struct llama_model * model) {
return model->hparams.n_ctx_train;
}
int32_t llama_n_embd(const struct llama_model * model) {
return model->hparams.n_embd;
}
int32_t llama_n_layer(const struct llama_model * model) {
return model->hparams.n_layer;
}
const struct llama_model * llama_get_model(const struct llama_context * ctx) {
return &ctx->model;
}
enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
return ctx->cparams.pooling_type;
}
enum llama_rope_type llama_rope_type(const struct llama_model * model) {
switch (model->arch) {
// these models do not use RoPE
@ -18564,26 +18581,6 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
return LLAMA_ROPE_TYPE_NONE;
}
enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
return ctx->cparams.pooling_type;
}
int32_t llama_n_vocab(const struct llama_model * model) {
return model->hparams.n_vocab;
}
int32_t llama_n_ctx_train(const struct llama_model * model) {
return model->hparams.n_ctx_train;
}
int32_t llama_n_embd(const struct llama_model * model) {
return model->hparams.n_embd;
}
int32_t llama_n_layer(const struct llama_model * model) {
return model->hparams.n_layer;
}
float llama_rope_freq_scale_train(const struct llama_model * model) {
return model->hparams.rope_freq_scale_train;
}
@ -19000,14 +18997,14 @@ struct llama_data_write {
// TODO: add more model-specific info which should prevent loading the session file if not identical
}
void write_rng(const std::mt19937 & rng) {
std::ostringstream rng_ss;
rng_ss << rng;
//void write_rng(const std::mt19937 & rng) {
// std::ostringstream rng_ss;
// rng_ss << rng;
const std::string & rng_str = rng_ss.str();
// const std::string & rng_str = rng_ss.str();
write_string(rng_str);
}
// write_string(rng_str);
//}
void write_output_ids(struct llama_context * ctx) {
llama_output_reorder(ctx);
@ -19227,17 +19224,17 @@ struct llama_data_read {
// TODO: add more info which needs to be identical but which is not verified otherwise
}
void read_rng(std::mt19937 & rng) {
std::string rng_str;
read_string(rng_str);
//void read_rng(std::mt19937 & rng) {
// std::string rng_str;
// read_string(rng_str);
std::istringstream rng_ss(rng_str);
rng_ss >> rng;
// std::istringstream rng_ss(rng_str);
// rng_ss >> rng;
if (rng_ss.fail()) {
throw std::runtime_error("failed to load RNG state");
}
}
// if (rng_ss.fail()) {
// throw std::runtime_error("failed to load RNG state");
// }
//}
void read_output_ids(struct llama_context * ctx) {
std::vector<int32_t> output_pos;
@ -19667,8 +19664,6 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da
data_ctx.write_model_info(ctx);
data_ctx.write_rng(ctx->sampling.rng);
// copy outputs
data_ctx.write_output_ids(ctx);
data_ctx.write_logits(ctx);
@ -19706,9 +19701,6 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da
data_ctx.read_model_info(ctx);
// set rng
data_ctx.read_rng(ctx->sampling.rng);
// set outputs
data_ctx.read_output_ids(ctx);
data_ctx.read_logits(ctx);
@ -20111,8 +20103,9 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
GGML_ABORT("fatal error");
#endif
#else
return nullptr;
#endif
}
}
@ -20160,8 +20153,9 @@ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
GGML_ABORT("fatal error");
#endif
#else
return nullptr;
#endif
}
}
@ -20594,124 +20588,18 @@ int32_t llama_chat_apply_template(
return res;
}
//
// grammar
//
struct llama_grammar * llama_grammar_init(
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index) {
return llama_grammar_init_impl(rules, n_rules, start_rule_index);
}
void llama_grammar_free(struct llama_grammar * grammar) {
llama_grammar_free_impl(grammar);
}
struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
return llama_grammar_copy_impl(grammar);
}
void llama_grammar_sample(
const struct llama_grammar * grammar,
const struct llama_context * ctx,
llama_token_data_array * candidates) {
llama_grammar_sample_impl(grammar, &ctx->model.vocab, &ctx->sampling, candidates);
}
void llama_sample_grammar(
struct llama_context * ctx,
llama_token_data_array * candidates,
const struct llama_grammar * grammar) {
llama_grammar_sample(grammar, ctx, candidates);
}
void llama_grammar_accept_token(
struct llama_grammar * grammar,
struct llama_context * ctx,
llama_token token) {
llama_grammar_accept_token_impl(grammar, &ctx->model.vocab, &ctx->sampling, token);
}
//
// sampling
//
void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
llama_set_rng_seed_impl(&ctx->sampling, seed);
// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
}
void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
llama_sample_softmax_impl(ctx ? &ctx->sampling : nullptr, candidates);
}
void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
llama_sample_top_k_impl(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep);
}
void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
llama_sample_top_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
}
void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
}
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
llama_sample_tail_free_impl(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep);
}
void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
llama_sample_typical_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
}
void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
llama_sample_entropy_impl(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val);
}
void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
llama_sample_temp_impl(ctx ? &ctx->sampling : nullptr, candidates_p, temp);
}
void llama_sample_repetition_penalties(
struct llama_context * ctx,
llama_token_data_array * candidates,
const llama_token * last_tokens,
size_t penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present) {
llama_sample_repetition_penalties_impl(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
}
void llama_sample_apply_guidance(
struct llama_context * ctx,
float * logits,
float * logits_guidance,
float scale) {
llama_sample_apply_guidance_impl(&ctx->sampling, logits, logits_guidance, scale);
}
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
return llama_sample_token_mirostat_impl(&ctx->sampling, candidates, tau, eta, m, mu);
}
llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
return llama_sample_token_mirostat_v2_impl(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu);
}
llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
return llama_sample_token_greedy_impl(ctx ? &ctx->sampling : nullptr, candidates);
}
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, rng);
}
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng);
}
//
// model split
//
int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
@ -20737,45 +20625,6 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int
return 0;
}
struct llama_timings llama_get_timings(struct llama_context * ctx) {
struct llama_timings result = {
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us,
/*.t_end_ms =*/ 1.00 * ggml_time_ms(),
/*.t_load_ms =*/ 1e-3 * ctx->t_load_us,
/*.t_sample_ms =*/ 1e-3 * ctx->sampling.t_sample_us,
/*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us,
/*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us,
/*.n_sample =*/ std::max(1, ctx->sampling.n_sample),
/*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
/*.n_eval =*/ std::max(1, ctx->n_eval),
};
return result;
}
void llama_print_timings(struct llama_context * ctx) {
const llama_timings timings = llama_get_timings(ctx);
LLAMA_LOG_INFO("\n");
LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms);
LLAMA_LOG_INFO("%s: sample time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, timings.t_sample_ms, timings.n_sample, timings.t_sample_ms / timings.n_sample, 1e3 / timings.t_sample_ms * timings.n_sample);
LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, timings.t_eval_ms, timings.n_eval, timings.t_eval_ms / timings.n_eval, 1e3 / timings.t_eval_ms * timings.n_eval);
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval));
}
void llama_reset_timings(struct llama_context * ctx) {
ctx->t_start_us = ggml_time_us();
ctx->t_eval_us = ctx->n_eval = 0;
ctx->t_p_eval_us = ctx->n_p_eval = 0;
ctx->sampling.reset_timings();
}
const char * llama_print_system_info(void) {
static std::string s;
@ -20804,7 +20653,68 @@ const char * llama_print_system_info(void) {
return s.c_str();
}
void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
void llama_perf_print(const void * ctx, enum llama_perf_type type) {
switch (type) {
case LLAMA_PERF_TYPE_CONTEXT:
{
const auto * p = (const struct llama_context *) ctx;
const double t_start_ms = 1e-3 * p->t_start_us;
const double t_end_ms = 1.00 * ggml_time_ms();
const double t_load_ms = 1e-3 * p->t_load_us;
const double t_p_eval_ms = 1e-3 * p->t_p_eval_us;
const double t_eval_ms = 1e-3 * p->t_eval_us;
const int32_t n_p_eval = std::max(0, p->n_p_eval);
const int32_t n_eval = std::max(1, p->n_eval);
LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, t_load_ms);
LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, t_p_eval_ms, n_p_eval, t_p_eval_ms / n_p_eval, 1e3 / t_p_eval_ms * n_p_eval);
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, t_eval_ms, n_eval, t_eval_ms / n_eval, 1e3 / t_eval_ms * n_eval);
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - t_start_ms), (n_p_eval + n_eval));
} break;
case LLAMA_PERF_TYPE_SAMPLER_CHAIN:
{
const auto * smpl = (const struct llama_sampler *) ctx;
const auto * p = (const struct llama_sampler_chain *) smpl->ctx;
const double t_sampler_ms = 1e-3 * p->t_sample_us;
const int32_t n_sampler = std::max(0, p->n_sample);
LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, t_sampler_ms, n_sampler, t_sampler_ms / n_sampler, 1e3 / t_sampler_ms * n_sampler);
} break;
default:
GGML_ABORT("invalid perf type");
}
}
void llama_perf_reset(void * ctx, enum llama_perf_type type) {
switch (type) {
case LLAMA_PERF_TYPE_CONTEXT:
{
auto * p = (struct llama_context *) ctx;
p->t_start_us = ggml_time_us();
p->t_eval_us = p->n_eval = 0;
p->t_p_eval_us = p->n_p_eval = 0;
} break;
case LLAMA_PERF_TYPE_SAMPLER_CHAIN:
{
auto * smpl = (struct llama_sampler *) ctx;
auto * p = (struct llama_sampler_chain *) smpl->ctx;
p->t_sample_us = p->n_sample = 0;
} break;
default:
GGML_ABORT("invalid perf type");
}
}
void llama_perf_dump_yaml(FILE * stream, const llama_context * ctx) {
fprintf(stream, "\n");
fprintf(stream, "###########\n");
fprintf(stream, "# Timings #\n");
@ -20815,21 +20725,15 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
1.0e-3 * ctx->t_eval_us / ctx->n_eval);
fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n",
1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval);
fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n",
1.0e-3 * ctx->sampling.t_sample_us / ctx->sampling.n_sample);
fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval);
fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval);
fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->sampling.n_sample);
fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us);
fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us);
fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us);
fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->sampling.t_sample_us);
fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n",
1.0e6 * ctx->n_eval / ctx->t_eval_us);
fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n",
1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us);
fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n",
1.0e6 * ctx->sampling.n_sample / ctx->sampling.t_sample_us);
}
// For internal test use