pyepo.func.contrastive

Noise contrastive estimation loss function

Classes

noiseContrastiveEstimation

Noise Contrastive Estimation (NCE) -- contrastive loss against a cached solution pool.

contrastiveMAP

Contrastive Maximum-a-Posteriori (CMAP) -- max-margin special case of NCE.

Module Contents

class pyepo.func.contrastive.noiseContrastiveEstimation(optmodel: pyepo.model.opt.optModel, processes: int = 1, solve_ratio: float = 1.0, reduction: pyepo.func.abcmodule.Reduction = 'mean', dataset: pyepo.data.dataset.optDataset | None = None)

Bases: pyepo.func.abcmodule.optModule

Noise Contrastive Estimation (NCE) – contrastive loss against a cached solution pool.

Averages the predicted-cost margin between the true optimum and every member of the cached pool \(\Gamma\): \(\mathcal{L} = \tfrac{1}{|\Gamma|}\sum_{\mathbf{w} \in \Gamma} (\hat{\mathbf{c}}^\top \mathbf{w}^*(\mathbf{c}) - \hat{\mathbf{c}}^\top \mathbf{w})\). The gradient has a closed form (no solver call in the backward pass), so per-step cost is dominated by occasional pool refreshes rather than by solver work. Pass solve_ratio < 1 to control refresh frequency; the pool is seeded from dataset at construction.

Reference: Mulamba et al. (2021) https://www.ijcai.org/proceedings/2021/390

forward(pred_cost: torch.Tensor, true_sol: torch.Tensor) torch.Tensor

Forward pass

class pyepo.func.contrastive.contrastiveMAP(optmodel: pyepo.model.opt.optModel, processes: int = 1, solve_ratio: float = 1.0, reduction: pyepo.func.abcmodule.Reduction = 'mean', dataset: pyepo.data.dataset.optDataset | None = None)

Bases: pyepo.func.abcmodule.optModule

Contrastive Maximum-a-Posteriori (CMAP) – max-margin special case of NCE.

Keeps only the most-violating member of the cached pool \(\Gamma\) (the one with the smallest predicted-cost objective) as the negative: \(\mathcal{L} = \hat{\mathbf{c}}^\top \mathbf{w}^*(\mathbf{c}) - \min_{\mathbf{w} \in \Gamma} \hat{\mathbf{c}}^\top \mathbf{w}\). Simpler than NCE and often equally effective. Pool semantics (solve_ratio, dataset) are identical to NCE.

Reference: Mulamba et al. (2021) https://www.ijcai.org/proceedings/2021/390

forward(pred_cost: torch.Tensor, true_sol: torch.Tensor) torch.Tensor

Forward pass