pyepo.func.jax.perturbed¶
Perturbed optimization function
Attributes¶
Classes¶
Differentiable Perturbed Optimizer (DPO) -- additive-Gaussian variant. |
|
Differentiable Perturbed Optimizer (DPO) -- multiplicative log-normal variant. |
|
Perturbed Fenchel-Young loss (PFYL) -- additive-Gaussian variant. |
|
Perturbed Fenchel-Young loss (PFYL) -- multiplicative log-normal variant. |
|
Implicit Maximum Likelihood Estimator (I-MLE) via perturb-and-MAP. |
|
Adaptive Implicit MLE (AI-MLE): I-MLE with an online-tuned interpolation step. |
Module Contents¶
- class pyepo.func.jax.perturbed.perturbedOpt(optmodel, n_samples=10, sigma=1.0, processes=1, seed=135, variance_reduction=True, solve_ratio=1.0, dataset=None)¶
Bases:
pyepo.func.jax.abcmodule.optModuleDifferentiable Perturbed Optimizer (DPO) – additive-Gaussian variant.
Estimates the expected solution \(\mathbb{E}_{\boldsymbol{\xi}}[\mathbf{w}^*(\hat{\mathbf{c}} + \sigma\boldsymbol{\xi})]\) by Monte Carlo averaging, giving an informative gradient where the bare solver gives zero. Returns a solution; pair with a task loss. The forward accepts
key=for explicit RNG control, required underjax.jit.Reference: Berthet et al. (2020) https://papers.nips.cc/paper/2020/hash/6bb56208f672af0dd65451f869fedfd9-Abstract.html
- n_samples = 10¶
- sigma¶
- variance_reduction = True¶
- forward(pred_cost, key=None)¶
Forward pass
- class pyepo.func.jax.perturbed.perturbedOptMul(optmodel, n_samples=10, sigma=1.0, processes=1, seed=135, variance_reduction=True, solve_ratio=1.0, dataset=None)¶
Bases:
perturbedOptDifferentiable Perturbed Optimizer (DPO) – multiplicative log-normal variant.
As
perturbedOpt, but perturbs the cost multiplicatively with log-normal noise \(\exp(\sigma\boldsymbol{\xi} - \sigma^2/2)\).Reference: Dalle et al. (2022) https://arxiv.org/abs/2207.13513
- class pyepo.func.jax.perturbed.perturbedFenchelYoung(optmodel, n_samples=10, sigma=1.0, processes=1, seed=135, solve_ratio=1.0, reduction: pyepo.func.runtime.Reduction = 'mean', dataset=None)¶
Bases:
pyepo.func.jax.abcmodule.optModulePerturbed Fenchel-Young loss (PFYL) – additive-Gaussian variant.
Pairs a Monte-Carlo expected perturbed solution with the Fenchel-Young loss against the true optimum, returning a scalar loss whose gradient is the residual \(\mathbf{w}^*(\mathbf{c}) - \mathbb{E}_{\boldsymbol{\xi}}[\mathbf{w}^*(\hat{\mathbf{c}} + \sigma\boldsymbol{\xi})]\). The forward accepts
key=for explicit RNG control, required underjax.jit.Reference: Berthet et al. (2020) https://papers.nips.cc/paper/2020/hash/6bb56208f672af0dd65451f869fedfd9-Abstract.html
- n_samples = 10¶
- sigma¶
- forward(pred_cost, true_sol, key=None)¶
Forward pass
- class pyepo.func.jax.perturbed.perturbedFenchelYoungMul(optmodel, n_samples=10, sigma=1.0, processes=1, seed=135, solve_ratio=1.0, reduction: pyepo.func.runtime.Reduction = 'mean', dataset=None)¶
Bases:
perturbedFenchelYoungPerturbed Fenchel-Young loss (PFYL) – multiplicative log-normal variant.
As
perturbedFenchelYoung, but perturbs the cost multiplicatively with log-normal noise \(\exp(\sigma\boldsymbol{\xi} - \sigma^2/2)\).Reference: Dalle et al. (2022) https://arxiv.org/abs/2207.13513
- class pyepo.func.jax.perturbed.implicitMLE(optmodel, n_samples=10, sigma=1.0, lambd=10, kappa=5, n_iterations=10, two_sides=False, seed=135, processes=1, solve_ratio=1.0, dataset=None)¶
Bases:
pyepo.func.jax.abcmodule.optModuleImplicit Maximum Likelihood Estimator (I-MLE) via perturb-and-MAP.
Frames decision-focused learning as imitation: an upstream gradient induces a virtual update \(\hat{\mathbf{c}} + \lambda \mathbf{d}\), and the gradient is a directional finite difference between smoothed solutions at the updated and original costs, sharing one Sum-of-Gamma noise realization across both. The forward accepts
key=for explicit RNG control, required underjax.jit.Reference: Niepert, Minervini & Franceschi (2021) https://proceedings.neurips.cc/paper_files/paper/2021/hash/7a430339c10c642c4b2251756fd1b484-Abstract.html
- n_samples = 10¶
- sigma¶
- lambd¶
- kappa¶
- n_iterations = 10¶
- two_sides = False¶
- forward(pred_cost, key=None)¶
Forward pass
- class pyepo.func.jax.perturbed.adaptiveImplicitMLE(optmodel, n_samples=10, sigma=1.0, kappa=5, n_iterations=10, two_sides=False, seed=135, processes=1, solve_ratio=1.0, dataset=None)¶
Bases:
pyepo.func.jax.abcmodule.optModuleAdaptive Implicit MLE (AI-MLE): I-MLE with an online-tuned interpolation step.
Replaces I-MLE’s fixed lambda with \(\lambda_t = \alpha_t \|\hat{\mathbf{c}}\| / \|\mathbf{d}\|\), where \(\alpha_t\) is adapted online from a moving average of the gradient sparsity. Eager-only: the alpha update is a concrete side effect in the backward, so this loss is not
jax.jit-able.Reference: Minervini, Franceschi & Niepert (2023) https://ojs.aaai.org/index.php/AAAI/article/view/26103
- n_samples = 10¶
- sigma¶
- kappa¶
- n_iterations = 10¶
- two_sides = False¶
- alpha = 1.0¶
- grad_norm_avg = 1.0¶
- step = 0.001¶
- forward(pred_cost)¶
Forward pass
- pyepo.func.jax.perturbed.DPO¶
- pyepo.func.jax.perturbed.DPOMul¶
- pyepo.func.jax.perturbed.PFY¶
- pyepo.func.jax.perturbed.PFYMul¶
- pyepo.func.jax.perturbed.IMLE¶
- pyepo.func.jax.perturbed.AIMLE¶