Skip to content

Commit

Permalink
Implement callbacks integration for gurobi solvers
Browse files Browse the repository at this point in the history
- replace method to implement (for all milp solvers) retrieve_ith_solution() by
  retrieve_current_solution() that take a callable
  get_var_value_for_current_solution(). Then
  - retrieve_ith_solution is retrieve_current_solution with
    get_var_value_for_ith_solution
  - during a gurobi callback, it will be called with model.cbGetSolution
- update milp solvers to reflect that
- create a GurobiCallback class that
   - populates on the fly a result_storage by using
     retrieve_current_solution()
   - calls user-defined d-o callbacks
   - terminate solve if user-defined callbacks decide to early stop
   - catch potential exceptions inside the callback as they are ignored
     by gurobi else (the solve continues even though the error is
     displayed to stderr)
- re-raise in solve() exceptions found in callbacks (useful for
  debugging but also necessary for optuna pruning, based on raising a
  TrialPruned exception)
- add tests using d-o callbacks with gurobi for coloring solver
- the gurobi pickup-vrp solver still does not use callbacks as it is
  overriding GurobiMilpSolver.solve() to perform a loop.
  • Loading branch information
nhuet committed Mar 21, 2024
1 parent d4a1394 commit b7a78d4
Show file tree
Hide file tree
Showing 10 changed files with 341 additions and 80 deletions.
10 changes: 7 additions & 3 deletions discrete_optimization/coloring/solvers/coloring_lp_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import logging
import sys
from typing import Any, Dict, Hashable, Optional, Tuple, Union
from typing import Any, Callable, Dict, Hashable, Optional, Tuple, Union

import mip
import networkx as nx
Expand Down Expand Up @@ -99,13 +99,17 @@ def __init__(
self.sense_optim = self.params_objective_function.sense_function
self.start_solution: Optional[ColoringSolution] = None

def retrieve_ith_solution(self, i: int) -> ColoringSolution:
def retrieve_current_solution(
self,
get_var_value_for_current_solution: Callable[[Any], float],
get_obj_value_for_current_solution: Callable[[], float],
) -> ColoringSolution:
colors = [0] * self.number_of_nodes
for (
variable_decision_key,
variable_decision_value,
) in self.variable_decision["colors_var"].items():
value = self.get_var_value_for_ith_solution(variable_decision_value, i)
value = get_var_value_for_current_solution(variable_decision_value)
if value >= 0.5:
node = variable_decision_key[0]
color = variable_decision_key[1]
Expand Down
10 changes: 7 additions & 3 deletions discrete_optimization/facility/solvers/facility_lp_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

import mip
import numpy as np
Expand Down Expand Up @@ -136,14 +136,18 @@ def __init__(
}
self.description_constraint: Dict[str, Dict[str, str]] = {}

def retrieve_ith_solution(self, i: int) -> FacilitySolution:
def retrieve_current_solution(
self,
get_var_value_for_current_solution: Callable[[Any], float],
get_obj_value_for_current_solution: Callable[[], float],
) -> FacilitySolution:
facility_for_customer = [0] * self.problem.customer_count
for (
variable_decision_key,
variable_decision_value,
) in self.variable_decision["x"].items():
if not isinstance(variable_decision_value, int):
value = self.get_var_value_for_ith_solution(variable_decision_value, i)
value = get_var_value_for_current_solution(variable_decision_value)
else:
value = variable_decision_value
if value >= 0.5:
Expand Down
136 changes: 127 additions & 9 deletions discrete_optimization/generic_tools/lp_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@
import logging
from abc import abstractmethod
from enum import Enum
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union

import mip

from discrete_optimization.generic_tools.callbacks.callback import (
Callback,
CallbackList,
)
from discrete_optimization.generic_tools.do_problem import Solution
from discrete_optimization.generic_tools.do_solver import SolverDO
from discrete_optimization.generic_tools.exceptions import SolveEarlyStop
from discrete_optimization.generic_tools.result_storage.multiobj_utils import (
TupleFitness,
)
Expand All @@ -24,6 +29,7 @@
gurobi_available = False
else:
gurobi_available = True
GRB = gurobipy.GRB

try:
import docplex
Expand Down Expand Up @@ -116,13 +122,55 @@ def retrieve_solutions(
)

@abstractmethod
def retrieve_ith_solution(self, i: int) -> Solution:
"""Retrieve i-th solution from internal milp model."""
def retrieve_current_solution(
self,
get_var_value_for_current_solution: Callable[[Any], float],
get_obj_value_for_current_solution: Callable[[], float],
) -> Solution:
"""Retrieve current solution from internal gurobi solution.
This converts internal gurobi solution into a discrete-optimization Solution.
This method can be called after the solve in `retrieve_solutions()`
or during solve within a gurobi/pymilp/cplex callback. The difference will be the
`get_var_value_for_current_solution` and `get_obj_value_for_current_solution` callables passed.
Args:
get_var_value_for_current_solution: function extracting the value of the given variable for the current solution
will be different when inside a callback or after the solve is finished
get_obj_value_for_current_solution: function extracting the value of the objective for the current solution.
Returns:
the converted solution at d-o format
"""
...

def retrieve_ith_solution(self, i: int) -> Solution:
"""Retrieve i-th solution from internal milp model.
Args:
i:
Returns:
"""
get_var_value_for_current_solution = (
lambda var: self.get_var_value_for_ith_solution(var=var, i=i)
)
get_obj_value_for_current_solution = (
lambda: self.get_obj_value_for_ith_solution(i=i)
)
return self.retrieve_current_solution(
get_var_value_for_current_solution=get_var_value_for_current_solution,
get_obj_value_for_current_solution=get_obj_value_for_current_solution,
)

@abstractmethod
def solve(
self, parameters_milp: Optional[ParametersMilp] = None, **kwargs: Any
self,
callbacks: Optional[List[Callback]] = None,
parameters_milp: Optional[ParametersMilp] = None,
**kwargs: Any,
) -> ResultStorage:
...

Expand Down Expand Up @@ -219,14 +267,46 @@ class GurobiMilpSolver(MilpSolver):
"""Milp solver wrapping a solver from gurobi library."""

model: Optional["gurobipy.Model"] = None
early_stopping_exception: Optional[Exception] = None

def solve(
self, parameters_milp: Optional[ParametersMilp] = None, **kwargs: Any
self,
callbacks: Optional[List[Callback]] = None,
parameters_milp: Optional[ParametersMilp] = None,
**kwargs: Any,
) -> ResultStorage:
self.early_stopping_exception = None
callbacks_list = CallbackList(callbacks=callbacks)
if parameters_milp is None:
parameters_milp = ParametersMilp.default()
self.optimize_model(parameters_milp=parameters_milp, **kwargs)
return self.retrieve_solutions(parameters_milp=parameters_milp)

# callback: solve start
callbacks_list.on_solve_start(solver=self)

if self.model is None:
self.init_model(**kwargs)
if self.model is None:
raise RuntimeError(
"self.model must not be None after self.init_model()."
)
self.prepare_model(parameters_milp=parameters_milp, **kwargs)

# wrap user callback in a gurobi callback
gurobi_callback = GurobiCallback(do_solver=self, callback=callbacks_list)
self.model.optimize(gurobi_callback)
# raise potential exception found during callback (useful for optuna pruning, and debugging)
if self.early_stopping_exception:
if isinstance(self.early_stopping_exception, SolveEarlyStop):
logger.info(self.early_stopping_exception)
else:
raise self.early_stopping_exception
# get result storage
res = gurobi_callback.res

# callback: solve end
callbacks_list.on_solve_end(res=res, solver=self)

return res

def prepare_model(
self, parameters_milp: Optional[ParametersMilp] = None, **kwargs: Any
Expand Down Expand Up @@ -254,6 +334,7 @@ def optimize_model(
"""Optimize the Gurobi Model.
The solutions are yet to be retrieved via `self.retrieve_solutions()`.
No callbacks are passed to the internal solver, and no result_storage is created
"""
if self.model is None:
Expand Down Expand Up @@ -296,8 +377,45 @@ def nb_solutions(self) -> int:
return self.model.SolCount


def gurobi_callback(model, where):
...
class GurobiCallback:
def __init__(self, do_solver: GurobiMilpSolver, callback: Callback):
self.do_solver = do_solver
self.callback = callback
self.res = ResultStorage(
[],
mode_optim=self.do_solver.params_objective_function.sense_function,
limit_store=False,
)
self.nb_solutions = 0

def __call__(self, model, where) -> None:
if where == GRB.Callback.MIPSOL:
try:
# retrieve and store new solution
sol = self.do_solver.retrieve_current_solution(
get_var_value_for_current_solution=model.cbGetSolution,
get_obj_value_for_current_solution=lambda: model.cbGet(
GRB.Callback.MIPSOL_OBJ
),
)
fit = self.do_solver.aggreg_from_sol(sol)
self.res.add_solution(solution=sol, fitness=fit)
self.nb_solutions += 1
# end of step callback: stopping?
stopping = self.callback.on_step_end(
step=self.nb_solutions, res=self.res, solver=self.do_solver
)
except Exception as e:
# catch exceptions because gurobi ignore them and do not stop solving
self.do_solver.early_stopping_exception = e
stopping = True
else:
if stopping:
self.do_solver.early_stopping_exception = SolveEarlyStop(
f"{self.do_solver.__class__.__name__}.solve() stopped by user callback."
)
if stopping:
model.terminate()


class CplexMilpSolver(MilpSolver):
Expand Down
10 changes: 7 additions & 3 deletions discrete_optimization/knapsack/solvers/lp_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Any, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

import mip
from mip import BINARY, MAXIMIZE, xsum
Expand Down Expand Up @@ -57,15 +57,19 @@ def __init__(
self.description_variable_description: Dict[str, Dict[str, Any]] = {}
self.description_constraint: Dict[str, Dict[str, str]] = {}

def retrieve_ith_solution(self, i: int) -> KnapsackSolution:
def retrieve_current_solution(
self,
get_var_value_for_current_solution: Callable[[Any], float],
get_obj_value_for_current_solution: Callable[[], float],
) -> KnapsackSolution:
weight = 0.0
value_kp = 0.0
xs = {}
for (
variable_decision_key,
variable_decision_value,
) in self.variable_decision["x"].items():
value = self.get_var_value_for_ith_solution(variable_decision_value, i)
value = get_var_value_for_current_solution(variable_decision_value)
if value <= 0.1:
xs[variable_decision_key] = 0
continue
Expand Down
Loading

0 comments on commit b7a78d4

Please sign in to comment.