pyepo.model.mpax.compile

MPAX (JAX PDHG) compiler for the PyEPO DSL.

compiledMpaxProblem mixes the generic compiledBase with optMpaxModel to turn a finalized DSL Problem into MPAX standard-form matrices (min cᵀx + ½xᵀQx s.t. Ax = b, Gx h, l x u) solved by the JAX first-order solver. Unlike the other backends it overrides setObj / solve rather than using compiledBase’s numpy hooks: the cost is kept as a device tensor (DLPack) so vmap-batched GPU solving is preserved. MPAX is a continuous LP / QP relaxation solver — integer / binary variables are relaxed to their bounds, and quadratic constraints are not expressible.

Attributes

jax

Classes

compiledMpaxProblem

MPAX-backed (JAX LP / QP) compiled DSL problem.

Functions

compileProblem(→ compiledMpaxProblem)

Instantiate the MPAX-compiled problem.

Module Contents

pyepo.model.mpax.compile.jax = None
pyepo.model.mpax.compile.compileProblem(problem, **params) compiledMpaxProblem

Instantiate the MPAX-compiled problem.

class pyepo.model.mpax.compile.compiledMpaxProblem(problem, params=None)

Bases: pyepo.dsl.compiled.compiledBase, pyepo.model.mpax.mpaxmodel.optMpaxModel

MPAX-backed (JAX LP / QP) compiled DSL problem.

use_sparse_matrix = False
setObj(c)

Set the objective from a predicted cost of length num_cost, scattered onto the known fixed costs.

solve()

Solve and return the full decision-vector solution (length num_vars) with its objective value.