pyepo.func.jax.contrastive

Noise contrastive estimation loss function

Attributes

Classes

noiseContrastiveEstimation

Noise Contrastive Estimation (NCE) -- contrastive loss against a cached pool.

contrastiveMAP

Contrastive Maximum-a-Posteriori (CMAP) -- max-margin special case of NCE.

Module Contents

class pyepo.func.jax.contrastive.noiseContrastiveEstimation(optmodel, processes=1, solve_ratio=1.0, reduction: pyepo.func.runtime.Reduction = 'mean', dataset=None)

Bases: pyepo.func.jax.abcmodule.optModule

Noise Contrastive Estimation (NCE) – contrastive loss against a cached pool.

Averages the predicted-cost margin between the true optimum and every pool member.

Reference: Mulamba et al. (2021) https://www.ijcai.org/proceedings/2021/390

forward(pred_cost, true_sol)

Forward pass

class pyepo.func.jax.contrastive.contrastiveMAP(optmodel, processes=1, solve_ratio=1.0, reduction: pyepo.func.runtime.Reduction = 'mean', dataset=None)

Bases: pyepo.func.jax.abcmodule.optModule

Contrastive Maximum-a-Posteriori (CMAP) – max-margin special case of NCE.

Keeps only the most-violating pool member (smallest predicted-cost objective) as the negative.

Reference: Mulamba et al. (2021) https://www.ijcai.org/proceedings/2021/390

forward(pred_cost, true_sol)

Forward pass

pyepo.func.jax.contrastive.NCE
pyepo.func.jax.contrastive.CMAP