Skip to content

Commit

Permalink
added new modeling variants, added optunacallback without pruned trial
Browse files Browse the repository at this point in the history
  • Loading branch information
g-poveda committed Mar 7, 2024
1 parent d5575d0 commit 2166d68
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 5 deletions.
19 changes: 19 additions & 0 deletions discrete_optimization/coloring/solvers/coloring_cpsat_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ class ColoringCPSatSolver(OrtoolsCPSatSolver, SolverColoringWithStartingSolution
hyperparameters = [
EnumHyperparameter(name="modeling", enum=ModelingCPSat),
CategoricalHyperparameter(name="warmstart", choices=[True, False]),
CategoricalHyperparameter(name="value_sequence_chain", choices=[True, False]),
CategoricalHyperparameter(name="used_variable", choices=[True, False]),
CategoricalHyperparameter(name="symmetry_on_used", choices=[True, False]),
] + SolverColoringWithStartingSolution.hyperparameters

def __init__(
Expand Down Expand Up @@ -112,6 +115,8 @@ def init_model_binary(self, nb_colors: int, **kwargs):

def init_model_integer(self, nb_colors: int, **kwargs):
used_variable = kwargs.get("used_variable", False)
value_sequence_chain = kwargs.get("value_sequence_chain", False)
symmetry_on_used = kwargs.get("symmetry_on_used", True)
cp_model = CpModel()
variables = [
cp_model.NewIntVar(0, nb_colors - 1, name=f"c_{i}")
Expand All @@ -128,6 +133,17 @@ def init_model_integer(self, nb_colors: int, **kwargs):
variables[ind]
== self.problem.constraints_coloring.color_constraint[node]
)
if value_sequence_chain:
vars = [variables[i] for i in self.problem.index_subset_nodes]
sliding_max = [
cp_model.NewIntVar(0, min(i, nb_colors), name=f"m_{i}")
for i in range(len(vars))
]
cp_model.Add(vars[0] == sliding_max[0])
self.variables["sliding_max"] = sliding_max
for k in range(1, len(vars)):
cp_model.AddMaxEquality(sliding_max[k], [sliding_max[k - 1], vars[k]])
cp_model.Add(sliding_max[k] <= sliding_max[k - 1] + 1)
used = [cp_model.NewBoolVar(name=f"used_{c}") for c in range(nb_colors)]
if used_variable:

Expand All @@ -147,6 +163,9 @@ def add_indicator(vars, value, presence_value, model):
else:
vars = variables
add_indicator(vars, j, used[j], cp_model)
if symmetry_on_used:
for j in range(nb_colors - 1):
cp_model.Add(used[j] >= used[j + 1])
cp_model.Minimize(sum(used))
else:
nbc = cp_model.NewIntVar(0, nb_colors, name="nbcolors")
Expand Down
42 changes: 42 additions & 0 deletions discrete_optimization/generic_tools/callbacks/optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,48 @@
logger.warning("You should install optuna to use callbacks for optuna.")


class OptunaReportSingleFitCallback(Callback):
"""Callback to report intermediary objective values, this is mostly useful for visualisation purposes.
Adapted to single objective optimization (res.fit is a float)
Args:
trial:
A :class:`optuna.trial.Trial` corresponding to the current evaluation of the
objective function.
optuna_report_nb_steps: report intermediate result every `optuna_report_nb_steps` steps
when the number of iterations is high, setting this to 1 could slow too much run of a single trial
"""

def __init__(
self, trial: optuna.trial.Trial, optuna_report_nb_steps: int = 1, **kwargs
) -> None:
self.report_nb_steps = optuna_report_nb_steps
self.trial = trial

def on_step_end(
self, step: int, res: ResultStorage, solver: SolverDO
) -> Optional[bool]:
"""Called at the end of an optimization step.
Args:
step: index of step
res: current result storage
solver: solvers using the callback
Returns:
If `True`, the optimization process is stopped, else it goes on.
"""
if step % self.report_nb_steps == 0:
fit = res.best_fit

# Report current score and step to Optuna's trial.
self.trial.report(float(fit), step=step)


class OptunaPruningSingleFitCallback(Callback):
"""Callback to prune unpromising trials during Optuna hyperparameters tuning.
Expand Down
10 changes: 8 additions & 2 deletions examples/coloring/coloring_cpspat_solver_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,16 @@

def run_cpsat_coloring():
logging.basicConfig(level=logging.INFO)
file = [f for f in get_data_available() if "gc_100_7" in f][0]
file = [f for f in get_data_available() if "gc_250_3" in f][0]
color_problem = parse_file(file)
solver = ColoringCPSatSolver(color_problem, params_objective_function=None)
solver.init_model(modeling=ModelingCPSat.BINARY, warmstart=True)
solver.init_model(
modeling=ModelingCPSat.INTEGER,
warmstart=False,
value_sequence_chain=False,
used_variable=True,
symmetry_on_used=True,
)
p = ParametersCP.default_cpsat()
p.time_limit = 100
logging.info("Starting solve")
Expand Down
8 changes: 5 additions & 3 deletions examples/coloring/optuna_full_example_coloring.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from discrete_optimization.generic_tools.callbacks.optuna import (
OptunaPruningSingleFitCallback,
OptunaReportSingleFitCallback,
)
from discrete_optimization.generic_tools.cp_tools import ParametersCP
from discrete_optimization.generic_tools.do_problem import ModeOptim
Expand All @@ -40,7 +41,7 @@
seed = 42
optuna_nb_trials = 150

study_name = f"coloring_cpsat-auto-250---"
study_name = f"coloring_cpsat-auto-250-7"
storage_path = "./optuna-journal.log" # NFS path for distributed optimization

# Solvers to test and their associated kwargs
Expand All @@ -61,7 +62,7 @@
}

# problem definition
file = [f for f in get_data_available() if "gc_250_5" in f][0]
file = [f for f in get_data_available() if "gc_250_7" in f][0]
problem = parse_file(file)

# sense of optimization
Expand Down Expand Up @@ -123,7 +124,8 @@ def objective(trial: Trial):
# solve
sol, fit = solver.solve(
callbacks=[
OptunaPruningSingleFitCallback(trial=trial, **kwargs),
OptunaReportSingleFitCallback(trial=trial, **kwargs)
# OptunaPruningSingleFitCallback(trial=trial, **kwargs),
],
**kwargs,
).get_best_solution_fit()
Expand Down

0 comments on commit 2166d68

Please sign in to comment.