pyepo.func.jax.perturbed

Perturbed optimization function

Attributes

Classes

perturbedOpt

Differentiable Perturbed Optimizer (DPO) -- additive-Gaussian variant.

perturbedOptMul

Differentiable Perturbed Optimizer (DPO) -- multiplicative log-normal variant.

perturbedFenchelYoung

Perturbed Fenchel-Young loss (PFYL) -- additive-Gaussian variant.

perturbedFenchelYoungMul

Perturbed Fenchel-Young loss (PFYL) -- multiplicative log-normal variant.

implicitMLE

Implicit Maximum Likelihood Estimator (I-MLE) via perturb-and-MAP.

adaptiveImplicitMLE

Adaptive Implicit MLE (AI-MLE): I-MLE with an online-tuned interpolation step.

Module Contents

class pyepo.func.jax.perturbed.perturbedOpt(optmodel, n_samples=10, sigma=1.0, processes=1, seed=135, variance_reduction=True, solve_ratio=1.0, dataset=None)

Bases: pyepo.func.jax.abcmodule.optModule

Differentiable Perturbed Optimizer (DPO) – additive-Gaussian variant.

Estimates the expected solution \(\mathbb{E}_{\boldsymbol{\xi}}[\mathbf{w}^*(\hat{\mathbf{c}} + \sigma\boldsymbol{\xi})]\) by Monte Carlo averaging, giving an informative gradient where the bare solver gives zero. Returns a solution; pair with a task loss. The forward accepts key= for explicit RNG control, required under jax.jit.

Reference: Berthet et al. (2020) https://papers.nips.cc/paper/2020/hash/6bb56208f672af0dd65451f869fedfd9-Abstract.html

n_samples = 10
sigma
variance_reduction = True
forward(pred_cost, key=None)

Forward pass

class pyepo.func.jax.perturbed.perturbedOptMul(optmodel, n_samples=10, sigma=1.0, processes=1, seed=135, variance_reduction=True, solve_ratio=1.0, dataset=None)

Bases: perturbedOpt

Differentiable Perturbed Optimizer (DPO) – multiplicative log-normal variant.

As perturbedOpt, but perturbs the cost multiplicatively with log-normal noise \(\exp(\sigma\boldsymbol{\xi} - \sigma^2/2)\).

Reference: Dalle et al. (2022) https://arxiv.org/abs/2207.13513

class pyepo.func.jax.perturbed.perturbedFenchelYoung(optmodel, n_samples=10, sigma=1.0, processes=1, seed=135, solve_ratio=1.0, reduction: pyepo.func.runtime.Reduction = 'mean', dataset=None)

Bases: pyepo.func.jax.abcmodule.optModule

Perturbed Fenchel-Young loss (PFYL) – additive-Gaussian variant.

Pairs a Monte-Carlo expected perturbed solution with the Fenchel-Young loss against the true optimum, returning a scalar loss whose gradient is the residual \(\mathbf{w}^*(\mathbf{c}) - \mathbb{E}_{\boldsymbol{\xi}}[\mathbf{w}^*(\hat{\mathbf{c}} + \sigma\boldsymbol{\xi})]\). The forward accepts key= for explicit RNG control, required under jax.jit.

Reference: Berthet et al. (2020) https://papers.nips.cc/paper/2020/hash/6bb56208f672af0dd65451f869fedfd9-Abstract.html

n_samples = 10
sigma
forward(pred_cost, true_sol, key=None)

Forward pass

class pyepo.func.jax.perturbed.perturbedFenchelYoungMul(optmodel, n_samples=10, sigma=1.0, processes=1, seed=135, solve_ratio=1.0, reduction: pyepo.func.runtime.Reduction = 'mean', dataset=None)

Bases: perturbedFenchelYoung

Perturbed Fenchel-Young loss (PFYL) – multiplicative log-normal variant.

As perturbedFenchelYoung, but perturbs the cost multiplicatively with log-normal noise \(\exp(\sigma\boldsymbol{\xi} - \sigma^2/2)\).

Reference: Dalle et al. (2022) https://arxiv.org/abs/2207.13513

class pyepo.func.jax.perturbed.implicitMLE(optmodel, n_samples=10, sigma=1.0, lambd=10, kappa=5, n_iterations=10, two_sides=False, seed=135, processes=1, solve_ratio=1.0, dataset=None)

Bases: pyepo.func.jax.abcmodule.optModule

Implicit Maximum Likelihood Estimator (I-MLE) via perturb-and-MAP.

Frames decision-focused learning as imitation: an upstream gradient induces a virtual update \(\hat{\mathbf{c}} + \lambda \mathbf{d}\), and the gradient is a directional finite difference between smoothed solutions at the updated and original costs, sharing one Sum-of-Gamma noise realization across both. The forward accepts key= for explicit RNG control, required under jax.jit.

Reference: Niepert, Minervini & Franceschi (2021) https://proceedings.neurips.cc/paper_files/paper/2021/hash/7a430339c10c642c4b2251756fd1b484-Abstract.html

n_samples = 10
sigma
lambd
kappa
n_iterations = 10
two_sides = False
forward(pred_cost, key=None)

Forward pass

class pyepo.func.jax.perturbed.adaptiveImplicitMLE(optmodel, n_samples=10, sigma=1.0, kappa=5, n_iterations=10, two_sides=False, seed=135, processes=1, solve_ratio=1.0, dataset=None)

Bases: pyepo.func.jax.abcmodule.optModule

Adaptive Implicit MLE (AI-MLE): I-MLE with an online-tuned interpolation step.

Replaces I-MLE’s fixed lambda with \(\lambda_t = \alpha_t \|\hat{\mathbf{c}}\| / \|\mathbf{d}\|\), where \(\alpha_t\) is adapted online from a moving average of the gradient sparsity. Eager-only: the alpha update is a concrete side effect in the backward, so this loss is not jax.jit-able.

Reference: Minervini, Franceschi & Niepert (2023) https://ojs.aaai.org/index.php/AAAI/article/view/26103

n_samples = 10
sigma
kappa
n_iterations = 10
two_sides = False
alpha = 1.0
grad_norm_avg = 1.0
step = 0.001
forward(pred_cost)

Forward pass

pyepo.func.jax.perturbed.DPO
pyepo.func.jax.perturbed.DPOMul
pyepo.func.jax.perturbed.PFY
pyepo.func.jax.perturbed.PFYMul
pyepo.func.jax.perturbed.IMLE
pyepo.func.jax.perturbed.AIMLE