From 27756ee182e5cb7231a610ca2f06d8d3bdbb602a Mon Sep 17 00:00:00 2001 From: Lizonghang <870644199@qq.com> Date: Wed, 4 Jun 2025 15:11:29 +0400 Subject: [PATCH] fix: enable rolling back set assignment when all devices are assigned to M4 but no feasible solutions --- common/common.cpp | 49 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 0072996c..dff98506 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1094,8 +1094,8 @@ static bool assign_layers_to_device( }; (void)print_matrix; - std::vector final_solution; - int final_k = -1; + std::vector final_solution, rollback_solution; + int final_k = -1, rollback_k = -1; // iterative optimization to find a valid set assignment (M1, M2, M3, M4) while (true) { @@ -1367,19 +1367,48 @@ static bool assign_layers_to_device( // get the solution const HighsModelStatus& model_status = highs.getModelStatus(); - if (model_status != HighsModelStatus::kOptimal) continue; + + if (model_status != HighsModelStatus::kOptimal) { + bool is_all_in_M4 = true; + for (uint32_t m = 0; m < n_world; ++m) { + if (!in_set(m, M4)) { + is_all_in_M4 = false; + break; + } + } + if (!is_all_in_M4) continue; + } // record the best solution const HighsSolution& solution = highs.getSolution(); double objective_value = highs.getInfo().objective_function_value; - if (objective_value < best_objective) { - best_objective = objective_value; - best_k = k; - best_solution = solution.col_value; - } - LOG_INF("k = %2d, obj = %7.1f, solution: %s | best_k = %2d, best_obj = %7.1f, best_solution: %s\n", - k, objective_value, vec_to_str(solution.col_value).c_str(), best_k, best_objective, vec_to_str(best_solution).c_str()); + if (solution.value_valid) { + if (objective_value < best_objective) { + best_objective = objective_value; + best_k = k; + best_solution = solution.col_value; + } + LOG_INF("k = %2d, obj = %7.1f, solution: %s | best_k = %2d, best_obj = %7.1f, best_solution: %s\n", + k, objective_value, vec_to_str(solution.col_value).c_str(), best_k, best_objective, vec_to_str(best_solution).c_str()); + } + } + + if (best_solution.empty()) { + LOG_INF("No feasible solution found for this set assignment, rolling back to previous sets.\n"); + + final_solution = rollback_solution; + final_k = rollback_k; + + // update w[m] and n[m] + GGML_ASSERT(final_solution.size() == n_world * 2 && "Invalid solution\n"); + std::copy(final_solution.begin(), final_solution.begin() + n_world, w.begin()); + std::copy(final_solution.begin() + n_world, final_solution.end(), n.begin()); + + break; + } else { + rollback_solution = best_solution; + rollback_k = best_k; } // check the solution