pyepo.func.cave

Cone-aligned vector estimation (CaVE) loss for binary linear programs

Classes

coneAlignedCosine

Cone-Aligned Vector Estimation (CaVE) loss for binary linear programs.

Module Contents

class pyepo.func.cave.coneAlignedCosine(optmodel: pyepo.model.opt.optModel, max_iter: int = 3, solve_ratio: float = 1.0, inner_ratio: float = 0.2, processes: int = 1, reduction: pyepo.func.abcmodule.Reduction = 'mean')

Bases: pyepo.func.abcmodule.optModule

Cone-Aligned Vector Estimation (CaVE) loss for binary linear programs.

For each training instance, the sense-flipped predicted cost vector \(-\hat{\mathbf{c}}\) is projected onto the polyhedral cone spanned by the binding-constraint normals at the true optimal vertex; the loss is \(1 - \cos(-\hat{\mathbf{c}}, \mathrm{proj})\). Because the supervision is the cone of binding normals (not the optimal solution itself), CaVE side-steps the zero-gradient pathology of solver layers without requiring a perturbation or solution pool. Defined for binary linear programs only.

PyEPO uses Clarabel as the interior-point QP solver for the cone projection.

Note

The default max_iter=3 is intentional — it is the CaVE+ preset from the paper. Three IPM steps under-converge the QP on purpose so the projection stays interior to the cone, yielding a richer gradient than a fully converged boundary projection. Raising max_iter changes the loss behavior.

For larger problems, set solve_ratio < 1 to enable the CaVE Hybrid preset from the paper: each batch goes through the QP projection with probability solve_ratio and through a cheap heuristic (normalized predicted cost blended with the average binding-constraint normal) with probability 1 - solve_ratio, cutting the per-epoch cost without measurable regret loss.

Training data must come from pyepo.data.dataset.optDatasetConstrs (Gurobi-backed) and be collated with collate_tight_constraints.

Reference: Tang & Khalil (2024) https://link.springer.com/chapter/10.1007/978-3-031-60599-4_12

max_iter = 3
solve_ratio
inner_ratio
forward(pred_cost: torch.Tensor, tight_ctrs: torch.Tensor) torch.Tensor

Forward pass