pyepo.func.jax.cave

Cone-aligned vector estimation (CaVE) loss for binary linear programs

Attributes

Classes

coneAlignedCosine

Cone-Aligned Vector Estimation (CaVE) loss for binary linear programs.

Module Contents

class pyepo.func.jax.cave.coneAlignedCosine(optmodel, max_iter=3, solve_ratio=1.0, inner_ratio=0.2, processes=1, reduction: pyepo.func.runtime.Reduction = 'mean')

Bases: pyepo.func.jax.abcmodule.optModule

Cone-Aligned Vector Estimation (CaVE) loss for binary linear programs.

Projects the sense-flipped predicted cost onto the polyhedral cone spanned by the binding-constraint normals at the true optimal vertex (a Clarabel QP) and minimizes \(1 - \cos(-\hat{\mathbf{c}}, \mathrm{proj})\). The projection is detached, so the gradient flows only through the cosine.

Reference: Tang & Khalil (2024) https://link.springer.com/chapter/10.1007/978-3-031-60599-4_12

max_iter = 3
solve_ratio
inner_ratio
forward(pred_cost, tight_ctrs)

Forward pass

pyepo.func.jax.cave.CaVE