diff --git a/CHANGELOG.md b/CHANGELOG.md index 74b81175..31fb421a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,12 +18,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Moved the conditional logic to `graphix.simulator` to remove code duplication in the backends. - Solves [#428](https://github.com/TeamGraphix/graphix/issues/428). +- #438: `ComplexUnit.try_from` now uses `cmath.isclose` for float comparison and has optional parameters `rel_tol` and `abs_tol`. + ### Changed - #181, #423: Structural separation of Pauli measurements - The class `Measurement` is now abstract and has two concrete subclasses: `PauliMeasurement` and `BlochMeasurement`. - `M` commands are now parameterized by an instance `Measurement` (instead of carrying a plane and an angle). - - Conversions are explicit with `Measurement.to_bloch()` and `Measurement.infer_pauli_measurements()`. + - Conversions are explicit with `Measurement.to_bloch()` and `Measurement.infer_pauli_measurements()`. Pauli measurement inference uses `math.isclose` and has optional parameters `rel_tol` and `abs_tol`. ## [0.3.4] - 2026-02-05 diff --git a/graphix/fundamentals.py b/graphix/fundamentals.py index c4de85f7..7a26a8a8 100644 --- a/graphix/fundamentals.py +++ b/graphix/fundamentals.py @@ -2,6 +2,7 @@ from __future__ import annotations +import cmath import enum from abc import ABC, ABCMeta, abstractmethod from enum import Enum, EnumMeta @@ -164,19 +165,36 @@ class ComplexUnit(EnumReprMixin, Enum): MINUS_J = 3 @staticmethod - def try_from(value: ComplexUnit | SupportsComplexCtor) -> ComplexUnit | None: - """Return the ComplexUnit instance if the value is compatible, None otherwise.""" + def try_from( + value: ComplexUnit | SupportsComplexCtor, rel_tol: float = 1e-09, abs_tol: float = 0.0 + ) -> ComplexUnit | None: + """Return the ComplexUnit instance if the value is compatible, None otherwise. + + Parameters + ---------- + value : ComplexUnit | SupportsComplexCtor + Complex value to convert. + rel_tol : float, optional + Relative tolerance for comparing values, passed to :func:`math.isclose`. Default is ``1e-9``. + abs_tol : float, optional + Absolute tolerance for comparing values, passed to :func:`math.isclose`. Default is ``0.0``. + + Returns + ------- + ComplexUnit | None + Complex unit close to value, or ``None`` otherwise. + """ if isinstance(value, ComplexUnit): return value value = complex(value) - if value == 1: - return ComplexUnit.ONE - if value == -1: - return ComplexUnit.MINUS_ONE - if value == 1j: - return ComplexUnit.J - if value == -1j: - return ComplexUnit.MINUS_J + for reference, result in ( + (1, ComplexUnit.ONE), + (-1, ComplexUnit.MINUS_ONE), + (1j, ComplexUnit.J), + (-1j, ComplexUnit.MINUS_J), + ): + if cmath.isclose(value, reference, rel_tol=rel_tol, abs_tol=abs_tol): + return result return None @staticmethod diff --git a/graphix/parameter.py b/graphix/parameter.py index b822f58f..0261b19c 100644 --- a/graphix/parameter.py +++ b/graphix/parameter.py @@ -376,12 +376,10 @@ def subs(value: T, variable: Parameter, substitute: ExpressionOrSupportsFloat) - if not isinstance(value, Expression): return value new_value = value.subs(variable, substitute) - # On Python<=3.10, complex is not a subtype of SupportsComplex - if isinstance(new_value, (complex, SupportsComplex)): - c = complex(new_value) - if c.imag == 0.0: - return c.real - return c + if isinstance(new_value, complex) and math.isclose(new_value.imag, 0.0): + # Conversion to float, to enable the simulator to call + # real trigonometric functions to the result. + return new_value.real return new_value @@ -416,12 +414,10 @@ def xreplace(value: T, assignment: Mapping[Parameter, ExpressionOrSupportsFloat] if not isinstance(value, Expression): return value new_value = value.xreplace(assignment) - # On Python<=3.10, complex is not a subtype of SupportsComplex - if isinstance(new_value, (complex, SupportsComplex)): - c = complex(new_value) - if c.imag == 0.0: - return c.real - return c + if isinstance(new_value, complex) and math.isclose(new_value.imag, 0.0): + # Conversion to float, to enable the simulator to call + # real trigonometric functions to the result. + return new_value.real return new_value diff --git a/noxfile.py b/noxfile.py index 602e441e..2cbd79e7 100644 --- a/noxfile.py +++ b/noxfile.py @@ -98,7 +98,7 @@ class ReverseDependency: branch: str | None = None version_constraint: VersionRange | None = None doctest_modules: bool = True - initialization: Callable[[Session], None] | None = None + initialization: Callable[[Session], bool | None] | None = None @nox.session(python=PYTHON_VERSIONS) diff --git a/requirements-dev.txt b/requirements-dev.txt index 12c6dcc2..f5ac5839 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,17 +2,17 @@ mypy==1.19.1 pre-commit # for language-agnostic hooks pyright -ruff==0.15.0 +ruff==0.15.1 # Stubs -types-networkx==3.6.1.20251220 +types-networkx==3.6.1.20260210 types-psutil types-setuptools scipy-stubs # Tests # Keep in sync with CI -nox==2025.11.12 +nox==2026.2.9 psutil pytest pytest-benchmark diff --git a/tests/test_fundamentals.py b/tests/test_fundamentals.py index 2460fad2..48adbbc9 100644 --- a/tests/test_fundamentals.py +++ b/tests/test_fundamentals.py @@ -50,32 +50,32 @@ def test_mul_int(self) -> None: def test_mul_float(self) -> None: left = Sign.PLUS * 1.0 assert isinstance(left, float) - assert left == float(Sign.PLUS) + assert left == float(Sign.PLUS) # noqa: RUF069 right = 1.0 * Sign.PLUS assert isinstance(right, float) - assert right == float(Sign.PLUS) + assert right == float(Sign.PLUS) # noqa: RUF069 left = Sign.MINUS * 1.0 assert isinstance(left, float) - assert left == float(Sign.MINUS) + assert left == float(Sign.MINUS) # noqa: RUF069 right = 1.0 * Sign.MINUS assert isinstance(right, float) - assert right == float(Sign.MINUS) + assert right == float(Sign.MINUS) # noqa: RUF069 def test_mul_complex(self) -> None: left = Sign.PLUS * complex(1) assert isinstance(left, complex) - assert left == complex(Sign.PLUS) + assert left == complex(Sign.PLUS) # noqa: RUF069 right = complex(1) * Sign.PLUS assert isinstance(right, complex) - assert right == complex(Sign.PLUS) + assert right == complex(Sign.PLUS) # noqa: RUF069 left = Sign.MINUS * complex(1) assert isinstance(left, complex) - assert left == complex(Sign.MINUS) + assert left == complex(Sign.MINUS) # noqa: RUF069 right = complex(1) * Sign.MINUS assert isinstance(right, complex) - assert right == complex(Sign.MINUS) + assert right == complex(Sign.MINUS) # noqa: RUF069 def test_int(self) -> None: # Necessary to justify `type: ignore` @@ -103,10 +103,10 @@ def test_properties(self, sign: Sign, is_imag: bool) -> None: assert ComplexUnit.from_properties(sign=sign, is_imag=is_imag).is_imag == is_imag def test_complex(self) -> None: - assert complex(ComplexUnit.ONE) == 1 - assert complex(ComplexUnit.J) == 1j - assert complex(ComplexUnit.MINUS_ONE) == -1 - assert complex(ComplexUnit.MINUS_J) == -1j + assert complex(ComplexUnit.ONE) == 1 # noqa: RUF069 + assert complex(ComplexUnit.J) == 1j # noqa: RUF069 + assert complex(ComplexUnit.MINUS_ONE) == -1 # noqa: RUF069 + assert complex(ComplexUnit.MINUS_J) == -1j # noqa: RUF069 def test_str(self) -> None: assert str(ComplexUnit.ONE) == "1" @@ -116,15 +116,15 @@ def test_str(self) -> None: @pytest.mark.parametrize(("lhs", "rhs"), itertools.product(ComplexUnit, ComplexUnit)) def test_mul_self(self, lhs: ComplexUnit, rhs: ComplexUnit) -> None: - assert complex(lhs * rhs) == complex(lhs) * complex(rhs) + assert complex(lhs * rhs) == complex(lhs) * complex(rhs) # noqa: RUF069 def test_mul_number(self) -> None: assert ComplexUnit.ONE * 1 == ComplexUnit.ONE assert 1 * ComplexUnit.ONE == ComplexUnit.ONE - assert ComplexUnit.ONE * 1.0 == ComplexUnit.ONE - assert 1.0 * ComplexUnit.ONE == ComplexUnit.ONE - assert ComplexUnit.ONE * complex(1) == ComplexUnit.ONE - assert complex(1) * ComplexUnit.ONE == ComplexUnit.ONE + assert ComplexUnit.ONE * 1.0 == ComplexUnit.ONE # noqa: RUF069 + assert 1.0 * ComplexUnit.ONE == ComplexUnit.ONE # noqa: RUF069 + assert ComplexUnit.ONE * complex(1) == ComplexUnit.ONE # noqa: RUF069 + assert complex(1) * ComplexUnit.ONE == ComplexUnit.ONE # noqa: RUF069 def test_neg(self) -> None: assert -ComplexUnit.ONE == ComplexUnit.MINUS_ONE diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 520adfcf..7970484a 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -25,7 +25,7 @@ def test_pattern_affine_operations() -> None: assert alpha + 1 + 1 == alpha + 2 assert alpha + alpha == 2 * alpha assert alpha - alpha == 0 - assert alpha / 2 == 0.5 * alpha + assert alpha / 2 == 0.5 * alpha # noqa: RUF069 assert -alpha + alpha == 0 beta = Placeholder("beta") with pytest.raises(PlaceholderOperationError): diff --git a/tests/test_pauli.py b/tests/test_pauli.py index b867cef7..67069088 100644 --- a/tests/test_pauli.py +++ b/tests/test_pauli.py @@ -71,21 +71,21 @@ def test_iterate_false(self) -> None: cmp = list(Pauli.iterate(symbol_only=False)) assert len(cmp) == 16 assert cmp[0] == Pauli.I - assert cmp[1] == 1j * Pauli.I + assert cmp[1] == 1j * Pauli.I # noqa: RUF069 assert cmp[2] == -1 * Pauli.I - assert cmp[3] == -1j * Pauli.I + assert cmp[3] == -1j * Pauli.I # noqa: RUF069 assert cmp[4] == Pauli.X - assert cmp[5] == 1j * Pauli.X + assert cmp[5] == 1j * Pauli.X # noqa: RUF069 assert cmp[6] == -1 * Pauli.X - assert cmp[7] == -1j * Pauli.X + assert cmp[7] == -1j * Pauli.X # noqa: RUF069 assert cmp[8] == Pauli.Y - assert cmp[9] == 1j * Pauli.Y + assert cmp[9] == 1j * Pauli.Y # noqa: RUF069 assert cmp[10] == -1 * Pauli.Y - assert cmp[11] == -1j * Pauli.Y + assert cmp[11] == -1j * Pauli.Y # noqa: RUF069 assert cmp[12] == Pauli.Z - assert cmp[13] == 1j * Pauli.Z + assert cmp[13] == 1j * Pauli.Z # noqa: RUF069 assert cmp[14] == -1 * Pauli.Z - assert cmp[15] == -1j * Pauli.Z + assert cmp[15] == -1j * Pauli.Z # noqa: RUF069 def test_iter_meta(self) -> None: it = Pauli.iterate(symbol_only=False)