fix: enable rolling back set assignment when all devices are assigned to M4 but no feasible solutions

This commit is contained in:
Lizonghang 2025-06-04 15:11:29 +04:00
parent 6439090920
commit 27756ee182

View file

@ -1094,8 +1094,8 @@ static bool assign_layers_to_device(
};
(void)print_matrix;
std::vector<double> final_solution;
int final_k = -1;
std::vector<double> final_solution, rollback_solution;
int final_k = -1, rollback_k = -1;
// iterative optimization to find a valid set assignment (M1, M2, M3, M4)
while (true) {
@ -1367,19 +1367,48 @@ static bool assign_layers_to_device(
// get the solution
const HighsModelStatus& model_status = highs.getModelStatus();
if (model_status != HighsModelStatus::kOptimal) continue;
if (model_status != HighsModelStatus::kOptimal) {
bool is_all_in_M4 = true;
for (uint32_t m = 0; m < n_world; ++m) {
if (!in_set(m, M4)) {
is_all_in_M4 = false;
break;
}
}
if (!is_all_in_M4) continue;
}
// record the best solution
const HighsSolution& solution = highs.getSolution();
double objective_value = highs.getInfo().objective_function_value;
if (objective_value < best_objective) {
best_objective = objective_value;
best_k = k;
best_solution = solution.col_value;
}
LOG_INF("k = %2d, obj = %7.1f, solution: %s | best_k = %2d, best_obj = %7.1f, best_solution: %s\n",
k, objective_value, vec_to_str(solution.col_value).c_str(), best_k, best_objective, vec_to_str(best_solution).c_str());
if (solution.value_valid) {
if (objective_value < best_objective) {
best_objective = objective_value;
best_k = k;
best_solution = solution.col_value;
}
LOG_INF("k = %2d, obj = %7.1f, solution: %s | best_k = %2d, best_obj = %7.1f, best_solution: %s\n",
k, objective_value, vec_to_str(solution.col_value).c_str(), best_k, best_objective, vec_to_str(best_solution).c_str());
}
}
if (best_solution.empty()) {
LOG_INF("No feasible solution found for this set assignment, rolling back to previous sets.\n");
final_solution = rollback_solution;
final_k = rollback_k;
// update w[m] and n[m]
GGML_ASSERT(final_solution.size() == n_world * 2 && "Invalid solution\n");
std::copy(final_solution.begin(), final_solution.begin() + n_world, w.begin());
std::copy(final_solution.begin() + n_world, final_solution.end(), n.begin());
break;
} else {
rollback_solution = best_solution;
rollback_k = best_k;
}
// check the solution