pyepo.func.jax.contrastive¶
Noise contrastive estimation loss function
Attributes¶
Classes¶
Noise Contrastive Estimation (NCE) -- contrastive loss against a cached pool. |
|
Contrastive Maximum-a-Posteriori (CMAP) -- max-margin special case of NCE. |
Module Contents¶
- class pyepo.func.jax.contrastive.noiseContrastiveEstimation(optmodel, processes=1, solve_ratio=1.0, reduction: pyepo.func.runtime.Reduction = 'mean', dataset=None)¶
Bases:
pyepo.func.jax.abcmodule.optModuleNoise Contrastive Estimation (NCE) – contrastive loss against a cached pool.
Averages the predicted-cost margin between the true optimum and every pool member.
Reference: Mulamba et al. (2021) https://www.ijcai.org/proceedings/2021/390
- forward(pred_cost, true_sol)¶
Forward pass
- class pyepo.func.jax.contrastive.contrastiveMAP(optmodel, processes=1, solve_ratio=1.0, reduction: pyepo.func.runtime.Reduction = 'mean', dataset=None)¶
Bases:
pyepo.func.jax.abcmodule.optModuleContrastive Maximum-a-Posteriori (CMAP) – max-margin special case of NCE.
Keeps only the most-violating pool member (smallest predicted-cost objective) as the negative.
Reference: Mulamba et al. (2021) https://www.ijcai.org/proceedings/2021/390
- forward(pred_cost, true_sol)¶
Forward pass
- pyepo.func.jax.contrastive.NCE¶
- pyepo.func.jax.contrastive.CMAP¶