pyepo.func.jax.regularized

Regularized differentiable optimization function (L2 Frank-Wolfe)

Attributes

Classes

regularizedFrankWolfeOpt

L2-Regularized Frank-Wolfe Optimizer (RFWO) -- differentiable smoothed solver.

regularizedFrankWolfeFenchelYoung

L2-Regularized Frank-Wolfe with Fenchel-Young loss (RFYL).

Module Contents

class pyepo.func.jax.regularized.regularizedFrankWolfeOpt(optmodel, lambd=1.0, max_iter=10000, tol=1e-06, processes=1, solve_ratio=1.0, dataset=None)

Bases: pyepo.func.jax.abcmodule.optModule

L2-Regularized Frank-Wolfe Optimizer (RFWO) – differentiable smoothed solver.

Returns the L2-regularized minimizer over conv(S), solved by batched Frank-Wolfe (the only oracle is the standard linear optModel solve). Returns a regularized solution, not a loss – pair with a task loss, or use regularizedFrankWolfeFenchelYoung. The FW loop needs an accurate linear minimization oracle; prefer the callback path with an exact solver (MPAX is approximate here).

Reference: Dalle et al. (2022) https://arxiv.org/abs/2207.13513

lambd
max_iter = 10000
tol = 1e-06
forward(pred_cost)

Forward pass

class pyepo.func.jax.regularized.regularizedFrankWolfeFenchelYoung(optmodel, lambd=1.0, max_iter=10000, tol=1e-06, processes=1, solve_ratio=1.0, reduction: pyepo.func.runtime.Reduction = 'mean', dataset=None)

Bases: pyepo.func.jax.abcmodule.optModule

L2-Regularized Frank-Wolfe with Fenchel-Young loss (RFYL).

Pairs the RFWO regularized solver with the Fenchel-Young loss of the L2 regularizer: a convex scalar comparing the predicted cost to the true optimum directly. By Danskin’s theorem the gradient is the residual w - r_sol, so the backward needs no implicit differentiation.

Reference: Dalle et al. (2022) https://arxiv.org/abs/2207.13513

lambd
max_iter = 10000
tol = 1e-06
forward(pred_cost, true_sol)

Forward pass

pyepo.func.jax.regularized.RFWO
pyepo.func.jax.regularized.RFY