diff --git a/doc/release_notes.rst b/doc/release_notes.rst index 60926055..ed59b835 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -4,6 +4,8 @@ Release Notes Upcoming Version ---------------- +* Add unified ``SolverMetrics`` dataclass accessible via ``Model.solver_metrics`` after solving. Provides ``solver_name``, ``solve_time``, ``objective_value``, ``best_bound``, and ``mip_gap`` in a solver-independent way. All solvers populate solver-specific fields where available. + Version 0.6.3 -------------- diff --git a/examples/create-a-model.ipynb b/examples/create-a-model.ipynb index a158e0cf..6b476ac2 100644 --- a/examples/create-a-model.ipynb +++ b/examples/create-a-model.ipynb @@ -30,11 +30,16 @@ }, { "cell_type": "code", - "execution_count": null, "id": "dramatic-cannon", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-11T21:42:38.164407Z", + "start_time": "2026-02-11T21:42:38.162992Z" + } + }, + "source": [], "outputs": [], - "source": [] + "execution_count": null }, { "attachments": {}, @@ -49,15 +54,20 @@ }, { "cell_type": "code", - "execution_count": null, "id": "technical-conducting", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-11T21:42:39.360058Z", + "start_time": "2026-02-11T21:42:38.171827Z" + } + }, "source": [ "from linopy import Model\n", "\n", "m = Model()" - ] + ], + "outputs": [], + "execution_count": null }, { "attachments": {}, @@ -83,14 +93,19 @@ }, { "cell_type": "code", - "execution_count": null, "id": "protecting-power", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-11T21:42:39.387467Z", + "start_time": "2026-02-11T21:42:39.384712Z" + } + }, "source": [ "x = m.add_variables(lower=0, name=\"x\")\n", "y = m.add_variables(lower=0, name=\"y\");" - ] + ], + "outputs": [], + "execution_count": null }, { "attachments": {}, @@ -103,13 +118,18 @@ }, { "cell_type": "code", - "execution_count": null, "id": "virtual-anxiety", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-11T21:42:39.393709Z", + "start_time": "2026-02-11T21:42:39.390438Z" + } + }, "source": [ "x" - ] + ], + "outputs": [], + "execution_count": null }, { "attachments": {}, @@ -127,13 +147,18 @@ }, { "cell_type": "code", - "execution_count": null, "id": "fbb46cad", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-11T21:42:39.405691Z", + "start_time": "2026-02-11T21:42:39.396625Z" + } + }, "source": [ "3 * x + 7 * y >= 10" - ] + ], + "outputs": [], + "execution_count": null }, { "attachments": {}, @@ -146,13 +171,18 @@ }, { "cell_type": "code", - "execution_count": null, "id": "60f41b76", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-11T21:42:39.416325Z", + "start_time": "2026-02-11T21:42:39.409117Z" + } + }, "source": [ "3 * x + 7 * y - 10 >= 0" - ] + ], + "outputs": [], + "execution_count": null }, { "attachments": {}, @@ -167,14 +197,19 @@ }, { "cell_type": "code", - "execution_count": null, "id": "hollywood-production", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-11T21:42:39.431755Z", + "start_time": "2026-02-11T21:42:39.420977Z" + } + }, "source": [ "m.add_constraints(3 * x + 7 * y >= 10)\n", "m.add_constraints(5 * x + 2 * y >= 3);" - ] + ], + "outputs": [], + "execution_count": null }, { "attachments": {}, @@ -189,13 +224,18 @@ }, { "cell_type": "code", - "execution_count": null, "id": "overall-exhibition", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-11T21:42:39.438865Z", + "start_time": "2026-02-11T21:42:39.434328Z" + } + }, "source": [ "m.add_objective(x + 2 * y)" - ] + ], + "outputs": [], + "execution_count": null }, { "attachments": {}, @@ -210,13 +250,18 @@ }, { "cell_type": "code", - "execution_count": null, "id": "pressing-copying", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-11T21:42:39.532619Z", + "start_time": "2026-02-11T21:42:39.441886Z" + } + }, "source": [ "m.solve(solver_name=\"highs\")" - ] + ], + "outputs": [], + "execution_count": null }, { "attachments": {}, @@ -229,23 +274,67 @@ }, { "cell_type": "code", - "execution_count": null, "id": "electric-duration", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-11T21:42:39.560199Z", + "start_time": "2026-02-11T21:42:39.553844Z" + } + }, "source": [ "x.solution" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "e6d31751", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-11T21:42:39.577784Z", + "start_time": "2026-02-11T21:42:39.573362Z" + } + }, "source": [ "y.solution" - ] + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "9zgzuhvo1b8", + "source": [ + "### Solver Metrics\n", + "\n", + "After solving, you can inspect performance metrics reported by the solver via `solver_metrics`. This includes solve time, objective value, and for MIP problems, the dual bound and MIP gap (available for most solvers." + ], + "metadata": {} + }, + { + "cell_type": "code", + "id": "bdfxi7haoc", + "source": "m.solver_metrics", + "metadata": { + "ExecuteTime": { + "end_time": "2026-02-11T21:42:39.592065Z", + "start_time": "2026-02-11T21:42:39.589851Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "SolverMetrics(solver_name='highs', solve_time=0.0019101661164313555, objective_value=2.862068965517241, dual_bound=0.0, mip_gap=inf)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": null }, { "attachments": {}, diff --git a/linopy/__init__.py b/linopy/__init__.py index 3efc297a..bccd5c26 100644 --- a/linopy/__init__.py +++ b/linopy/__init__.py @@ -14,7 +14,7 @@ import linopy.monkey_patch_xarray # noqa: F401 from linopy.common import align from linopy.config import options -from linopy.constants import EQUAL, GREATER_EQUAL, LESS_EQUAL +from linopy.constants import EQUAL, GREATER_EQUAL, LESS_EQUAL, SolverMetrics from linopy.constraints import Constraint, Constraints from linopy.expressions import LinearExpression, QuadraticExpression, merge from linopy.io import read_netcdf @@ -34,6 +34,7 @@ "OetcHandler", "QuadraticExpression", "RemoteHandler", + "SolverMetrics", "Variable", "Variables", "available_solvers", diff --git a/linopy/constants.py b/linopy/constants.py index 021a9a10..39466789 100644 --- a/linopy/constants.py +++ b/linopy/constants.py @@ -3,6 +3,7 @@ Linopy module for defining constant values used within the package. """ +import dataclasses import logging from dataclasses import dataclass, field from enum import Enum @@ -211,6 +212,46 @@ class Solution: objective: float = field(default=np.nan) +@dataclass(frozen=True) +class SolverMetrics: + """ + Unified solver performance metrics. + + All fields default to ``None``. Solvers populate what they can; + unsupported fields remain ``None``. Access via + :attr:`Model.solver_metrics` after calling :meth:`Model.solve`. + + Attributes + ---------- + solver_name : str or None + Name of the solver used. + solve_time : float or None + Wall-clock time spent solving (seconds). + objective_value : float or None + Objective value of the best solution found. + dual_bound : float or None + Best bound on the objective from the MIP relaxation (also known as + "best bound"). Only populated for integer programs. + mip_gap : float or None + Relative gap between the objective value and the dual bound. + Only populated for integer programs. + """ + + solver_name: str | None = None + solve_time: float | None = None + objective_value: float | None = None + dual_bound: float | None = None + mip_gap: float | None = None + + def __repr__(self) -> str: + fields = [] + for f in dataclasses.fields(self): + val = getattr(self, f.name) + if val is not None: + fields.append(f"{f.name}={val!r}") + return f"SolverMetrics({', '.join(fields)})" + + @dataclass class Result: """ @@ -220,6 +261,7 @@ class Result: status: Status solution: Solution | None = None solver_model: Any = None + metrics: SolverMetrics | None = None def __repr__(self) -> str: solver_model_string = ( @@ -232,12 +274,16 @@ def __repr__(self) -> str: ) else: solution_string = "Solution: None\n" + metrics_string = "" + if self.metrics is not None: + metrics_string = f"Solver metrics: {self.metrics}\n" return ( f"Status: {self.status.status.value}\n" f"Termination condition: {self.status.termination_condition.value}\n" + solution_string + f"Solver model: {solver_model_string}\n" - f"Solver message: {self.status.legacy_status}" + + metrics_string + + f"Solver message: {self.status.legacy_status}" ) def info(self) -> None: diff --git a/linopy/model.py b/linopy/model.py index 871945ba..015cbec0 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -41,6 +41,7 @@ LESS_EQUAL, TERM_DIM, ModelStatus, + SolverMetrics, TerminationCondition, ) from linopy.constraints import AnonymousScalarConstraint, Constraint, Constraints @@ -97,6 +98,7 @@ class Model: solver_model: Any solver_name: str + _solver_metrics: SolverMetrics | None _variables: Variables _constraints: Constraints _objective: Objective @@ -137,6 +139,7 @@ class Model: "_force_dim_names", "_auto_mask", "_solver_dir", + "_solver_metrics", "solver_model", "solver_name", "matrices", @@ -197,6 +200,28 @@ def __init__( ) self.matrices: MatrixAccessor = MatrixAccessor(self) + self._solver_metrics: SolverMetrics | None = None + + @property + def solver_metrics(self) -> SolverMetrics | None: + """ + Solver performance metrics from the last solve, or ``None`` + if the model has not been solved yet. + + Returns a :class:`~linopy.constants.SolverMetrics` instance. + Fields the solver cannot provide remain ``None``. + + Reset to ``None`` by :meth:`reset_solution`. + + Examples + -------- + >>> m.solve(solver_name="highs") # doctest: +SKIP + >>> m.solver_metrics.solve_time # doctest: +SKIP + 0.003 + >>> m.solver_metrics.objective_value # doctest: +SKIP + 0.0 + """ + return self._solver_metrics @property def variables(self) -> Variables: @@ -1413,6 +1438,7 @@ def solve( self.termination_condition = result.status.termination_condition.value self.solver_model = result.solver_model self.solver_name = solver_name + self._solver_metrics = result.metrics if not result.status.is_ok: return result.status.status.value, result.status.termination_condition.value @@ -1470,6 +1496,7 @@ def _mock_solve( self.termination_condition = TerminationCondition.optimal.value self.solver_model = None self.solver_name = solver_name + self._solver_metrics = SolverMetrics(solver_name="mock", objective_value=0.0) for name, var in self.variables.items(): var.solution = xr.DataArray(0.0, var.coords) @@ -1712,6 +1739,7 @@ def reset_solution(self) -> None: """ self.variables.reset_solution() self.constraints.reset_dual() + self._solver_metrics = None to_netcdf = to_netcdf diff --git a/linopy/solvers.py b/linopy/solvers.py index fe516b47..c82f4a22 100644 --- a/linopy/solvers.py +++ b/linopy/solvers.py @@ -6,6 +6,7 @@ from __future__ import annotations import contextlib +import dataclasses import enum import io import logging @@ -14,6 +15,7 @@ import subprocess as sub import sys import threading +import time import warnings from abc import ABC, abstractmethod from collections import namedtuple @@ -29,6 +31,7 @@ from linopy.constants import ( Result, Solution, + SolverMetrics, SolverStatus, Status, TerminationCondition, @@ -216,6 +219,15 @@ class xpress_Namespaces: # type: ignore[no-redef] logger = logging.getLogger(__name__) +def _safe_get(func: Callable[[], Any]) -> Any: + """Call *func* and return its result, or None if it raises.""" + try: + return func() + except Exception: + logger.debug("Failed to extract solver metric", exc_info=True) + return None + + io_structure = dict( lp_file={ "gurobi", @@ -410,6 +422,21 @@ def solve_problem( msg = "No problem file or model specified." raise ValueError(msg) + def _extract_metrics(self, solver_model: Any, solution: Solution) -> SolverMetrics: + """ + Extract solver performance metrics. + + Base implementation populates solver_name and objective_value. + Subclasses should call super(), then set solver-specific fields + on the returned object. + """ + return SolverMetrics( + solver_name=self.solver_name.value, + objective_value=_safe_get( + lambda: solution.objective if not np.isnan(solution.objective) else None + ), + ) + @property def solver_name(self) -> SolverName: return SolverName[self.__class__.__name__] @@ -431,6 +458,14 @@ def __init__( ) -> None: super().__init__(**solver_options) + def _extract_metrics(self, solver_model: Any, solution: Solution) -> SolverMetrics: + metrics = super()._extract_metrics(solver_model, solution) + return dataclasses.replace( + metrics, + solve_time=_safe_get(lambda: solver_model.runtime), + mip_gap=_safe_get(lambda: solver_model.mip_gap), + ) + def solve_problem_from_model( self, model: Model, @@ -598,7 +633,9 @@ def get_solver_solution() -> Solution: runtime = float(m.group(1)) CbcModel = namedtuple("CbcModel", ["mip_gap", "runtime"]) - return Result(status, solution, CbcModel(mip_gap, runtime)) + solver_model = CbcModel(mip_gap, runtime) + metrics = self._extract_metrics(solver_model, solution) + return Result(status, solution, solver_model, metrics) class GLPK(Solver[None]): @@ -728,7 +765,10 @@ def solve_problem_from_file( if not os.path.exists(solution_fn): status = Status(SolverStatus.warning, TerminationCondition.unknown) - return Result(status, Solution()) + solution = Solution() + return Result( + status, solution, metrics=self._extract_metrics(None, solution) + ) f = open(solution_fn) @@ -768,7 +808,8 @@ def get_solver_solution() -> Solution: solution = self.safe_get_solution(status=status, func=get_solver_solution) solution = maybe_adjust_objective_sign(solution, io_api, sense) - return Result(status, solution) + metrics = self._extract_metrics(None, solution) + return Result(status, solution, metrics=metrics) class Highs(Solver[None]): @@ -911,6 +952,28 @@ def solve_problem_from_file( sense=read_sense_from_problem_file(problem_fn), ) + def _extract_metrics(self, solver_model: Any, solution: Solution) -> SolverMetrics: + h = solver_model + metrics = super()._extract_metrics(solver_model, solution) + + def _highs_info(key: str) -> float: + status, val = h.getInfoValue(key) + if status != highspy.HighsStatus.kOk: # pragma: no cover + msg = f"Failed to get HiGHS info: {key}" + raise RuntimeError(msg) + return val + + is_mip = _safe_get(lambda: _highs_info("mip_node_count")) not in (None, -1) + + return dataclasses.replace( + metrics, + solve_time=_safe_get(lambda: h.getRunTime()), + mip_gap=_safe_get(lambda: _highs_info("mip_gap")) if is_mip else None, + dual_bound=_safe_get(lambda: _highs_info("mip_dual_bound")) + if is_mip + else None, + ) + def _set_solver_params( self, highs_solver: highspy.Highs, @@ -1019,7 +1082,8 @@ def get_solver_solution() -> Solution: solution = self.safe_get_solution(status=status, func=get_solver_solution) solution = maybe_adjust_objective_sign(solution, io_api, sense) - return Result(status, solution, h) + metrics = self._extract_metrics(h, solution) + return Result(status, solution, h, metrics) class Gurobi(Solver["gurobipy.Env | dict[str, Any] | None"]): @@ -1153,6 +1217,19 @@ def solve_problem_from_file( sense=sense, ) + def _extract_metrics( + self, solver_model: Any, solution: Solution + ) -> SolverMetrics: # pragma: no cover + m = solver_model + metrics = super()._extract_metrics(solver_model, solution) + is_mip = _safe_get(lambda: m.IsMIP) == 1 + return dataclasses.replace( + metrics, + solve_time=_safe_get(lambda: m.Runtime), + dual_bound=_safe_get(lambda: m.ObjBound) if is_mip else None, + mip_gap=_safe_get(lambda: m.MIPGap) if is_mip else None, + ) + def _solve( self, m: gurobipy.Model, @@ -1254,7 +1331,8 @@ def get_solver_solution() -> Solution: solution = self.safe_get_solution(status=status, func=get_solver_solution) solution = solution = maybe_adjust_objective_sign(solution, io_api, sense) - return Result(status, solution, m) + metrics = self._extract_metrics(m, solution) + return Result(status, solution, m, metrics) class Cplex(Solver[None]): @@ -1277,6 +1355,23 @@ def __init__( ) -> None: super().__init__(**solver_options) + def _extract_metrics( + self, solver_model: Any, solution: Solution + ) -> SolverMetrics: # pragma: no cover + m = solver_model + metrics = super()._extract_metrics(solver_model, solution) + is_mip = _safe_get(lambda: m.problem_type[m.get_problem_type()] != "LP") + return dataclasses.replace( + metrics, + solve_time=_safe_get(lambda: self._solve_time), + dual_bound=_safe_get(lambda: m.solution.MIP.get_best_objective()) + if is_mip + else None, + mip_gap=_safe_get(lambda: m.solution.MIP.get_mip_relative_gap()) + if is_mip + else None, + ) + def solve_problem_from_model( self, model: Model, @@ -1366,8 +1461,10 @@ def solve_problem_from_file( is_lp = m.problem_type[m.get_problem_type()] == "LP" + _t0 = time.perf_counter() # pragma: no cover with contextlib.suppress(cplex.exceptions.errors.CplexSolverError): m.solve() + self._solve_time = time.perf_counter() - _t0 # pragma: no cover if solution_fn is not None: try: @@ -1410,7 +1507,8 @@ def get_solver_solution() -> Solution: solution = self.safe_get_solution(status=status, func=get_solver_solution) solution = maybe_adjust_objective_sign(solution, io_api, sense) - return Result(status, solution, m) + metrics = self._extract_metrics(m, solution) + return Result(status, solution, m, metrics) class SCIP(Solver[None]): @@ -1429,6 +1527,17 @@ def __init__( ) -> None: super().__init__(**solver_options) + def _extract_metrics(self, solver_model: Any, solution: Solution) -> SolverMetrics: + m = solver_model + metrics = super()._extract_metrics(solver_model, solution) + is_mip = getattr(self, "_is_mip", False) + return dataclasses.replace( + metrics, + solve_time=_safe_get(lambda: m.getSolvingTime()), + dual_bound=_safe_get(lambda: m.getDualbound()) if is_mip else None, + mip_gap=_safe_get(lambda: m.getGap()) if is_mip else None, + ) + def solve_problem_from_model( self, model: Model, @@ -1520,6 +1629,7 @@ def solve_problem_from_file( if warmstart_fn: logger.warning("Warmstart not implemented for SCIP") + self._is_mip = m.getNIntVars() + m.getNBinVars() > 0 m.optimize() if basis_fn: @@ -1563,7 +1673,8 @@ def get_solver_solution() -> Solution: solution = self.safe_get_solution(status=status, func=get_solver_solution) solution = maybe_adjust_objective_sign(solution, io_api, sense) - return Result(status, solution, m) + metrics = self._extract_metrics(m, solution) + return Result(status, solution, m, metrics) class Xpress(Solver[None]): @@ -1585,6 +1696,27 @@ def __init__( ) -> None: super().__init__(**solver_options) + def _extract_metrics( + self, solver_model: Any, solution: Solution + ) -> SolverMetrics: # pragma: no cover + m = solver_model + metrics = super()._extract_metrics(solver_model, solution) + is_mip = _safe_get(lambda: m.attributes.mipents) not in (None, 0) + + def _xpress_mip_gap() -> float | None: + obj = m.attributes.mipbestobjval + bound = m.attributes.bestbound + if obj == 0: + return 0.0 if bound == 0 else None + return abs(obj - bound) / abs(obj) + + return dataclasses.replace( + metrics, + solve_time=_safe_get(lambda: m.attributes.time), + dual_bound=_safe_get(lambda: m.attributes.bestbound) if is_mip else None, + mip_gap=_safe_get(_xpress_mip_gap) if is_mip else None, + ) + def solve_problem_from_model( self, model: Model, @@ -1733,7 +1865,8 @@ def get_solver_solution() -> Solution: solution = self.safe_get_solution(status=status, func=get_solver_solution) solution = maybe_adjust_objective_sign(solution, io_api, sense) - return Result(status, solution, m) + metrics = self._extract_metrics(m, solution) + return Result(status, solution, m, metrics) mosek_bas_re = re.compile(r" (XL|XU)\s+([^ \t]+)\s+([^ \t]+)| (LL|UL|BS)\s+([^ \t]+)") @@ -1765,6 +1898,23 @@ def __init__( ) -> None: super().__init__(**solver_options) + def _extract_metrics( + self, solver_model: Any, solution: Solution + ) -> SolverMetrics: # pragma: no cover + m = solver_model + metrics = super()._extract_metrics(solver_model, solution) + is_mip = _safe_get(lambda: m.getnumintvar()) not in (None, 0) + return dataclasses.replace( + metrics, + solve_time=_safe_get(lambda: m.getdouinf(mosek.dinfitem.optimizer_time)), + dual_bound=_safe_get(lambda: m.getdouinf(mosek.dinfitem.mio_obj_bound)) + if is_mip + else None, + mip_gap=_safe_get(lambda: m.getdouinf(mosek.dinfitem.mio_obj_rel_gap)) + if is_mip + else None, + ) + def solve_problem_from_model( self, model: Model, @@ -2075,7 +2225,8 @@ def get_solver_solution() -> Solution: solution = self.safe_get_solution(status=status, func=get_solver_solution) solution = maybe_adjust_objective_sign(solution, io_api, sense) - return Result(status, solution) + metrics = self._extract_metrics(m, solution) + return Result(status, solution, metrics=metrics) class COPT(Solver[None]): @@ -2214,9 +2365,10 @@ def get_solver_solution() -> Solution: solution = self.safe_get_solution(status=status, func=get_solver_solution) solution = maybe_adjust_objective_sign(solution, io_api, sense) + metrics = self._extract_metrics(m, solution) env_.close() - return Result(status, solution, m) + return Result(status, solution, m, metrics) class MindOpt(Solver[None]): @@ -2357,10 +2509,12 @@ def get_solver_solution() -> Solution: solution = self.safe_get_solution(status=status, func=get_solver_solution) solution = maybe_adjust_objective_sign(solution, io_api, sense) + metrics = self._extract_metrics(m, solution) + m.dispose() env_.dispose() - return Result(status, solution, m) + return Result(status, solution, m, metrics) class PIPS(Solver[None]): @@ -2609,7 +2763,8 @@ def get_solver_solution() -> Solution: solution = maybe_adjust_objective_sign(solution, io_api, sense) # see https://github.com/MIT-Lu-Lab/cuPDLPx/tree/main/python#solution-attributes - return Result(status, solution, cu_model) + metrics = self._extract_metrics(cu_model, solution) + return Result(status, solution, cu_model, metrics) def _set_solver_params(self, cu_model: cupdlpx.Model) -> None: """ diff --git a/test/test_solver_metrics.py b/test/test_solver_metrics.py new file mode 100644 index 00000000..e6bcee4e --- /dev/null +++ b/test/test_solver_metrics.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python3 +""" +Tests for the SolverMetrics feature. +""" + +from __future__ import annotations + +import numpy as np +import pytest +import xarray as xr + +from linopy import Model, available_solvers +from linopy.constants import Result, Solution, SolverMetrics, Status +from linopy.solver_capabilities import SolverFeature, get_available_solvers_with_feature + +# --------------------------------------------------------------------------- +# SolverMetrics dataclass tests +# --------------------------------------------------------------------------- + + +def test_solver_metrics_defaults() -> None: + m = SolverMetrics() + assert m.solver_name is None + assert m.solve_time is None + assert m.objective_value is None + assert m.dual_bound is None + assert m.mip_gap is None + + +def test_solver_metrics_partial() -> None: + m = SolverMetrics(solver_name="highs", solve_time=1.5) + assert m.solver_name == "highs" + assert m.solve_time == 1.5 + assert m.objective_value is None + + +def test_solver_metrics_repr_only_non_none() -> None: + m = SolverMetrics(solver_name="gurobi", solve_time=2.3) + r = repr(m) + assert "solver_name='gurobi'" in r + assert "solve_time=2.3" in r + assert "objective_value" not in r + assert "dual_bound" not in r + + +def test_solver_metrics_repr_empty() -> None: + m = SolverMetrics() + assert repr(m) == "SolverMetrics()" + + +def test_solver_metrics_frozen() -> None: + m = SolverMetrics(solver_name="test") + with pytest.raises(AttributeError): + m.solver_name = "other" # type: ignore[misc] + + +# --------------------------------------------------------------------------- +# Result backward compatibility tests +# --------------------------------------------------------------------------- + + +def test_result_without_metrics() -> None: + """Result without metrics should still work (backward compatible).""" + status = Status.from_termination_condition("optimal") + result = Result(status=status, solution=Solution()) + assert result.metrics is None + # repr should not crash + repr(result) + + +def test_result_with_metrics() -> None: + status = Status.from_termination_condition("optimal") + metrics = SolverMetrics(solver_name="test", solve_time=1.0) + result = Result(status=status, solution=Solution(), metrics=metrics) + assert result.metrics is not None + assert result.metrics.solver_name == "test" + r = repr(result) + assert "Solver metrics:" in r + + +# --------------------------------------------------------------------------- +# Model integration tests +# --------------------------------------------------------------------------- + + +def test_model_metrics_none_before_solve() -> None: + m = Model() + assert m.solver_metrics is None + + +def test_model_metrics_populated_after_mock_solve() -> None: + m = Model() + x = m.add_variables( + lower=xr.DataArray(np.zeros(5), dims=["i"]), + upper=xr.DataArray(np.ones(5), dims=["i"]), + name="x", + ) + m.add_objective(x.sum()) + m.solve(mock_solve=True) + assert m.solver_metrics is not None + assert m.solver_metrics.solver_name == "mock" + assert m.solver_metrics.objective_value == 0.0 + + +def test_model_metrics_reset() -> None: + m = Model() + x = m.add_variables( + lower=xr.DataArray(np.zeros(5), dims=["i"]), + upper=xr.DataArray(np.ones(5), dims=["i"]), + name="x", + ) + m.add_objective(x.sum()) + m.solve(mock_solve=True) + assert m.solver_metrics is not None + m.reset_solution() + assert m.solver_metrics is None + + +# --------------------------------------------------------------------------- +# Solver-specific integration tests (parametrized over available solvers) +# --------------------------------------------------------------------------- + +# Solvers that have a tested _extract_metrics override providing solve_time etc. +_solvers_with_metrics = {"gurobi", "highs", "scip", "cplex", "xpress", "mosek"} + +direct_solvers = [ + s + for s in get_available_solvers_with_feature( + SolverFeature.DIRECT_API, available_solvers + ) + if s in _solvers_with_metrics +] +file_io_solvers = [ + s + for s in get_available_solvers_with_feature( + SolverFeature.READ_MODEL_FROM_FILE, available_solvers + ) + if s in _solvers_with_metrics +] +mip_solvers = [ + s + for s in get_available_solvers_with_feature( + SolverFeature.INTEGER_VARIABLES, available_solvers + ) + if s in _solvers_with_metrics +] + + +def _make_simple_lp() -> Model: + m = Model() + x = m.add_variables( + lower=xr.DataArray(np.zeros(3), dims=["i"]), + upper=xr.DataArray(np.ones(3), dims=["i"]), + name="x", + ) + m.add_constraints(x.sum() >= 1, name="con") + m.add_objective(x.sum()) + return m + + +def _make_simple_mip() -> Model: + m = Model() + x = m.add_variables(coords=[np.arange(3)], name="x", binary=True) + m.add_constraints(x.sum() >= 1, name="con") + m.add_objective(x.sum()) + return m + + +@pytest.mark.parametrize("solver", direct_solvers) +def test_solver_metrics_direct(solver: str) -> None: + m = _make_simple_lp() + m.solve(solver_name=solver, io_api="direct") + metrics = m.solver_metrics + assert metrics is not None + assert metrics.solver_name == solver + assert metrics.objective_value is not None + assert metrics.objective_value == pytest.approx(1.0) + assert metrics.solve_time is not None + assert metrics.solve_time >= 0 + + +@pytest.mark.parametrize("solver", file_io_solvers) +def test_solver_metrics_file_io(solver: str) -> None: + m = _make_simple_lp() + m.solve(solver_name=solver, io_api="lp") + metrics = m.solver_metrics + assert metrics is not None + assert metrics.solver_name == solver + assert metrics.objective_value is not None + assert metrics.objective_value == pytest.approx(1.0) + assert metrics.solve_time is not None + assert metrics.solve_time >= 0 + + +@pytest.mark.parametrize("solver", mip_solvers) +def test_solver_metrics_mip(solver: str) -> None: + """Solve a MIP and verify mip_gap and dual_bound are populated.""" + m = _make_simple_mip() + if solver in direct_solvers: + m.solve(solver_name=solver, io_api="direct") + else: + m.solve(solver_name=solver, io_api="lp") + metrics = m.solver_metrics + assert metrics is not None + assert metrics.solver_name == solver + assert metrics.objective_value == pytest.approx(1.0) + assert metrics.solve_time is not None + assert metrics.solve_time >= 0 + assert metrics.mip_gap is not None + assert metrics.mip_gap >= 0 + assert metrics.dual_bound is not None + assert isinstance(metrics.dual_bound, float)