diff --git a/common/common.cpp b/common/common.cpp index 4af4f98a..02829102 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1259,7 +1259,7 @@ static bool assign_layers_to_device( // constraint bound 4: CUDA/shared memory constraint for CUDA/Metal devices for (uint32_t m = 0; m < n_world; ++m) { double upper_bound = W * vec_z_gpu[m]; - model.lp_.row_upper_[constraint_idx] = (upper_bound > 0) ? std::max(upper_bound, 1.0) : upper_bound; + model.lp_.row_upper_[constraint_idx] = std::max(upper_bound, 0.0); constraint_idx++; } @@ -1366,21 +1366,26 @@ static bool assign_layers_to_device( } // check the solution - bool has_free_gpu_memory = false, has_gpu_overload = false; + bool has_free_gpu_memory = false, has_gpu_overload = false, has_cpu_overload = false; for (uint32_t m = 0; m < n_world; ++m) { - if (!dev_gpu[m]) continue; + // if (!dev_gpu[m]) continue; uint32_t w_m = best_solution[m], n_m = best_solution[m + n_world]; - if (n_m < static_cast(std::round(W * vec_z_gpu[m]))) { - // if there is still free GPU memory - has_free_gpu_memory = true; - } else if (w_m > n_m) { - // if the GPU is overloaded - has_gpu_overload = true; + if (dev_gpu[m]) { + if (n_m < static_cast(std::round(W * vec_z_gpu[m]))) { + // if there is still free GPU memory + has_free_gpu_memory = true; + } else if (w_m > n_m) { + // if the GPU is overloaded + has_gpu_overload = true; + } + } else if (!in_set(m, M4)) { + // if the CPU is overloaded + has_cpu_overload = true; } } - if (has_free_gpu_memory && has_gpu_overload) { + if (has_free_gpu_memory && (has_gpu_overload || has_cpu_overload)) { int worst_device = -1; float worst_speed = std::numeric_limits::max();