pyepo.func.jax.contrastive ========================== .. py:module:: pyepo.func.jax.contrastive .. autoapi-nested-parse:: Noise contrastive estimation loss function Attributes ---------- .. autoapisummary:: pyepo.func.jax.contrastive.NCE pyepo.func.jax.contrastive.CMAP Classes ------- .. autoapisummary:: pyepo.func.jax.contrastive.noiseContrastiveEstimation pyepo.func.jax.contrastive.contrastiveMAP Module Contents --------------- .. py:class:: noiseContrastiveEstimation(optmodel, processes=1, solve_ratio=1.0, reduction: pyepo.func.runtime.Reduction = 'mean', dataset=None) Bases: :py:obj:`pyepo.func.jax.abcmodule.optModule` Noise Contrastive Estimation (NCE) -- contrastive loss against a cached pool. Averages the predicted-cost margin between the true optimum and every pool member. Reference: Mulamba et al. (2021) ``_ .. py:method:: forward(pred_cost, true_sol) Forward pass .. py:class:: contrastiveMAP(optmodel, processes=1, solve_ratio=1.0, reduction: pyepo.func.runtime.Reduction = 'mean', dataset=None) Bases: :py:obj:`pyepo.func.jax.abcmodule.optModule` Contrastive Maximum-a-Posteriori (CMAP) -- max-margin special case of NCE. Keeps only the most-violating pool member (smallest predicted-cost objective) as the negative. Reference: Mulamba et al. (2021) ``_ .. py:method:: forward(pred_cost, true_sol) Forward pass .. py:data:: NCE .. py:data:: CMAP