decrease compute buf from available memory

This commit is contained in:
Lizonghang 2024-11-29 11:15:54 +04:00
parent 329d084061
commit 0f73d12247
3 changed files with 53 additions and 4 deletions

View file

@ -20808,6 +20808,44 @@ static void count_n_params(struct model_params * n_params, enum ggml_type dtype,
}
}
uint64_t llama_model_compute_buf_size(const struct llama_model * model, const struct llama_context_params cparams, bool compress_memory) {
const llama_hparams hparams = model->hparams;
// input tensors
const uint64_t n_inp_toks = cparams.n_ubatch;
const uint64_t n_inp_embd = hparams.n_embd * cparams.n_ubatch;
// activations (see figures/memory-allocation-map-for-activations.png for detailed allocation)
const uint64_t n_bak_embd = hparams.n_embd * cparams.n_ubatch;
const uint64_t n_inp_pos = cparams.n_ubatch;
const uint64_t n_kq_mask = cparams.n_ctx * cparams.n_ubatch;
const uint64_t n_inp_out_ids = cparams.n_ubatch;
const uint64_t n_norm = hparams.n_embd * cparams.n_ubatch;
const uint64_t n_qcur = hparams.n_embd * cparams.n_ubatch * 2;
const uint64_t n_kq = cparams.n_ctx * cparams.n_ubatch * hparams.n_head();
// outputs
const uint64_t n_out_embd = hparams.n_embd * cparams.n_ubatch;
const uint64_t n_output = hparams.n_vocab * cparams.n_ubatch;
// compute buffer size for input, each layer, and output
// const uint64_t n_buf_inp = (n_inp_toks + n_inp_embd) * ggml_type_size(GGML_TYPE_F32); // do not consider memory compression
const uint64_t n_buf_inp = (n_inp_toks + n_inp_embd) * ggml_type_size(GGML_TYPE_F32) / 2; // consider compressed memory with ratio 2:1
const uint64_t n_buf_act = (n_bak_embd + n_inp_pos + n_kq_mask +
n_inp_out_ids + n_norm + n_qcur + n_kq
) * ggml_type_size(GGML_TYPE_F32);
// const uint64_t n_buf_out = (n_out_embd + n_output) * ggml_type_size(GGML_TYPE_F32); // do not consider memory compression
const uint64_t n_buf_out = (n_out_embd + n_output) * ggml_type_size(GGML_TYPE_F32) / 2; // consider compressed memory with ratio 2:1
uint64_t n_buf_total = 0;
if (cparams.rank == 0) {
n_buf_total = n_buf_inp + n_buf_act + n_buf_out;
} else {
n_buf_total = n_buf_act;
}
return n_buf_total;
}
uint64_t llama_model_kvcache_size(const struct llama_model * model, const struct llama_context_params cparams) {
const llama_hparams hparams = model->hparams;
uint64_t ne_k = static_cast<uint64_t>(hparams.n_embd_k_gqa()) * cparams.n_ctx * ggml_type_size(cparams.type_k);