pyepo.func.jax.perturbed ======================== .. py:module:: pyepo.func.jax.perturbed .. autoapi-nested-parse:: Perturbed optimization function Attributes ---------- .. autoapisummary:: pyepo.func.jax.perturbed.DPO pyepo.func.jax.perturbed.DPOMul pyepo.func.jax.perturbed.PFY pyepo.func.jax.perturbed.PFYMul pyepo.func.jax.perturbed.IMLE pyepo.func.jax.perturbed.AIMLE Classes ------- .. autoapisummary:: pyepo.func.jax.perturbed.perturbedOpt pyepo.func.jax.perturbed.perturbedOptMul pyepo.func.jax.perturbed.perturbedFenchelYoung pyepo.func.jax.perturbed.perturbedFenchelYoungMul pyepo.func.jax.perturbed.implicitMLE pyepo.func.jax.perturbed.adaptiveImplicitMLE Module Contents --------------- .. py:class:: perturbedOpt(optmodel, n_samples=10, sigma=1.0, processes=1, seed=135, variance_reduction=True, solve_ratio=1.0, dataset=None) Bases: :py:obj:`pyepo.func.jax.abcmodule.optModule` Differentiable Perturbed Optimizer (DPO) -- additive-Gaussian variant. Estimates the expected solution :math:`\mathbb{E}_{\boldsymbol{\xi}}[\mathbf{w}^*(\hat{\mathbf{c}} + \sigma\boldsymbol{\xi})]` by Monte Carlo averaging, giving an informative gradient where the bare solver gives zero. Returns a solution; pair with a task loss. The forward accepts ``key=`` for explicit RNG control, required under ``jax.jit``. Reference: Berthet et al. (2020) ``_ .. py:attribute:: n_samples :value: 10 .. py:attribute:: sigma .. py:attribute:: variance_reduction :value: True .. py:method:: forward(pred_cost, key=None) Forward pass .. py:class:: perturbedOptMul(optmodel, n_samples=10, sigma=1.0, processes=1, seed=135, variance_reduction=True, solve_ratio=1.0, dataset=None) Bases: :py:obj:`perturbedOpt` Differentiable Perturbed Optimizer (DPO) -- multiplicative log-normal variant. As :class:`perturbedOpt`, but perturbs the cost multiplicatively with log-normal noise :math:`\exp(\sigma\boldsymbol{\xi} - \sigma^2/2)`. Reference: Dalle et al. (2022) ``_ .. py:class:: perturbedFenchelYoung(optmodel, n_samples=10, sigma=1.0, processes=1, seed=135, solve_ratio=1.0, reduction: pyepo.func.runtime.Reduction = 'mean', dataset=None) Bases: :py:obj:`pyepo.func.jax.abcmodule.optModule` Perturbed Fenchel-Young loss (PFYL) -- additive-Gaussian variant. Pairs a Monte-Carlo expected perturbed solution with the Fenchel-Young loss against the true optimum, returning a scalar loss whose gradient is the residual :math:`\mathbf{w}^*(\mathbf{c}) - \mathbb{E}_{\boldsymbol{\xi}}[\mathbf{w}^*(\hat{\mathbf{c}} + \sigma\boldsymbol{\xi})]`. The forward accepts ``key=`` for explicit RNG control, required under ``jax.jit``. Reference: Berthet et al. (2020) ``_ .. py:attribute:: n_samples :value: 10 .. py:attribute:: sigma .. py:method:: forward(pred_cost, true_sol, key=None) Forward pass .. py:class:: perturbedFenchelYoungMul(optmodel, n_samples=10, sigma=1.0, processes=1, seed=135, solve_ratio=1.0, reduction: pyepo.func.runtime.Reduction = 'mean', dataset=None) Bases: :py:obj:`perturbedFenchelYoung` Perturbed Fenchel-Young loss (PFYL) -- multiplicative log-normal variant. As :class:`perturbedFenchelYoung`, but perturbs the cost multiplicatively with log-normal noise :math:`\exp(\sigma\boldsymbol{\xi} - \sigma^2/2)`. Reference: Dalle et al. (2022) ``_ .. py:class:: implicitMLE(optmodel, n_samples=10, sigma=1.0, lambd=10, kappa=5, n_iterations=10, two_sides=False, seed=135, processes=1, solve_ratio=1.0, dataset=None) Bases: :py:obj:`pyepo.func.jax.abcmodule.optModule` Implicit Maximum Likelihood Estimator (I-MLE) via perturb-and-MAP. Frames decision-focused learning as imitation: an upstream gradient induces a virtual update :math:`\hat{\mathbf{c}} + \lambda \mathbf{d}`, and the gradient is a directional finite difference between smoothed solutions at the updated and original costs, sharing one Sum-of-Gamma noise realization across both. The forward accepts ``key=`` for explicit RNG control, required under ``jax.jit``. Reference: Niepert, Minervini & Franceschi (2021) ``_ .. py:attribute:: n_samples :value: 10 .. py:attribute:: sigma .. py:attribute:: lambd .. py:attribute:: kappa .. py:attribute:: n_iterations :value: 10 .. py:attribute:: two_sides :value: False .. py:method:: forward(pred_cost, key=None) Forward pass .. py:class:: adaptiveImplicitMLE(optmodel, n_samples=10, sigma=1.0, kappa=5, n_iterations=10, two_sides=False, seed=135, processes=1, solve_ratio=1.0, dataset=None) Bases: :py:obj:`pyepo.func.jax.abcmodule.optModule` Adaptive Implicit MLE (AI-MLE): I-MLE with an online-tuned interpolation step. Replaces I-MLE's fixed lambda with :math:`\lambda_t = \alpha_t \|\hat{\mathbf{c}}\| / \|\mathbf{d}\|`, where :math:`\alpha_t` is adapted online from a moving average of the gradient sparsity. Eager-only: the alpha update is a concrete side effect in the backward, so this loss is not ``jax.jit``-able. Reference: Minervini, Franceschi & Niepert (2023) ``_ .. py:attribute:: n_samples :value: 10 .. py:attribute:: sigma .. py:attribute:: kappa .. py:attribute:: n_iterations :value: 10 .. py:attribute:: two_sides :value: False .. py:attribute:: alpha :value: 1.0 .. py:attribute:: grad_norm_avg :value: 1.0 .. py:attribute:: step :value: 0.001 .. py:method:: forward(pred_cost) Forward pass .. py:data:: DPO .. py:data:: DPOMul .. py:data:: PFY .. py:data:: PFYMul .. py:data:: IMLE .. py:data:: AIMLE