add metal mem limit

This commit is contained in:
Lizonghang 2025-01-23 16:08:52 +04:00
parent 33429ec4e1
commit 78a544d716
5 changed files with 102 additions and 67 deletions

View file

@ -731,10 +731,10 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
}
).set_env("LLAMA_ARG_UNLOAD"));
add_opt(llama_arg(
{"-cm", "--cuda-mem"}, "N",
format("maximum cuda memory to use (default: %d)", params.cuda_mem),
{"-gm", "--gpu-mem"}, "N",
format("maximum GPU memory to use (default: %d)", params.gpu_mem),
[](gpt_params & params, int value) {
params.cuda_mem = value; // in GiB
params.gpu_mem = value; // in GiB
}
).set_env("LLAMA_ARG_CUDA_MEM"));
#ifdef GGML_USE_METAL

View file

@ -16,6 +16,7 @@
#include <codecvt>
#include <cstdarg>
#include <cstring>
#include <csignal>
#include <ctime>
#include <fstream>
#include <iostream>
@ -76,6 +77,16 @@ using json = nlohmann::ordered_json;
constexpr int GIGABYTE = 1024 * 1024 * 1024;
struct HiGHSException {
int signal;
const char * message;
};
[[noreturn]] static void highs_handler(int signal) {
HiGHSException e{signal, "HiGHS terminated due to signal"};
throw e;
}
//
// CPU utils
//
@ -846,7 +857,7 @@ static void assign_device(
// model-specific constants
const int n_embd_k_gqa = llama_model_n_embd_k_gqa(model);
const int n_embd_v_gqa = llama_model_n_embd_v_gqa(model);
const int n_kv = 16;
const int n_kv = cparams.n_ctx;
const int64_t b = dev_info_set[0].model_bytes.nb_layer;
const int64_t bi = dev_info_set[0].model_bytes.nb_input;
@ -1104,9 +1115,9 @@ static void assign_device(
}
// -------------------------------------------------------------
// Construct vectors vz, vz_cuda
// Construct vectors vz, vz_gpu
// -------------------------------------------------------------
// z and z_cuda are used to express memory constraints:
// z and z_gpu are used to express memory constraints:
// for z:
// - M1: (d_m^{avail} - b_cio) / (L*b')
// - M2: (d_m^{total} - b_cio - c_gpu) / (L*b')
@ -1115,11 +1126,11 @@ static void assign_device(
// or - (d_m^{total} - b_cio - c_gpu) / (L*b') on macOS with Metal,
// or - (d_m^{avail}+d_m^{swapout} - b_cio) / (L*b') on Linux or Android
//
// for z_cuda:
// for z_gpu:
// - M1: (d_{m,cuda}^{avail} - c_gpu) / (L*b'),
// d_{m,cuda}^{avail} is non-zero only if the device supports CUDA
std::vector<float> vec_z(n_world, 0.0f), vec_z_cuda(n_world, 0.0f);
std::vector<int> dev_cuda(n_world, 0);
std::vector<float> vec_z(n_world, 0.0f), vec_z_gpu(n_world, 0.0f);
std::vector<int> dev_gpu(n_world, 0);
for (uint32_t m = 0; m < n_world; ++m) {
const device_info &dev = dev_info_set[m];
@ -1148,16 +1159,20 @@ static void assign_device(
}
}
if (dev.gpu_support.cuda) {
vec_z_cuda[m] = (double)(dev.gpu_props.memory_free * GIGABYTE - c_gpu[m]) / (double)(n_layer * b_prime);
dev_cuda[m] = 1;
if (dev.gpu_support.cuda || dev.gpu_support.metal) {
float reserved_mem = 0.3f; // reserved shared memory to avoid potential OOM, set to 300 MiB by default
vec_z_gpu[m] = (double)((dev.gpu_props.memory_free - reserved_mem) * GIGABYTE - c_gpu[m]) / (double)(n_layer * b_prime);
if (dev.gpu_support.metal && m == 0 && cparams.keep_inp_out_in_metal) {
vec_z_gpu[m] -= (double)(bi + bo) / (double)(n_layer * b_prime);
}
dev_gpu[m] = 1;
} else {
vec_z_cuda[m] = -(double)c_gpu[m] / (double)(n_layer * b_prime);
vec_z_gpu[m] = -(double)c_gpu[m] / (double)(n_layer * b_prime);
}
}
// count the number of cuda devices
int num_dev_cuda = std::accumulate(dev_cuda.begin(), dev_cuda.end(), 0);
int num_dev_gpu = std::accumulate(dev_gpu.begin(), dev_gpu.end(), 0);
// -------------------------------------------------------------
// Build and solve the optimization model
@ -1175,7 +1190,7 @@ static void assign_device(
// define the number of decision variables and constraints
model.lp_.num_col_ = n_world * 2; // number of decision variables
model.lp_.num_row_ = 1 + 2 * n_world + num_dev_cuda; // number of constraints
model.lp_.num_row_ = 1 + 2 * n_world + num_dev_gpu; // number of constraints
// define the objective: k * sum(a[m] * w[m] + b[m] * n[m]) + kappa + k * sum(c[m])
model.lp_.sense_ = ObjSense::kMinimize;
@ -1216,10 +1231,10 @@ static void assign_device(
}
constraint_idx += n_world;
// constraint bound 4: CUDA memory constraint for CUDA devices
// constraint bound 4: CUDA/shared memory constraint for CUDA/Metal devices
for (uint32_t m = 0; m < n_world; ++m) {
if (dev_cuda[m]) {
model.lp_.row_upper_[constraint_idx] = W * vec_z_cuda[m];
if (dev_gpu[m]) {
model.lp_.row_upper_[constraint_idx] = W * vec_z_gpu[m];
constraint_idx++;
}
}
@ -1265,9 +1280,9 @@ static void assign_device(
}
constraint_idx += n_world;
// constraint coefficients 4: CUDA memory constraint for CUDA devices
// constraint coefficients 4: CUDA/shared memory constraint for CUDA/Metal devices
for (uint32_t m = 0; m < n_world; ++m) {
if (dev_cuda[m]) {
if (dev_gpu[m]) {
A[constraint_idx][m] = 0.0; // coefficient for w[m]
A[constraint_idx][m + n_world] = 1.0; // coefficient for n[m]
constraint_idx++;
@ -1304,8 +1319,14 @@ static void assign_device(
GGML_ASSERT(return_status == HighsStatus::kOk && "Failed to pass model\n");
// run the solver
return_status = highs.run();
GGML_ASSERT(return_status == HighsStatus::kOk && "Failed to run the solver\n");
try {
std::signal(SIGABRT, highs_handler);
return_status = highs.run();
GGML_ASSERT(return_status == HighsStatus::kOk && "Failed to run the solver\n");
} catch (const HiGHSException &e) {
LOG_INF("Failed to run the solver when k = %d: unknown exception\n", k);
continue;
}
// get the solution
const HighsModelStatus& model_status = highs.getModelStatus();
@ -1419,7 +1440,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
// get device profile
LOG_INF("Start profiling this device, this may take some seconds ...\n");
dev_info.rank = params.rank;
llama_profile_device(&dev_info, model, ml, params.cuda_mem, params.n_predict, params.n_ctx, params.cpuparams.n_threads, params.flash_attn);
llama_profile_device(&dev_info, model, ml, params.gpu_mem, params.n_predict, params.n_ctx, params.cpuparams.n_threads, params.flash_attn);
}
// create llama context
@ -1647,10 +1668,11 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto cparams = llama_context_default_params();
cparams.n_world = params.n_world;
cparams.rank = params.rank;
cparams.unload = params.unload;
cparams.n_gpu_layers = params.n_gpu_layers;
cparams.n_world = params.n_world;
cparams.rank = params.rank;
cparams.unload = params.unload;
cparams.keep_inp_out_in_metal = params.keep_inp_out_in_metal;
cparams.n_gpu_layers = params.n_gpu_layers;
std::copy(std::begin(params.n_layer_window), std::end(params.n_layer_window), cparams.n_layer_window);
if (cparams.master_ip != nullptr) {

View file

@ -149,7 +149,7 @@ struct gpt_params {
std::string next_node_ip = "localhost"; // ip address of my next node
bool unload = false; // unload layer weights after use or not
bool keep_inp_out_in_metal = false; // whether to keep input/output weight in metal, not by default
int32_t cuda_mem = 999.0; // cuda memory to use, in GiB
int32_t gpu_mem = 999.0; // gpu memory to use, in GiB
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 0; // context size
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)

View file

@ -323,6 +323,7 @@ extern "C" {
uint32_t n_layer_window[32];// number of layers to process in each compute
uint32_t n_gpu_layers; // number of layers to process on GPU
bool unload; // whether to unload layer weights after use
bool keep_inp_out_in_metal; // whether to keep input/output weight in metal
char * master_ip; // ip address of the master node
char * next_node_ip; // ip address of the next node
uint32_t n_ctx; // text context, 0 = from model

View file

@ -3574,7 +3574,7 @@ void llama_profile_device(
device_info * dev_info,
struct llama_model * model,
llama_model_loader * ml,
int cuda_mem,
int gpu_mem,
int n_predict,
int n_ctx,
int n_threads,
@ -3622,9 +3622,9 @@ void llama_profile_device(
dev_info->gpu_props.description = gpu_props.description;
dev_info->gpu_props.memory_free = round(gpu_props.memory_free / (double)(1 << 30) * 100) / 100;
#ifdef GGML_USE_CUDA
// CUDA memory limitation
dev_info->gpu_props.memory_free = std::min((float)cuda_mem, dev_info->gpu_props.memory_free);
#if defined(GGML_USE_CUDA) || defined(GGML_USE_METAL)
// GPU memory limitation
dev_info->gpu_props.memory_free = std::min((float)gpu_mem, dev_info->gpu_props.memory_free);
#endif
dev_info->gpu_props.memory_total = round(gpu_props.memory_total / (double)(1 << 30) * 100) / 100;
@ -5159,15 +5159,17 @@ struct llama_model_loader {
static const int TENSOR_NOT_REQUIRED = 1;
static const int TENSOR_DUPLICATED = 2;
struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list<int64_t> & ne, int flags = 0) {
struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list<int64_t> & ne, int flags = 0, bool set_needed = false) {
const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED));
if (cur == NULL) {
return NULL;
}
auto * weight = get_weight(ggml_get_name(cur));
weight->set_as_needed(); // this tensor is needed for this device
if (set_needed) {
auto * weight = get_weight(ggml_get_name(cur));
weight->set_as_needed(); // this tensor is needed for this device
}
return create_tensor_for(ctx, cur, flags & TENSOR_DUPLICATED);
}
@ -5253,10 +5255,18 @@ struct llama_model_loader {
const auto & mapping = mappings.at(idx);
*addr = mapping->addr;
auto merge_tensor_range = [&](ggml_context * context) {
auto merge_tensor_range = [&](ggml_context * context, bool keep_only_inp_out) {
for (ggml_tensor * tensor = ggml_get_first_tensor(context); tensor; tensor = ggml_get_next_tensor(context, tensor)) {
try {
const llama_tensor_weight* weight = get_weight(ggml_get_name(tensor));
const char * tname = ggml_get_name(tensor);
if (keep_only_inp_out && !(
strcmp(tname, "token_embd.weight") == 0 ||
strcmp(tname, "output_norm.weight") == 0 ||
strcmp(tname, "output.weight") == 0)) {
continue;
}
const llama_tensor_weight* weight = get_weight(tname);
if (!weight || weight->idx != idx) continue;
size_t first = weight->offs;
@ -5286,10 +5296,10 @@ struct llama_model_loader {
}
};
merge_tensor_range(ctx);
merge_tensor_range(ctx, false);
if (cpu_ctx != ctx && cpu_ctx != nullptr) {
merge_tensor_range(cpu_ctx);
merge_tensor_range(cpu_ctx, true);
}
}
@ -7264,7 +7274,8 @@ static void llm_load_llama_tensors(
uint32_t n_world,
uint32_t my_rank,
const uint32_t * n_layer_window,
bool * use_mmap_buffer) {
bool * use_mmap_buffer,
bool set_needed) {
const auto tn = LLM_TN(model.arch);
ggml_context * ctx_input = nullptr;
@ -7295,13 +7306,13 @@ static void llm_load_llama_tensors(
if (my_rank == 0) {
// token embedding
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0, set_needed);
// output
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0, set_needed);
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED, set_needed);
// if output is NULL, init from the input tok embed
if (model.output == NULL) {
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED, set_needed);
}
}
@ -7316,37 +7327,37 @@ static void llm_load_llama_tensors(
auto & layer = model.layers[local_i];
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0, set_needed);
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0, set_needed);
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0, set_needed);
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0, set_needed);
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0, set_needed);
// optional bias tensors
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED, set_needed);
layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED, set_needed);
layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED, set_needed);
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED, set_needed);
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0, set_needed);
layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0), set_needed);
if (n_expert == 0) {
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0, set_needed);
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0, set_needed);
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0, set_needed);
// optional MLP bias
layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED, set_needed);
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED, set_needed);
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED, set_needed);
} else {
layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0, set_needed);
layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED, set_needed);
if (layer.ffn_gate_exps) {
layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert});
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert});
layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0, set_needed);
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0, set_needed);
} else {
// merge split expert into a single tensor for compatibility with older models
// requires disabling mmap
@ -7515,7 +7526,7 @@ static bool llm_load_tensors_impl(
case LLM_ARCH_MINICPM:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
llm_load_llama_tensors(ml, model, ctx_map, n_world, my_rank, n_layer_window, &use_mmap_buffer);
llm_load_llama_tensors(ml, model, ctx_map, n_world, my_rank, n_layer_window, &use_mmap_buffer, true);
break;
case LLM_ARCH_MINICPM3:
{
@ -19791,6 +19802,7 @@ struct llama_context_params llama_context_default_params() {
/*.n_layer_window =*/ {32},
/*.n_gpu_layers =*/ 0,
/*.unload =*/ false,
/*.keep_inp_out_in_metal =*/ false,
/*.master_ip =*/ nullptr,
/*.next_node_ip =*/ nullptr,
/*.n_ctx =*/ 512,
@ -21207,7 +21219,7 @@ void llama_model_n_flops(
case LLM_ARCH_MINICPM:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
llm_load_llama_tensors(*ml, *model, ctx_map, 1, 0, n_layer_window, &use_mmap_buffer);
llm_load_llama_tensors(*ml, *model, ctx_map, 1, 0, n_layer_window, &use_mmap_buffer, false);
break;
default:
throw std::runtime_error("unsupported architecture\n");