pyepo.func.jax.surrogate

Surrogate Loss function

Attributes

Classes

SPOPlus

SPO+ loss: a convex surrogate for the SPO regret of a linear-objective LP.

perturbationGradient

Perturbation Gradient (PG): zeroth-order surrogate of the objective-value loss.

Module Contents

class pyepo.func.jax.surrogate.SPOPlus(optmodel, processes=1, solve_ratio=1.0, reduction: pyepo.func.runtime.Reduction = 'mean', dataset=None)

Bases: pyepo.func.jax.abcmodule.optModule

SPO+ loss: a convex surrogate for the SPO regret of a linear-objective LP.

SPO+ upper-bounds the SPO regret with a convex function of the predicted cost vector and provides an informative subgradient (via Danskin’s theorem) for end-to-end training. It is the strong default for predict-then-optimize when true optimal solutions \(\mathbf{w}^*(\mathbf{c})\) are available as supervision.

Reference: Elmachtoub & Grigas (2022) https://doi.org/10.1287/mnsc.2020.3922

forward(pred_cost, true_cost, true_sol, true_obj)

Forward pass

class pyepo.func.jax.surrogate.perturbationGradient(optmodel, sigma=0.1, two_sides=False, processes=1, solve_ratio=1.0, reduction: pyepo.func.runtime.Reduction = 'mean', dataset=None)

Bases: pyepo.func.jax.abcmodule.optModule

Perturbation Gradient (PG): zeroth-order surrogate of the objective-value loss.

Approximates the directional derivative of the optimal objective along the true cost with a finite difference, giving an informative gradient through the piecewise-constant solver layer. two_sides selects backward (default) vs central differencing. Needs only the true cost, not the true optimal solution.

Reference: Gupta & Huang (2024) https://arxiv.org/abs/2402.03256

sigma
two_sides = False
forward(pred_cost, true_cost)

Forward pass

pyepo.func.jax.surrogate.smartPredictThenOptimizePlus
pyepo.func.jax.surrogate.PG