From b7a78d4ecc9799eba8ee00c74188787a839eb5e7 Mon Sep 17 00:00:00 2001 From: Nolwen Date: Thu, 21 Mar 2024 14:24:31 +0100 Subject: [PATCH] Implement callbacks integration for gurobi solvers - replace method to implement (for all milp solvers) retrieve_ith_solution() by retrieve_current_solution() that take a callable get_var_value_for_current_solution(). Then - retrieve_ith_solution is retrieve_current_solution with get_var_value_for_ith_solution - during a gurobi callback, it will be called with model.cbGetSolution - update milp solvers to reflect that - create a GurobiCallback class that - populates on the fly a result_storage by using retrieve_current_solution() - calls user-defined d-o callbacks - terminate solve if user-defined callbacks decide to early stop - catch potential exceptions inside the callback as they are ignored by gurobi else (the solve continues even though the error is displayed to stderr) - re-raise in solve() exceptions found in callbacks (useful for debugging but also necessary for optuna pruning, based on raising a TrialPruned exception) - add tests using d-o callbacks with gurobi for coloring solver - the gurobi pickup-vrp solver still does not use callbacks as it is overriding GurobiMilpSolver.solve() to perform a loop. --- .../coloring/solvers/coloring_lp_solvers.py | 10 +- .../facility/solvers/facility_lp_solver.py | 10 +- .../generic_tools/lp_tools.py | 136 ++++++++++++++++-- .../knapsack/solvers/lp_solvers.py | 10 +- .../pickup_vrp/solver/lp_solver.py | 72 ++++++---- .../pickup_vrp/solver/lp_solver_pymip.py | 58 ++++++-- .../rcpsp/solver/rcpsp_lp_solver.py | 18 ++- .../rcpsp/solver/rcpsp_lp_solver_gantt.py | 23 +-- .../rcpsp_multiskill/solvers/lp_model.py | 24 ++-- tests/coloring/test_coloring.py | 60 ++++++++ 10 files changed, 341 insertions(+), 80 deletions(-) diff --git a/discrete_optimization/coloring/solvers/coloring_lp_solvers.py b/discrete_optimization/coloring/solvers/coloring_lp_solvers.py index 1b89c913..2a90124c 100644 --- a/discrete_optimization/coloring/solvers/coloring_lp_solvers.py +++ b/discrete_optimization/coloring/solvers/coloring_lp_solvers.py @@ -6,7 +6,7 @@ import logging import sys -from typing import Any, Dict, Hashable, Optional, Tuple, Union +from typing import Any, Callable, Dict, Hashable, Optional, Tuple, Union import mip import networkx as nx @@ -99,13 +99,17 @@ def __init__( self.sense_optim = self.params_objective_function.sense_function self.start_solution: Optional[ColoringSolution] = None - def retrieve_ith_solution(self, i: int) -> ColoringSolution: + def retrieve_current_solution( + self, + get_var_value_for_current_solution: Callable[[Any], float], + get_obj_value_for_current_solution: Callable[[], float], + ) -> ColoringSolution: colors = [0] * self.number_of_nodes for ( variable_decision_key, variable_decision_value, ) in self.variable_decision["colors_var"].items(): - value = self.get_var_value_for_ith_solution(variable_decision_value, i) + value = get_var_value_for_current_solution(variable_decision_value) if value >= 0.5: node = variable_decision_key[0] color = variable_decision_key[1] diff --git a/discrete_optimization/facility/solvers/facility_lp_solver.py b/discrete_optimization/facility/solvers/facility_lp_solver.py index 644558d3..60abf2b4 100644 --- a/discrete_optimization/facility/solvers/facility_lp_solver.py +++ b/discrete_optimization/facility/solvers/facility_lp_solver.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import mip import numpy as np @@ -136,14 +136,18 @@ def __init__( } self.description_constraint: Dict[str, Dict[str, str]] = {} - def retrieve_ith_solution(self, i: int) -> FacilitySolution: + def retrieve_current_solution( + self, + get_var_value_for_current_solution: Callable[[Any], float], + get_obj_value_for_current_solution: Callable[[], float], + ) -> FacilitySolution: facility_for_customer = [0] * self.problem.customer_count for ( variable_decision_key, variable_decision_value, ) in self.variable_decision["x"].items(): if not isinstance(variable_decision_value, int): - value = self.get_var_value_for_ith_solution(variable_decision_value, i) + value = get_var_value_for_current_solution(variable_decision_value) else: value = variable_decision_value if value >= 0.5: diff --git a/discrete_optimization/generic_tools/lp_tools.py b/discrete_optimization/generic_tools/lp_tools.py index 3bdacd05..89a2dc8b 100644 --- a/discrete_optimization/generic_tools/lp_tools.py +++ b/discrete_optimization/generic_tools/lp_tools.py @@ -5,12 +5,17 @@ import logging from abc import abstractmethod from enum import Enum -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import mip +from discrete_optimization.generic_tools.callbacks.callback import ( + Callback, + CallbackList, +) from discrete_optimization.generic_tools.do_problem import Solution from discrete_optimization.generic_tools.do_solver import SolverDO +from discrete_optimization.generic_tools.exceptions import SolveEarlyStop from discrete_optimization.generic_tools.result_storage.multiobj_utils import ( TupleFitness, ) @@ -24,6 +29,7 @@ gurobi_available = False else: gurobi_available = True + GRB = gurobipy.GRB try: import docplex @@ -116,13 +122,55 @@ def retrieve_solutions( ) @abstractmethod - def retrieve_ith_solution(self, i: int) -> Solution: - """Retrieve i-th solution from internal milp model.""" + def retrieve_current_solution( + self, + get_var_value_for_current_solution: Callable[[Any], float], + get_obj_value_for_current_solution: Callable[[], float], + ) -> Solution: + """Retrieve current solution from internal gurobi solution. + + This converts internal gurobi solution into a discrete-optimization Solution. + This method can be called after the solve in `retrieve_solutions()` + or during solve within a gurobi/pymilp/cplex callback. The difference will be the + `get_var_value_for_current_solution` and `get_obj_value_for_current_solution` callables passed. + + Args: + get_var_value_for_current_solution: function extracting the value of the given variable for the current solution + will be different when inside a callback or after the solve is finished + get_obj_value_for_current_solution: function extracting the value of the objective for the current solution. + + Returns: + the converted solution at d-o format + + """ ... + def retrieve_ith_solution(self, i: int) -> Solution: + """Retrieve i-th solution from internal milp model. + + Args: + i: + + Returns: + + """ + get_var_value_for_current_solution = ( + lambda var: self.get_var_value_for_ith_solution(var=var, i=i) + ) + get_obj_value_for_current_solution = ( + lambda: self.get_obj_value_for_ith_solution(i=i) + ) + return self.retrieve_current_solution( + get_var_value_for_current_solution=get_var_value_for_current_solution, + get_obj_value_for_current_solution=get_obj_value_for_current_solution, + ) + @abstractmethod def solve( - self, parameters_milp: Optional[ParametersMilp] = None, **kwargs: Any + self, + callbacks: Optional[List[Callback]] = None, + parameters_milp: Optional[ParametersMilp] = None, + **kwargs: Any, ) -> ResultStorage: ... @@ -219,14 +267,46 @@ class GurobiMilpSolver(MilpSolver): """Milp solver wrapping a solver from gurobi library.""" model: Optional["gurobipy.Model"] = None + early_stopping_exception: Optional[Exception] = None def solve( - self, parameters_milp: Optional[ParametersMilp] = None, **kwargs: Any + self, + callbacks: Optional[List[Callback]] = None, + parameters_milp: Optional[ParametersMilp] = None, + **kwargs: Any, ) -> ResultStorage: + self.early_stopping_exception = None + callbacks_list = CallbackList(callbacks=callbacks) if parameters_milp is None: parameters_milp = ParametersMilp.default() - self.optimize_model(parameters_milp=parameters_milp, **kwargs) - return self.retrieve_solutions(parameters_milp=parameters_milp) + + # callback: solve start + callbacks_list.on_solve_start(solver=self) + + if self.model is None: + self.init_model(**kwargs) + if self.model is None: + raise RuntimeError( + "self.model must not be None after self.init_model()." + ) + self.prepare_model(parameters_milp=parameters_milp, **kwargs) + + # wrap user callback in a gurobi callback + gurobi_callback = GurobiCallback(do_solver=self, callback=callbacks_list) + self.model.optimize(gurobi_callback) + # raise potential exception found during callback (useful for optuna pruning, and debugging) + if self.early_stopping_exception: + if isinstance(self.early_stopping_exception, SolveEarlyStop): + logger.info(self.early_stopping_exception) + else: + raise self.early_stopping_exception + # get result storage + res = gurobi_callback.res + + # callback: solve end + callbacks_list.on_solve_end(res=res, solver=self) + + return res def prepare_model( self, parameters_milp: Optional[ParametersMilp] = None, **kwargs: Any @@ -254,6 +334,7 @@ def optimize_model( """Optimize the Gurobi Model. The solutions are yet to be retrieved via `self.retrieve_solutions()`. + No callbacks are passed to the internal solver, and no result_storage is created """ if self.model is None: @@ -296,8 +377,45 @@ def nb_solutions(self) -> int: return self.model.SolCount -def gurobi_callback(model, where): - ... +class GurobiCallback: + def __init__(self, do_solver: GurobiMilpSolver, callback: Callback): + self.do_solver = do_solver + self.callback = callback + self.res = ResultStorage( + [], + mode_optim=self.do_solver.params_objective_function.sense_function, + limit_store=False, + ) + self.nb_solutions = 0 + + def __call__(self, model, where) -> None: + if where == GRB.Callback.MIPSOL: + try: + # retrieve and store new solution + sol = self.do_solver.retrieve_current_solution( + get_var_value_for_current_solution=model.cbGetSolution, + get_obj_value_for_current_solution=lambda: model.cbGet( + GRB.Callback.MIPSOL_OBJ + ), + ) + fit = self.do_solver.aggreg_from_sol(sol) + self.res.add_solution(solution=sol, fitness=fit) + self.nb_solutions += 1 + # end of step callback: stopping? + stopping = self.callback.on_step_end( + step=self.nb_solutions, res=self.res, solver=self.do_solver + ) + except Exception as e: + # catch exceptions because gurobi ignore them and do not stop solving + self.do_solver.early_stopping_exception = e + stopping = True + else: + if stopping: + self.do_solver.early_stopping_exception = SolveEarlyStop( + f"{self.do_solver.__class__.__name__}.solve() stopped by user callback." + ) + if stopping: + model.terminate() class CplexMilpSolver(MilpSolver): diff --git a/discrete_optimization/knapsack/solvers/lp_solvers.py b/discrete_optimization/knapsack/solvers/lp_solvers.py index 1c59fa8f..4dd0407a 100644 --- a/discrete_optimization/knapsack/solvers/lp_solvers.py +++ b/discrete_optimization/knapsack/solvers/lp_solvers.py @@ -3,7 +3,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Any, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import mip from mip import BINARY, MAXIMIZE, xsum @@ -57,7 +57,11 @@ def __init__( self.description_variable_description: Dict[str, Dict[str, Any]] = {} self.description_constraint: Dict[str, Dict[str, str]] = {} - def retrieve_ith_solution(self, i: int) -> KnapsackSolution: + def retrieve_current_solution( + self, + get_var_value_for_current_solution: Callable[[Any], float], + get_obj_value_for_current_solution: Callable[[], float], + ) -> KnapsackSolution: weight = 0.0 value_kp = 0.0 xs = {} @@ -65,7 +69,7 @@ def retrieve_ith_solution(self, i: int) -> KnapsackSolution: variable_decision_key, variable_decision_value, ) in self.variable_decision["x"].items(): - value = self.get_var_value_for_ith_solution(variable_decision_value, i) + value = get_var_value_for_current_solution(variable_decision_value) if value <= 0.1: xs[variable_decision_key] = 0 continue diff --git a/discrete_optimization/pickup_vrp/solver/lp_solver.py b/discrete_optimization/pickup_vrp/solver/lp_solver.py index dfa2628a..cf845b7f 100644 --- a/discrete_optimization/pickup_vrp/solver/lp_solver.py +++ b/discrete_optimization/pickup_vrp/solver/lp_solver.py @@ -105,18 +105,21 @@ def convert_temporaryresult_to_gpdpsolution( ) -def retrieve_ith_solution( - i: int, model: "grb.Model", variable_decisions: Dict[str, Any] +def retrieve_current_solution( + get_var_value_for_current_solution: Callable[[Any], float], + get_obj_value_for_current_solution: Callable[[], float], + variable_decisions: Dict[str, Any], ) -> Tuple[Dict[str, Dict[Hashable, Any]], float]: results: Dict[str, Dict[Hashable, Any]] = {} xsolution: Dict[int, Dict[Edge, int]] = { v: {} for v in variable_decisions["variables_edges"] } - model.params.SolutionNumber = i - obj = model.getAttr("PoolObjVal") + obj = get_obj_value_for_current_solution() for vehicle in variable_decisions["variables_edges"]: for edge in variable_decisions["variables_edges"][vehicle]: - value = variable_decisions["variables_edges"][vehicle][edge].getAttr("Xn") + value = get_var_value_for_current_solution( + variable_decisions["variables_edges"][vehicle][edge] + ) if value <= 0.1: continue xsolution[vehicle][edge] = 1 @@ -132,31 +135,23 @@ def retrieve_ith_solution( if isinstance(variable_decisions[key][key_2][key_3], dict): results[key][key_2][key_3] = {} for key_4 in variable_decisions[key][key_2][key_3]: - value = variable_decisions[key][key_2][key_3].getAttr("Xn") + value = get_var_value_for_current_solution( + variable_decisions[key][key_2][key_3] + ) results[key][key_2][key_3][key_4] = value else: - value = variable_decisions[key][key_2][key_3].getAttr("Xn") + value = get_var_value_for_current_solution( + variable_decisions[key][key_2][key_3] + ) results[key][key_2][key_3] = value else: - value = variable_decisions[key][key_2].getAttr("Xn") + value = get_var_value_for_current_solution( + variable_decisions[key][key_2] + ) results[key][key_2] = value return results, obj -def retrieve_solutions( - model: "grb.Model", variable_decisions: Dict[str, Any] -) -> List[Tuple[Dict[str, Dict[Hashable, Any]], float]]: - nSolutions = model.SolCount - range_solutions = range(nSolutions) - list_results = [] - for s in range_solutions: - results, obj = retrieve_ith_solution( - s, model=model, variable_decisions=variable_decisions - ) - list_results += [(results, obj)] - return list_results - - class LinearFlowSolver(GurobiMilpSolver, SolverPickupVrp): problem: GPDP @@ -808,8 +803,26 @@ def init_model(self, **kwargs: Any) -> None: self.tsp_version = one_visit_per_node def retrieve_ith_temporaryresult(self, i: int) -> TemporaryResult: - res, obj = retrieve_ith_solution( - i=i, model=self.model, variable_decisions=self.variable_decisions + get_var_value_for_current_solution = ( + lambda var: self.get_var_value_for_ith_solution(var=var, i=i) + ) + get_obj_value_for_current_solution = ( + lambda: self.get_obj_value_for_ith_solution(i=i) + ) + return self.retrieve_current_temporaryresult( + get_var_value_for_current_solution=get_var_value_for_current_solution, + get_obj_value_for_current_solution=get_obj_value_for_current_solution, + ) + + def retrieve_current_temporaryresult( + self, + get_var_value_for_current_solution: Callable[[Any], float], + get_obj_value_for_current_solution: Callable[[], float], + ) -> TemporaryResult: + res, obj = retrieve_current_solution( + get_var_value_for_current_solution=get_var_value_for_current_solution, + get_obj_value_for_current_solution=get_obj_value_for_current_solution, + variable_decisions=self.variable_decisions, ) if self.problem.graph is None: raise RuntimeError( @@ -828,14 +841,21 @@ def retrieve_solutions( kwargs["limit_store"] = False return super().retrieve_solutions(parameters_milp=parameters_milp, **kwargs) - def retrieve_ith_solution(self, i: int) -> GPDPSolution: + def retrieve_current_solution( + self, + get_var_value_for_current_solution: Callable[[Any], float], + get_obj_value_for_current_solution: Callable[[], float], + ) -> GPDPSolution: """ Not used here as GurobiMilpSolver.solve() is overriden """ - temporaryresult = self.retrieve_ith_temporaryresult(i=i) + temporaryresult = self.retrieve_current_temporaryresult( + get_var_value_for_current_solution=get_var_value_for_current_solution, + get_obj_value_for_current_solution=get_obj_value_for_current_solution, + ) return convert_temporaryresult_to_gpdpsolution( temporaryresult=temporaryresult, problem=self.problem ) diff --git a/discrete_optimization/pickup_vrp/solver/lp_solver_pymip.py b/discrete_optimization/pickup_vrp/solver/lp_solver_pymip.py index 75893e91..19748aba 100644 --- a/discrete_optimization/pickup_vrp/solver/lp_solver_pymip.py +++ b/discrete_optimization/pickup_vrp/solver/lp_solver_pymip.py @@ -6,6 +6,7 @@ import random from typing import ( Any, + Callable, Dict, Hashable, Iterable, @@ -53,17 +54,21 @@ def __repr__(self) -> str: return self.message -def retrieve_ith_solution( - i: int, model: mip.Model, variable_decisions: Dict[str, Any] +def retrieve_current_solution( + get_var_value_for_current_solution: Callable[[Any], float], + get_obj_value_for_current_solution: Callable[[], float], + variable_decisions: Dict[str, Any], ) -> Tuple[Dict[str, Dict[Hashable, Any]], float]: results: Dict[str, Dict[Hashable, Any]] = {} xsolution: Dict[int, Dict[Edge, int]] = { v: {} for v in variable_decisions["variables_edges"] } - obj = float(model.objective_values[i]) + obj = float(get_obj_value_for_current_solution()) for vehicle in variable_decisions["variables_edges"]: for edge in variable_decisions["variables_edges"][vehicle]: - value = variable_decisions["variables_edges"][vehicle][edge].xi(i) + value = get_var_value_for_current_solution( + variable_decisions["variables_edges"][vehicle][edge] + ) if value <= 0.1: continue xsolution[vehicle][edge] = 1 @@ -79,13 +84,19 @@ def retrieve_ith_solution( if isinstance(variable_decisions[key][key_2][key_3], dict): results[key][key_2][key_3] = {} for key_4 in variable_decisions[key][key_2][key_3]: - value = variable_decisions[key][key_2][key_3].xi(i) + value = get_var_value_for_current_solution( + variable_decisions[key][key_2][key_3] + ) results[key][key_2][key_3][key_4] = value else: - value = variable_decisions[key][key_2][key_3].xi(i) + value = get_var_value_for_current_solution( + variable_decisions[key][key_2][key_3] + ) results[key][key_2][key_3] = value else: - value = variable_decisions[key][key_2].xi(i) + value = get_var_value_for_current_solution( + variable_decisions[key][key_2] + ) results[key][key_2] = value return results, obj @@ -700,21 +711,46 @@ def retrieve_solutions( kwargs["limit_store"] = False return super().retrieve_solutions(parameters_milp=parameters_milp, **kwargs) - def retrieve_ith_solution(self, i: int) -> GPDPSolution: + def retrieve_current_solution( + self, + get_var_value_for_current_solution: Callable[[Any], float], + get_obj_value_for_current_solution: Callable[[], float], + ) -> Solution: """ Not used here as GurobiMilpSolver.solve() is overriden """ - temporaryresult = self.retrieve_ith_temporaryresult(i=i) + temporaryresult = self.retrieve_current_temporaryresult( + get_var_value_for_current_solution=get_var_value_for_current_solution, + get_obj_value_for_current_solution=get_obj_value_for_current_solution, + ) return convert_temporaryresult_to_gpdpsolution( temporaryresult=temporaryresult, problem=self.problem ) def retrieve_ith_temporaryresult(self, i: int) -> TemporaryResult: - res, obj = retrieve_ith_solution( - i=i, model=self.model, variable_decisions=self.variable_decisions + get_var_value_for_current_solution = ( + lambda var: self.get_var_value_for_ith_solution(var=var, i=i) + ) + get_obj_value_for_current_solution = ( + lambda: self.get_obj_value_for_ith_solution(i=i) + ) + return self.retrieve_current_temporaryresult( + get_var_value_for_current_solution=get_var_value_for_current_solution, + get_obj_value_for_current_solution=get_obj_value_for_current_solution, + ) + + def retrieve_current_temporaryresult( + self, + get_var_value_for_current_solution: Callable[[Any], float], + get_obj_value_for_current_solution: Callable[[], float], + ) -> TemporaryResult: + res, obj = retrieve_current_solution( + get_var_value_for_current_solution=get_var_value_for_current_solution, + get_obj_value_for_current_solution=get_obj_value_for_current_solution, + variable_decisions=self.variable_decisions, ) if self.problem.graph is None: raise RuntimeError( diff --git a/discrete_optimization/rcpsp/solver/rcpsp_lp_solver.py b/discrete_optimization/rcpsp/solver/rcpsp_lp_solver.py index 7c131685..a5ab7a39 100644 --- a/discrete_optimization/rcpsp/solver/rcpsp_lp_solver.py +++ b/discrete_optimization/rcpsp/solver/rcpsp_lp_solver.py @@ -4,7 +4,7 @@ import logging from itertools import product -from typing import Any, Dict, Hashable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Hashable, List, Optional, Tuple, Union from mip import BINARY, INTEGER, MINIMIZE, Model, Var, xsum @@ -258,10 +258,14 @@ def init_model(self, **args): ] self.constraints_partial_solutions = constraints - def retrieve_ith_solution(self, i: int) -> RCPSPSolution: + def retrieve_current_solution( + self, + get_var_value_for_current_solution: Callable[[Any], float], + get_obj_value_for_current_solution: Callable[[], float], + ) -> RCPSPSolution: rcpsp_schedule = {} for (task_index, time) in product(self.index_task, self.index_time): - value = self.get_var_value_for_ith_solution(self.x[task_index][time], i) + value = get_var_value_for_current_solution(self.x[task_index][time]) if value >= 0.5: task = self.problem.tasks_list[task_index] rcpsp_schedule[task] = { @@ -292,11 +296,15 @@ def __init__( self.variable_decision = {} self.constraints_dict = {"lns": []} - def retrieve_ith_solution(self, i: int) -> RCPSPSolution: + def retrieve_current_solution( + self, + get_var_value_for_current_solution: Callable[[Any], float], + get_obj_value_for_current_solution: Callable[[], float], + ) -> RCPSPSolution: rcpsp_schedule = {} modes: Dict[Hashable, Union[str, int]] = {} for (task, mode, t), x in self.x.items(): - value = self.get_var_value_for_ith_solution(x, i) + value = get_var_value_for_current_solution(x) if value >= 0.5: rcpsp_schedule[task] = { "start_time": t, diff --git a/discrete_optimization/rcpsp/solver/rcpsp_lp_solver_gantt.py b/discrete_optimization/rcpsp/solver/rcpsp_lp_solver_gantt.py index 2b05e3ef..270068e6 100644 --- a/discrete_optimization/rcpsp/solver/rcpsp_lp_solver_gantt.py +++ b/discrete_optimization/rcpsp/solver/rcpsp_lp_solver_gantt.py @@ -4,7 +4,7 @@ import logging import random -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import networkx as nx from mip import BINARY, MINIMIZE, Model, xsum @@ -12,6 +12,7 @@ from discrete_optimization.generic_tools.do_problem import ( ModeOptim, ParamsObjectiveFunction, + Solution, ) from discrete_optimization.generic_tools.lp_tools import ( GurobiMilpSolver, @@ -126,14 +127,16 @@ def __init__( self.params_objective_function.sense_function = self.sense_optim self.constraint_additionnal = {} - def retrieve_ith_solution( - self, i: int + def retrieve_current_solution( + self, + get_var_value_for_current_solution: Callable[[Any], float], + get_obj_value_for_current_solution: Callable[[], float], ) -> Tuple[Dict[Any, Dict[Any, Dict[Any, Any]]], float]: - objective = self.get_obj_value_for_ith_solution(i) + objective = get_obj_value_for_current_solution() resource_id_usage = { k: { individual: { - task: self.get_var_value_for_ith_solution(resource_usage, i) + task: get_var_value_for_current_solution(resource_usage) for task, resource_usage in self.ressource_id_usage[k][ individual ].items() @@ -398,14 +401,16 @@ def adding_constraint( ] self.model.update() - def retrieve_ith_solution( - self, i: int + def retrieve_current_solution( + self, + get_var_value_for_current_solution: Callable[[Any], float], + get_obj_value_for_current_solution: Callable[[], float], ) -> Tuple[Dict[Any, Dict[Any, Dict[Any, Any]]], float]: - objective = self.get_pool_obj_value_for_ith_solution(i) + objective = get_obj_value_for_current_solution() resource_id_usage = { k: { individual: { - task: self.get_var_value_for_ith_solution(resource_usage, i) + task: get_var_value_for_current_solution(resource_usage) for task, resource_usage in self.ressource_id_usage[k][ individual ].items() diff --git a/discrete_optimization/rcpsp_multiskill/solvers/lp_model.py b/discrete_optimization/rcpsp_multiskill/solvers/lp_model.py index d72cdae3..eee0b167 100644 --- a/discrete_optimization/rcpsp_multiskill/solvers/lp_model.py +++ b/discrete_optimization/rcpsp_multiskill/solvers/lp_model.py @@ -5,7 +5,7 @@ import logging import time from itertools import product -from typing import Optional +from typing import Any, Callable, Optional from mip import BINARY, INTEGER, MINIMIZE, Model, xsum @@ -277,7 +277,11 @@ def init_model(self, **args): ) self.model.objective = self.start_times_task[max(self.start_times_task)] - def retrieve_ith_solution(self, i: int) -> MS_RCPSPSolution: + def retrieve_current_solution( + self, + get_var_value_for_current_solution: Callable[[Any], float], + get_obj_value_for_current_solution: Callable[[], float], + ) -> MS_RCPSPSolution: rcpsp_schedule = {} modes = {} results = {} @@ -286,7 +290,7 @@ def retrieve_ith_solution(self, i: int) -> MS_RCPSPSolution: for task in self.start_times: for mode in self.start_times[task]: for t, start_time in self.start_times[task][mode].items(): - value = self.get_var_value_for_ith_solution(start_time, i) + value = get_var_value_for_current_solution(start_time) results[(task, mode, t)] = value if value >= 0.5: rcpsp_schedule[task] = { @@ -297,8 +301,8 @@ def retrieve_ith_solution(self, i: int) -> MS_RCPSPSolution: } modes[task] = mode for t in self.employee_usage: - employee_usage[t] = self.get_var_value_for_ith_solution( - self.employee_usage[t], i + employee_usage[t] = get_var_value_for_current_solution( + self.employee_usage[t] ) if employee_usage[t] >= 0.5: if t[1] not in employee_usage_solution: @@ -312,20 +316,18 @@ def retrieve_ith_solution(self, i: int) -> MS_RCPSPSolution: modes_task = {} for t in self.modes: for m, mode in self.modes[t].items(): - modes[(t, m)] = self.get_var_value_for_ith_solution(mode, i) + modes[(t, m)] = get_var_value_for_current_solution(mode) if modes[(t, m)] >= 0.5: modes_task[t] = m durations = {} for t in self.durations: - durations[t] = self.get_var_value_for_ith_solution(self.durations[t], i) + durations[t] = get_var_value_for_current_solution(self.durations[t]) start_time = {} for t in self.start_times_task: - start_time[t] = self.get_var_value_for_ith_solution( - self.start_times_task[t], i - ) + start_time[t] = get_var_value_for_current_solution(self.start_times_task[t]) end_time = {} for t in self.start_times_task: - end_time[t] = self.get_var_value_for_ith_solution(self.end_times_task[t], i) + end_time[t] = get_var_value_for_current_solution(self.end_times_task[t]) logger.debug(f"Size schedule : {len(rcpsp_schedule.keys())}") logger.debug( ( diff --git a/tests/coloring/test_coloring.py b/tests/coloring/test_coloring.py index 4422a654..80ccf756 100644 --- a/tests/coloring/test_coloring.py +++ b/tests/coloring/test_coloring.py @@ -31,6 +31,11 @@ GreedyColoring, NXGreedyColoringMethod, ) +from discrete_optimization.generic_tools.callbacks.callback import Callback +from discrete_optimization.generic_tools.callbacks.early_stoppers import ( + NbIterationStopper, +) +from discrete_optimization.generic_tools.callbacks.loggers import NbIterationTracker from discrete_optimization.generic_tools.cp_tools import ParametersCP from discrete_optimization.generic_tools.do_problem import ( ObjectiveHandling, @@ -271,5 +276,60 @@ def test_color_lp_gurobi(): assert len(result_store.list_solution_fits) > 1 +@pytest.mark.skipif(not gurobi_available, reason="You need Gurobi to test this solver.") +def test_color_lp_gurobi_cb_log(): + file = [f for f in get_data_available() if "gc_70_1" in f][0] + color_problem = parse_file(file) + solver = ColoringLP( + color_problem, + params_objective_function=get_default_objective_setup(color_problem), + ) + tracker = NbIterationTracker() + callbacks = [tracker] + result_store = solver.solve( + parameters_milp=ParametersMilp.default(), callbacks=callbacks + ) + solution = result_store.get_best_solution_fit()[0] + assert len(result_store.list_solution_fits) > 1 + # check tracker called at each solution found + assert tracker.nb_iteration == len(result_store.list_solution_fits) + + +@pytest.mark.skipif(not gurobi_available, reason="You need Gurobi to test this solver.") +def test_color_lp_gurobi_cb_stop(): + file = [f for f in get_data_available() if "gc_70_1" in f][0] + color_problem = parse_file(file) + solver = ColoringLP( + color_problem, + params_objective_function=get_default_objective_setup(color_problem), + ) + stopper = NbIterationStopper(nb_iteration_max=1) + callbacks = [stopper] + result_store = solver.solve( + parameters_milp=ParametersMilp.default(), callbacks=callbacks + ) + # check stop after 1st iteration + assert len(result_store.list_solution_fits) == 1 + + +class MyCallbackNok(Callback): + def on_step_end(self, step: int, res, solver): + raise RuntimeError("Explicit crash") + + +@pytest.mark.skipif(not gurobi_available, reason="You need Gurobi to test this solver.") +def test_color_lp_gurobi_cb_exception(): + file = [f for f in get_data_available() if "gc_70_1" in f][0] + color_problem = parse_file(file) + solver = ColoringLP( + color_problem, + params_objective_function=get_default_objective_setup(color_problem), + ) + with pytest.raises(RuntimeError, match="Explicit crash"): + solver.solve( + parameters_milp=ParametersMilp.default(), callbacks=[MyCallbackNok()] + ) + + if __name__ == "__main__": test_solvers()