pyepo.func.jax.regularized¶
Regularized differentiable optimization function (L2 Frank-Wolfe)
Attributes¶
Classes¶
L2-Regularized Frank-Wolfe Optimizer (RFWO) -- differentiable smoothed solver. |
|
L2-Regularized Frank-Wolfe with Fenchel-Young loss (RFYL). |
Module Contents¶
- class pyepo.func.jax.regularized.regularizedFrankWolfeOpt(optmodel, lambd=1.0, max_iter=10000, tol=1e-06, processes=1, solve_ratio=1.0, dataset=None)¶
Bases:
pyepo.func.jax.abcmodule.optModuleL2-Regularized Frank-Wolfe Optimizer (RFWO) – differentiable smoothed solver.
Returns the L2-regularized minimizer over conv(S), solved by batched Frank-Wolfe (the only oracle is the standard linear
optModelsolve). Returns a regularized solution, not a loss – pair with a task loss, or useregularizedFrankWolfeFenchelYoung. The FW loop needs an accurate linear minimization oracle; prefer the callback path with an exact solver (MPAX is approximate here).Reference: Dalle et al. (2022) https://arxiv.org/abs/2207.13513
- lambd¶
- max_iter = 10000¶
- tol = 1e-06¶
- forward(pred_cost)¶
Forward pass
- class pyepo.func.jax.regularized.regularizedFrankWolfeFenchelYoung(optmodel, lambd=1.0, max_iter=10000, tol=1e-06, processes=1, solve_ratio=1.0, reduction: pyepo.func.runtime.Reduction = 'mean', dataset=None)¶
Bases:
pyepo.func.jax.abcmodule.optModuleL2-Regularized Frank-Wolfe with Fenchel-Young loss (RFYL).
Pairs the RFWO regularized solver with the Fenchel-Young loss of the L2 regularizer: a convex scalar comparing the predicted cost to the true optimum directly. By Danskin’s theorem the gradient is the residual
w - r_sol, so the backward needs no implicit differentiation.Reference: Dalle et al. (2022) https://arxiv.org/abs/2207.13513
- lambd¶
- max_iter = 10000¶
- tol = 1e-06¶
- forward(pred_cost, true_sol)¶
Forward pass
- pyepo.func.jax.regularized.RFWO¶
- pyepo.func.jax.regularized.RFY¶