pyepo.func.jax.surrogate¶
Surrogate Loss function
Attributes¶
Classes¶
SPO+ loss: a convex surrogate for the SPO regret of a linear-objective LP. |
|
Perturbation Gradient (PG): zeroth-order surrogate of the objective-value loss. |
Module Contents¶
- class pyepo.func.jax.surrogate.SPOPlus(optmodel, processes=1, solve_ratio=1.0, reduction: pyepo.func.runtime.Reduction = 'mean', dataset=None)¶
Bases:
pyepo.func.jax.abcmodule.optModuleSPO+ loss: a convex surrogate for the SPO regret of a linear-objective LP.
SPO+ upper-bounds the SPO regret with a convex function of the predicted cost vector and provides an informative subgradient (via Danskin’s theorem) for end-to-end training. It is the strong default for predict-then-optimize when true optimal solutions \(\mathbf{w}^*(\mathbf{c})\) are available as supervision.
Reference: Elmachtoub & Grigas (2022) https://doi.org/10.1287/mnsc.2020.3922
- forward(pred_cost, true_cost, true_sol, true_obj)¶
Forward pass
- class pyepo.func.jax.surrogate.perturbationGradient(optmodel, sigma=0.1, two_sides=False, processes=1, solve_ratio=1.0, reduction: pyepo.func.runtime.Reduction = 'mean', dataset=None)¶
Bases:
pyepo.func.jax.abcmodule.optModulePerturbation Gradient (PG): zeroth-order surrogate of the objective-value loss.
Approximates the directional derivative of the optimal objective along the true cost with a finite difference, giving an informative gradient through the piecewise-constant solver layer.
two_sidesselects backward (default) vs central differencing. Needs only the true cost, not the true optimal solution.Reference: Gupta & Huang (2024) https://arxiv.org/abs/2402.03256
- sigma¶
- two_sides = False¶
- forward(pred_cost, true_cost)¶
Forward pass
- pyepo.func.jax.surrogate.smartPredictThenOptimizePlus¶
- pyepo.func.jax.surrogate.PG¶