pyepo.func.cave¶
Cone-aligned vector estimation (CaVE) loss for binary linear programs
Classes¶
Cone-Aligned Vector Estimation (CaVE) loss for binary linear programs. |
Module Contents¶
- class pyepo.func.cave.coneAlignedCosine(optmodel: pyepo.model.opt.optModel, max_iter: int = 3, solve_ratio: float = 1.0, inner_ratio: float = 0.2, processes: int = 1, reduction: pyepo.func.abcmodule.Reduction = 'mean')¶
Bases:
pyepo.func.abcmodule.optModuleCone-Aligned Vector Estimation (CaVE) loss for binary linear programs.
For each training instance, the sense-flipped predicted cost vector \(-\hat{\mathbf{c}}\) is projected onto the polyhedral cone spanned by the binding-constraint normals at the true optimal vertex; the loss is \(1 - \cos(-\hat{\mathbf{c}}, \mathrm{proj})\). Because the supervision is the cone of binding normals (not the optimal solution itself), CaVE side-steps the zero-gradient pathology of solver layers without requiring a perturbation or solution pool. Defined for binary linear programs only.
PyEPO uses Clarabel as the interior-point QP solver for the cone projection.
Note
The default
max_iter=3is intentional — it is the CaVE+ preset from the paper. Three IPM steps under-converge the QP on purpose so the projection stays interior to the cone, yielding a richer gradient than a fully converged boundary projection. Raisingmax_iterchanges the loss behavior.For larger problems, set
solve_ratio < 1to enable the CaVE Hybrid preset from the paper: each batch goes through the QP projection with probabilitysolve_ratioand through a cheap heuristic (normalized predicted cost blended with the average binding-constraint normal) with probability1 - solve_ratio, cutting the per-epoch cost without measurable regret loss.Training data must come from
pyepo.data.dataset.optDatasetConstrs(Gurobi-backed) and be collated withcollate_tight_constraints.Reference: Tang & Khalil (2024) https://link.springer.com/chapter/10.1007/978-3-031-60599-4_12
- max_iter = 3¶
- solve_ratio¶
- inner_ratio¶
- forward(pred_cost: torch.Tensor, tight_ctrs: torch.Tensor) torch.Tensor¶
Forward pass