pyepo.func.jax.regularized ========================== .. py:module:: pyepo.func.jax.regularized .. autoapi-nested-parse:: Regularized differentiable optimization function (L2 Frank-Wolfe) Attributes ---------- .. autoapisummary:: pyepo.func.jax.regularized.RFWO pyepo.func.jax.regularized.RFY Classes ------- .. autoapisummary:: pyepo.func.jax.regularized.regularizedFrankWolfeOpt pyepo.func.jax.regularized.regularizedFrankWolfeFenchelYoung Module Contents --------------- .. py:class:: regularizedFrankWolfeOpt(optmodel, lambd=1.0, max_iter=10000, tol=1e-06, processes=1, solve_ratio=1.0, dataset=None) Bases: :py:obj:`pyepo.func.jax.abcmodule.optModule` L2-Regularized Frank-Wolfe Optimizer (RFWO) -- differentiable smoothed solver. Returns the L2-regularized minimizer over conv(S), solved by batched Frank-Wolfe (the only oracle is the standard linear ``optModel`` solve). Returns a regularized solution, not a loss -- pair with a task loss, or use ``regularizedFrankWolfeFenchelYoung``. The FW loop needs an accurate linear minimization oracle; prefer the callback path with an exact solver (MPAX is approximate here). Reference: Dalle et al. (2022) ``_ .. py:attribute:: lambd .. py:attribute:: max_iter :value: 10000 .. py:attribute:: tol :value: 1e-06 .. py:method:: forward(pred_cost) Forward pass .. py:class:: regularizedFrankWolfeFenchelYoung(optmodel, lambd=1.0, max_iter=10000, tol=1e-06, processes=1, solve_ratio=1.0, reduction: pyepo.func.runtime.Reduction = 'mean', dataset=None) Bases: :py:obj:`pyepo.func.jax.abcmodule.optModule` L2-Regularized Frank-Wolfe with Fenchel-Young loss (RFYL). Pairs the RFWO regularized solver with the Fenchel-Young loss of the L2 regularizer: a convex scalar comparing the predicted cost to the true optimum directly. By Danskin's theorem the gradient is the residual ``w - r_sol``, so the backward needs no implicit differentiation. Reference: Dalle et al. (2022) ``_ .. py:attribute:: lambd .. py:attribute:: max_iter :value: 10000 .. py:attribute:: tol :value: 1e-06 .. py:method:: forward(pred_cost, true_sol) Forward pass .. py:data:: RFWO .. py:data:: RFY