diff --git a/common/common.cpp b/common/common.cpp index c241b54f..472205b9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1016,7 +1016,6 @@ static bool assign_layers_to_device( auto device = (diff > 0) ? std::max_element(mem_budget.begin(), mem_budget.end()) : std::min_element(mem_budget.begin(), mem_budget.end()); w[std::distance(mem_budget.begin(), device)] += diff; - // initialize n_m to w_m (if there is GPU), assume all layers can run on GPU for (uint32_t m = 0; m < n_world; ++m) { if (dev_info_set[m].gpu_support.metal || dev_info_set[m].gpu_support.cuda) { @@ -1540,13 +1539,24 @@ static bool assign_layers_to_device( float total_mem_budget = std::accumulate(mem_budget.begin(), mem_budget.end(), 0.0f); for (uint32_t m = 0; m < n_world; ++m) { w[m] = std::round(mem_budget[m] / total_mem_budget * n_layer); - n[m] = 0; + } + // no 0 is allowed in w, it must be at least 1 + for (uint32_t m = 0; m < n_world; ++m) { + if (w[m] == 0) { + w[m] = 1; + // find the maximum and decrease it by 1 + auto max_it = std::max_element(w.begin(), w.end()); + if (max_it != w.end() && *max_it > 1) { + *max_it -= 1; + } + } } // adjust w[m] to ensure L mod W = 0 int diff = n_layer - std::accumulate(w.begin(), w.end(), 0); auto device = (diff > 0) ? std::max_element(mem_budget.begin(), mem_budget.end()) : std::min_element(mem_budget.begin(), mem_budget.end()); w[std::distance(mem_budget.begin(), device)] += diff; + std::copy(w.begin(), w.end(), n_layer_window); std::vector vec_z_gpu(n_world, 0.0f); @@ -1554,8 +1564,7 @@ static bool assign_layers_to_device( for (uint32_t m = 0; m < n_world; ++m) { const device_info & dev = dev_info_set[m]; - bool use_gpu = dev.gpu_support.metal || dev.gpu_support.cuda; - llama_model_compute_buf_size(&c_cpu[m], &c_gpu[m], model, cparams, get_backend_type(dev.gpu_support), m, dev_info_set[0].model_bytes, true); + llama_model_compute_buf_size(&c_cpu[m], &c_gpu[m], model, cparams, get_backend_type(dev.gpu_support), m, dev_info_set[0].model_bytes, false, true); if (dev.gpu_support.cuda || dev.gpu_support.metal) { int64_t required_mem = w[m] * b_prime;