pyepo.func.contrastive¶
Noise contrastive estimation loss function
Classes¶
Noise Contrastive Estimation (NCE) -- contrastive loss against a cached solution pool. |
|
Contrastive Maximum-a-Posteriori (CMAP) -- max-margin special case of NCE. |
Module Contents¶
- class pyepo.func.contrastive.noiseContrastiveEstimation(optmodel: pyepo.model.opt.optModel, processes: int = 1, solve_ratio: float = 1.0, reduction: pyepo.func.abcmodule.Reduction = 'mean', dataset: pyepo.data.dataset.optDataset | None = None)¶
Bases:
pyepo.func.abcmodule.optModuleNoise Contrastive Estimation (NCE) – contrastive loss against a cached solution pool.
Averages the predicted-cost margin between the true optimum and every member of the cached pool \(\Gamma\): \(\mathcal{L} = \tfrac{1}{|\Gamma|}\sum_{\mathbf{w} \in \Gamma} (\hat{\mathbf{c}}^\top \mathbf{w}^*(\mathbf{c}) - \hat{\mathbf{c}}^\top \mathbf{w})\). The gradient has a closed form (no solver call in the backward pass), so per-step cost is dominated by occasional pool refreshes rather than by solver work. Pass
solve_ratio < 1to control refresh frequency; the pool is seeded fromdatasetat construction.Reference: Mulamba et al. (2021) https://www.ijcai.org/proceedings/2021/390
- forward(pred_cost: torch.Tensor, true_sol: torch.Tensor) torch.Tensor¶
Forward pass
- class pyepo.func.contrastive.contrastiveMAP(optmodel: pyepo.model.opt.optModel, processes: int = 1, solve_ratio: float = 1.0, reduction: pyepo.func.abcmodule.Reduction = 'mean', dataset: pyepo.data.dataset.optDataset | None = None)¶
Bases:
pyepo.func.abcmodule.optModuleContrastive Maximum-a-Posteriori (CMAP) – max-margin special case of NCE.
Keeps only the most-violating member of the cached pool \(\Gamma\) (the one with the smallest predicted-cost objective) as the negative: \(\mathcal{L} = \hat{\mathbf{c}}^\top \mathbf{w}^*(\mathbf{c}) - \min_{\mathbf{w} \in \Gamma} \hat{\mathbf{c}}^\top \mathbf{w}\). Simpler than NCE and often equally effective. Pool semantics (
solve_ratio,dataset) are identical to NCE.Reference: Mulamba et al. (2021) https://www.ijcai.org/proceedings/2021/390
- forward(pred_cost: torch.Tensor, true_sol: torch.Tensor) torch.Tensor¶
Forward pass