diff --git a/CHANGELOG.md b/CHANGELOG.md index 3210ebb3..57c90440 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,33 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added + +- **Noise Model API**: Event-based noise injection system for Stim circuit compilation + - Added `NoiseModel` base class for custom noise behavior with `on_prepare`, `on_entangle`, `on_measure`, and `on_idle` hooks + - Added typed `NoiseOp` dataclasses: `PauliChannel1`, `PauliChannel2`, `HeraldedPauliChannel1`, `HeraldedErase`, `RawStimOp`, `MeasurementFlip` + - Added event dataclasses: `PrepareEvent`, `EntangleEvent`, `MeasureEvent`, `IdleEvent` with `NodeInfo` and `Coordinate` + - Added `NoisePlacement` enum for controlling noise insertion timing (AUTO, BEFORE, AFTER) + - Added `noise_op_to_stim()` conversion function + - Added `depolarize1_probs()` and `depolarize2_probs()` utility functions for depolarizing noise + +- **Built-in Noise Models**: Ready-to-use noise model implementations + - Added `DepolarizingNoiseModel` for single and two-qubit depolarizing noise + - Added `MeasurementFlipNoiseModel` for measurement bit-flip errors using Stim's built-in `MX(p)` syntax + +- **Stim Compiler Enhancement**: Extended `stim_compile()` to accept `noise_models` parameter + - Support for multiple noise models via `Sequence[NoiseModel]` + - Added `tick_duration` parameter for idle noise calculations + - Automatic measurement record tracking for heralded noise operations + +- **Documentation**: Added comprehensive Sphinx documentation for the noise model module + +### Changed + +- **Stim Compiler**: Refactored internal structure to support noise model integration + ## [0.2.1] - 2026-01-16 ### Added diff --git a/docs/source/noise_model.rst b/docs/source/noise_model.rst new file mode 100644 index 00000000..ea1bc015 --- /dev/null +++ b/docs/source/noise_model.rst @@ -0,0 +1,7 @@ +Noise Model +=========== + +.. automodule:: graphqomb.noise_model + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/references.rst b/docs/source/references.rst index 5bc98498..984e0783 100644 --- a/docs/source/references.rst +++ b/docs/source/references.rst @@ -21,4 +21,5 @@ Module reference qompiler scheduler stim_compiler + noise_model visualizer diff --git a/graphqomb/noise_model.py b/graphqomb/noise_model.py new file mode 100644 index 00000000..d6c3d6fa --- /dev/null +++ b/graphqomb/noise_model.py @@ -0,0 +1,943 @@ +r"""Noise model interface for Stim circuit compilation. + +This module provides: + +- `NoisePlacement`: Enum for noise placement. +- `Coordinate`: N-dimensional coordinate dataclass. +- `NodeInfo`: Node identifier with optional coordinate. +- `PrepareEvent`, `EntangleEvent`, `MeasureEvent`, `IdleEvent`: Event dataclasses. +- `NoiseEvent`: Union type of all event types. +- `PauliChannel1`, `PauliChannel2`, `HeraldedPauliChannel1`, `HeraldedErase`, `RawStimOp`, + `MeasurementFlip`: NoiseOp types. +- `NoiseOp`: Union type of all noise operation types. +- `default_noise_placement`: Global default placement policy for AUTO operations. +- `NoiseModel`: Base class for noise models. +- `DepolarizingNoiseModel`, `MeasurementFlipNoiseModel`: Built-in noise models. +- `noise_op_to_stim`: Conversion function. +- `depolarize1_probs`: Utility to create single-qubit depolarizing probabilities. +- `depolarize2_probs`: Utility to create 2-qubit depolarizing probabilities. +- :data:`PAULI_CHANNEL_2_ORDER`: Constant for Pauli channel order. + +Examples +-------- +Create a simple depolarizing noise model: + +>>> from graphqomb.noise_model import ( +... NoiseModel, +... PrepareEvent, +... EntangleEvent, +... PauliChannel1, +... PauliChannel2, +... depolarize1_probs, +... depolarize2_probs, +... ) +>>> +>>> class DepolarizingNoise(NoiseModel): +... def __init__(self, p1: float, p2: float) -> None: +... self.p1 = p1 # Single-qubit depolarizing probability +... self.p2 = p2 # Two-qubit depolarizing probability +... +... def on_prepare(self, event: PrepareEvent) -> list[PauliChannel1]: +... return [PauliChannel1(**depolarize1_probs(self.p1), targets=[event.node.id])] +... +... def on_entangle(self, event: EntangleEvent) -> list[PauliChannel2]: +... return [PauliChannel2(probabilities=depolarize2_probs(self.p2), targets=[(event.node0.id, event.node1.id)])] + +Use with stim_compile: + +>>> from graphqomb.stim_compiler import stim_compile +>>> # pattern = ... # your compiled pattern +>>> # stim_str = stim_compile(pattern, noise_models=[DepolarizingNoise(0.001, 0.01)]) + +Use heralded noise that adds measurement records: + +>>> from graphqomb.noise_model import NoiseModel, MeasureEvent, HeraldedPauliChannel1 +>>> +>>> class HeraldedMeasurementNoise(NoiseModel): +... def on_measure(self, event: MeasureEvent) -> list[HeraldedPauliChannel1]: +... # Heralded erasure with 10% probability +... return [HeraldedPauliChannel1(pi=0.1, px=0.0, py=0.0, pz=0.0, targets=[event.node.id])] + +Notes +----- +- **Placement control**: Each `NoiseOp` has a ``placement`` attribute. + ``AUTO`` defers to :func:`default_noise_placement`, while + ``BEFORE``/``AFTER`` force insertion side. + +- **Record delta**: Heralded instructions (`HeraldedPauliChannel1`, + `HeraldedErase`) add measurement records. The compiler automatically + tracks these to compute correct detector indices. + +- **Coordinate access**: Events provide `NodeInfo` objects with optional + coordinates, useful for position-dependent noise models. + +See Also +-------- +stim_compile : The main compilation function that accepts a NoiseModel. +""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from enum import Enum, auto +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from graphqomb.common import Axis + + +PAULI_CHANNEL_2_ORDER: tuple[str, ...] = ( + "IX", + "IY", + "IZ", + "XI", + "XX", + "XY", + "XZ", + "YI", + "YX", + "YY", + "YZ", + "ZI", + "ZX", + "ZY", + "ZZ", +) + + +def _validate_probability(name: str, value: float) -> float: + """Validate a probability value and return it as float. + + Parameters + ---------- + name : `str` + Human-readable probability name used in error messages. + value : `float` + Probability value to validate. + + Returns + ------- + `float` + The validated probability value. + + Raises + ------ + ValueError + If the probability is outside the range ``[0, 1]``. + """ + p = float(value) + if not 0.0 <= p <= 1.0: + msg = f"{name} must be within [0, 1], got {value!r}" + raise ValueError(msg) + return p + + +def _validate_probability_sum(name: str, probabilities: Sequence[float], *, atol: float = 1e-12) -> None: + r"""Validate that probabilities sum to at most 1 within tolerance. + + Parameters + ---------- + name : `str` + Human-readable name used in error messages. + probabilities : `collections.abc.Sequence`\[`float`\] + Probability values to validate. + atol : `float`, optional + Absolute tolerance for sum comparison, by default ``1e-12``. + + Raises + ------ + ValueError + If the total probability exceeds ``1 + atol``. + """ + total = float(sum(probabilities)) + if total > 1.0 + atol: + msg = f"{name} probabilities must sum to <= 1, got {total}" + raise ValueError(msg) + + +def depolarize1_probs(p: float) -> dict[str, float]: + r"""Create probability dict for single-qubit depolarizing channel. + + Parameters + ---------- + p : `float` + Total depolarizing probability. + + Returns + ------- + `dict`\[`str`, `float`\] + Mapping with keys ``px``, ``py``, ``pz`` each set to ``p/3``. + + Examples + -------- + >>> probs = depolarize1_probs(0.03) + >>> probs["px"] + 0.01 + >>> probs["py"] + 0.01 + """ + p = _validate_probability("depolarize1_probs.p", p) + p_each = p / 3 + return {"px": p_each, "py": p_each, "pz": p_each} + + +def depolarize2_probs(p: float) -> dict[str, float]: + r"""Create probability dict for 2-qubit depolarizing channel. + + Parameters + ---------- + p : `float` + Total depolarizing probability. + + Returns + ------- + `dict`\[`str`, `float`\] + Mapping from Pauli pair to probability ``p/15``. + + Examples + -------- + >>> probs = depolarize2_probs(0.15) + >>> probs["ZZ"] + 0.01 + >>> len(probs) + 15 + """ + p = _validate_probability("depolarize2_probs.p", p) + p_each = p / 15 + return dict.fromkeys(PAULI_CHANNEL_2_ORDER, p_each) + + +class NoisePlacement(Enum): + """Where to insert noise relative to the main operation.""" + + AUTO = auto() + BEFORE = auto() + AFTER = auto() + + +@dataclass(frozen=True) +class Coordinate: + r"""N-dimensional coordinate for a node. + + Parameters + ---------- + values : `tuple`\[`float`, ...\] + The coordinate values as a tuple of floats. + + Examples + -------- + >>> coord = Coordinate((1.0, 2.0, 3.0)) + >>> coord.xy + (1.0, 2.0) + >>> coord.xyz + (1.0, 2.0, 3.0) + """ + + values: tuple[float, ...] + + @property + def xy(self) -> tuple[float, float] | None: + """Return the first two dimensions as (x, y), or None if fewer than 2 dimensions.""" + if len(self.values) < 2: # noqa: PLR2004 + return None + return (self.values[0], self.values[1]) + + @property + def xyz(self) -> tuple[float, float, float] | None: + """Return the first three dimensions as (x, y, z), or None if fewer than 3 dimensions.""" + if len(self.values) < 3: # noqa: PLR2004 + return None + return (self.values[0], self.values[1], self.values[2]) + + +@dataclass(frozen=True) +class NodeInfo: + """Node identifier with optional coordinate. + + Parameters + ---------- + id : `int` + The unique node index in the pattern. + coord : `Coordinate` | `None` + The spatial coordinate of the node, if available. + """ + + id: int + coord: Coordinate | None + + +@dataclass(frozen=True) +class PrepareEvent: + """Event emitted when a qubit is prepared (N command). + + Parameters + ---------- + time : `int` + The current tick (time step) in the pattern execution. + node : `NodeInfo` + Information about the node being prepared. + is_input : `bool` + Whether this node is an input node of the pattern. + Input nodes may require different noise treatment. + """ + + time: int + node: NodeInfo + is_input: bool + + +@dataclass(frozen=True) +class EntangleEvent: + r"""Event emitted when two qubits are entangled (E command / CZ gate). + + Parameters + ---------- + time : `int` + The current tick (time step) in the pattern execution. + node0 : `NodeInfo` + Information about the first node in the entanglement. + node1 : `NodeInfo` + Information about the second node in the entanglement. + edge : `tuple`\[`int`, `int`\] + The edge as ``(min_node_id, max_node_id)``. + """ + + time: int + node0: NodeInfo + node1: NodeInfo + edge: tuple[int, int] + + +@dataclass(frozen=True) +class MeasureEvent: + """Event emitted when a qubit is measured (M command). + + Parameters + ---------- + time : `int` + The current tick (time step) in the pattern execution. + node : `NodeInfo` + Information about the node being measured. + axis : `Axis` + The measurement axis (X, Y, or Z). + """ + + time: int + node: NodeInfo + axis: Axis + + +@dataclass(frozen=True) +class IdleEvent: + r"""Event emitted for qubits that are idle during a TICK. + + Parameters + ---------- + time : `int` + The current tick (time step) in the pattern execution. + nodes : `collections.abc.Sequence`\[`NodeInfo`\] + Information about all nodes that are idle during this tick. + duration : `float` + The duration of the idle period (from ``tick_duration`` parameter). + """ + + time: int + nodes: Sequence[NodeInfo] + duration: float + + +NoiseEvent = PrepareEvent | EntangleEvent | MeasureEvent | IdleEvent +"""Union type of all noise event types.""" + + +def default_noise_placement(event: NoiseEvent) -> NoisePlacement: + """Return the global default placement for AUTO noise operations. + + Measurement noise is inserted before measurement operations. Noise for all + other events is inserted after the corresponding operation. + + Parameters + ---------- + event : `NoiseEvent` + The event for which to determine the default placement. + + Returns + ------- + `NoisePlacement` + ``BEFORE`` for measurement events, ``AFTER`` for all others. + """ + if isinstance(event, MeasureEvent): + return NoisePlacement.BEFORE + return NoisePlacement.AFTER + + +@dataclass(frozen=True) +class PauliChannel1: + r"""Single-qubit Pauli channel noise operation. + + Applies independent X, Y, Z errors with given probabilities. + Corresponds to Stim's ``PAULI_CHANNEL_1`` instruction. + + Parameters + ---------- + px : `float` + Probability of X error. + py : `float` + Probability of Y error. + pz : `float` + Probability of Z error. + targets : `collections.abc.Sequence`\[`int`\] + Target qubit indices. + placement : `NoisePlacement` + Whether to insert before or after the main operation. + ``AUTO`` defers to :func:`default_noise_placement`. + + Examples + -------- + >>> op = PauliChannel1(px=0.01, py=0.01, pz=0.01, targets=[0, 1]) + >>> noise_op_to_stim(op) + ('PAULI_CHANNEL_1(0.01,0.01,0.01) 0 1', 0) + """ + + px: float + py: float + pz: float + targets: Sequence[int] + placement: NoisePlacement = NoisePlacement.AUTO + + def __post_init__(self) -> None: + object.__setattr__(self, "targets", tuple(self.targets)) + + +@dataclass(frozen=True) +class PauliChannel2: + r"""Two-qubit Pauli channel noise operation. + + Applies correlated two-qubit Pauli errors. + Corresponds to Stim's ``PAULI_CHANNEL_2`` instruction. + + Parameters + ---------- + probabilities : `collections.abc.Sequence`\[`float`\] | `collections.abc.Mapping`\[`str`, `float`\] + Either a sequence of 15 probabilities in the order + (IX, IY, IZ, XI, XX, XY, XZ, YI, YX, YY, YZ, ZI, ZX, ZY, ZZ), + or a mapping from Pauli string keys to probabilities. + Missing keys default to 0. + targets : `collections.abc.Sequence`\[`tuple`\[`int`, `int`\]\] + Target qubit pairs as ``[(q0, q1), ...]``. + placement : `NoisePlacement` + Whether to insert before or after the main operation. + ``AUTO`` defers to :func:`default_noise_placement`. + + Examples + -------- + Using a mapping (recommended for sparse errors): + + >>> op = PauliChannel2(probabilities={"ZZ": 0.01}, targets=[(0, 1)]) + >>> text, delta = noise_op_to_stim(op) + >>> "PAULI_CHANNEL_2" in text + True + + Using a full probability sequence: + + >>> probs = [0.0] * 14 + [0.01] # Only ZZ error + >>> op = PauliChannel2(probabilities=probs, targets=[(2, 3)]) + """ + + probabilities: Sequence[float] | Mapping[str, float] + targets: Sequence[tuple[int, int]] + placement: NoisePlacement = NoisePlacement.AUTO + + def __post_init__(self) -> None: + object.__setattr__(self, "targets", tuple(tuple(pair) for pair in self.targets)) + + +@dataclass(frozen=True) +class HeraldedPauliChannel1: + r"""Heralded single-qubit Pauli channel noise operation. + + Similar to `PauliChannel1` but produces a herald measurement record + indicating whether an error occurred. The herald outcome is 1 if any + error occurred (including identity with probability ``pi``). + Corresponds to Stim's ``HERALDED_PAULI_CHANNEL_1`` instruction. + + Parameters + ---------- + pi : `float` + Probability of heralded identity (no error but flagged). + px : `float` + Probability of heralded X error. + py : `float` + Probability of heralded Y error. + pz : `float` + Probability of heralded Z error. + targets : `collections.abc.Sequence`\[`int`\] + Target qubit indices. + placement : `NoisePlacement` + Whether to insert before or after the main operation. + ``AUTO`` defers to :func:`default_noise_placement`. + + Notes + ----- + This instruction adds one measurement record per target qubit. + The compiler automatically tracks this when computing detector indices. + + Examples + -------- + >>> op = HeraldedPauliChannel1(pi=0.0, px=0.01, py=0.0, pz=0.0, targets=[5]) + >>> text, delta = noise_op_to_stim(op) + >>> text + 'HERALDED_PAULI_CHANNEL_1(0.0,0.01,0.0,0.0) 5' + >>> delta # One record added per target + 1 + """ + + pi: float + px: float + py: float + pz: float + targets: Sequence[int] + placement: NoisePlacement = NoisePlacement.AUTO + + def __post_init__(self) -> None: + object.__setattr__(self, "targets", tuple(self.targets)) + + +@dataclass(frozen=True) +class HeraldedErase: + r"""Heralded erasure noise operation. + + Models photon loss or erasure errors with a herald signal. + Corresponds to Stim's ``HERALDED_ERASE`` instruction. + + Parameters + ---------- + p : `float` + Probability of erasure. + targets : `collections.abc.Sequence`\[`int`\] + Target qubit indices. + placement : `NoisePlacement` + Whether to insert before or after the main operation. + ``AUTO`` defers to :func:`default_noise_placement`. + + Notes + ----- + This instruction adds one measurement record per target qubit. + The compiler automatically tracks this when computing detector indices. + + Examples + -------- + >>> op = HeraldedErase(p=0.05, targets=[0, 1, 2]) + >>> text, delta = noise_op_to_stim(op) + >>> text + 'HERALDED_ERASE(0.05) 0 1 2' + >>> delta # One record added per target + 3 + """ + + p: float + targets: Sequence[int] + placement: NoisePlacement = NoisePlacement.AUTO + + def __post_init__(self) -> None: + object.__setattr__(self, "targets", tuple(self.targets)) + + +@dataclass(frozen=True) +class RawStimOp: + """Raw Stim instruction for advanced use cases. + + Use this when the typed noise operations don't cover your use case. + The text is inserted directly into the Stim circuit. + + Parameters + ---------- + text : `str` + A single Stim instruction line (without trailing newline). + record_delta : `int` + The number of measurement records added by this instruction. + Most noise instructions do not add records (default 0). + placement : `NoisePlacement` + Whether to insert before or after the main operation. + ``AUTO`` defers to :func:`default_noise_placement`. + + Examples + -------- + >>> op = RawStimOp("X_ERROR(0.001) 0 1 2") + >>> noise_op_to_stim(op) + ('X_ERROR(0.001) 0 1 2', 0) + + With custom record delta for measurement-like instructions: + + >>> op = RawStimOp("MR 5", record_delta=1) + >>> noise_op_to_stim(op) + ('MR 5', 1) + """ + + text: str + record_delta: int = 0 + placement: NoisePlacement = NoisePlacement.AUTO + + def __post_init__(self) -> None: + if "\n" in self.text or "\r" in self.text: + msg = "RawStimOp.text must be a single Stim instruction line without newlines" + raise ValueError(msg) + if self.record_delta < 0: + msg = f"RawStimOp.record_delta must be non-negative, got {self.record_delta}" + raise ValueError(msg) + expected_delta = _infer_raw_record_delta(self.text) + if expected_delta is not None and self.record_delta != expected_delta: + msg = ( + f"RawStimOp.record_delta mismatch for instruction {self.text!r}: " + f"expected {expected_delta}, got {self.record_delta}" + ) + raise ValueError(msg) + + +@dataclass(frozen=True) +class MeasurementFlip: + """Measurement flip error applied to measurement instruction. + + Unlike other NoiseOp types that insert separate instructions, + this modifies the measurement instruction itself to use Stim's + built-in measurement error probability: MX(p) instead of MX. + + Parameters + ---------- + p : `float` + Probability of measurement result flip. + target : `int` + Target qubit index (must match the measurement target). + placement : `NoisePlacement` + Placement attribute for compatibility (ignored, as this modifies + the measurement instruction itself). + """ + + p: float + target: int + placement: NoisePlacement = NoisePlacement.AUTO + + +NoiseOp = PauliChannel1 | PauliChannel2 | HeraldedPauliChannel1 | HeraldedErase | RawStimOp | MeasurementFlip +"""Union type of all noise operation types.""" + + +class NoiseModel: + """Base class for custom noise injection during Stim compilation. + + Subclass this to define custom noise behavior by overriding one or more + of the event handler methods. Each method receives an event object with + context about the current operation and returns noise operations to inject. + + Examples + -------- + >>> class SimpleNoise(NoiseModel): + ... def on_prepare(self, event: PrepareEvent) -> list[PauliChannel1]: + ... # Add depolarizing noise after preparation + ... p = 0.001 / 3 + ... return [PauliChannel1(px=p, py=p, pz=p, targets=[event.node.id])] + ... + ... def on_measure(self, event: MeasureEvent) -> list[PauliChannel1]: + ... # Add bit-flip noise before measurement + ... return [ + ... PauliChannel1(px=0.01, py=0.0, pz=0.0, targets=[event.node.id], placement=NoisePlacement.BEFORE) + ... ] + + See Also + -------- + stim_compile : The main compilation function that accepts a NoiseModel. + """ + + def on_prepare(self, event: PrepareEvent) -> Sequence[NoiseOp]: # noqa: ARG002, PLR6301 + r"""Return noise operations to inject at qubit preparation. + + Parameters + ---------- + event : `PrepareEvent` + Context about the preparation operation. + + Returns + ------- + `collections.abc.Sequence`\[`NoiseOp`\] + Zero or more noise operations to inject. + """ + return [] + + def on_entangle(self, event: EntangleEvent) -> Sequence[NoiseOp]: # noqa: ARG002, PLR6301 + r"""Return noise operations to inject at entanglement. + + Parameters + ---------- + event : `EntangleEvent` + Context about the entanglement operation. + + Returns + ------- + `collections.abc.Sequence`\[`NoiseOp`\] + Zero or more noise operations to inject. + """ + return [] + + def on_measure(self, event: MeasureEvent) -> Sequence[NoiseOp]: # noqa: ARG002, PLR6301 + r"""Return noise operations to inject at measurement. + + Parameters + ---------- + event : `MeasureEvent` + Context about the measurement operation. + + Returns + ------- + `collections.abc.Sequence`\[`NoiseOp`\] + Zero or more noise operations to inject. + """ + return [] + + def on_idle(self, event: IdleEvent) -> Sequence[NoiseOp]: # noqa: ARG002, PLR6301 + r"""Return noise operations to inject during idle periods. + + Parameters + ---------- + event : `IdleEvent` + Context about the idle period. + + Returns + ------- + `collections.abc.Sequence`\[`NoiseOp`\] + Zero or more noise operations to inject. + """ + return [] + + +def noise_op_to_stim(op: NoiseOp) -> tuple[str, int]: # noqa: PLR0911, C901 + r"""Convert a NoiseOp into a Stim instruction line and record delta. + + Parameters + ---------- + op : `NoiseOp` + The noise operation to convert. + + Returns + ------- + `tuple`\[`str`, `int`\] + A tuple of ``(stim_instruction, record_delta)`` where + ``stim_instruction`` is a single line of Stim code and + ``record_delta`` is the number of measurement records added. + + Raises + ------ + TypeError + If ``op`` is not a recognized NoiseOp type. + + Examples + -------- + >>> op = PauliChannel1(px=0.01, py=0.02, pz=0.03, targets=[0]) + >>> noise_op_to_stim(op) + ('PAULI_CHANNEL_1(0.01,0.02,0.03) 0', 0) + """ + if isinstance(op, RawStimOp): + return op.text, op.record_delta + + if isinstance(op, PauliChannel1): + if not op.targets: + return "", 0 + px = _validate_probability("PauliChannel1.px", op.px) + py = _validate_probability("PauliChannel1.py", op.py) + pz = _validate_probability("PauliChannel1.pz", op.pz) + _validate_probability_sum("PauliChannel1", (px, py, pz)) + targets = " ".join(str(t) for t in op.targets) + return f"PAULI_CHANNEL_1({px},{py},{pz}) {targets}", 0 + + if isinstance(op, PauliChannel2): + if not op.targets: + return "", 0 + args = _pauli_channel_2_args(op.probabilities) + flat_targets = _flatten_pairs(op.targets) + targets_str = " ".join(str(t) for t in flat_targets) + args_str = ",".join(str(v) for v in args) + return f"PAULI_CHANNEL_2({args_str}) {targets_str}", 0 + + if isinstance(op, HeraldedPauliChannel1): + if not op.targets: + return "", 0 + pi = _validate_probability("HeraldedPauliChannel1.pi", op.pi) + px = _validate_probability("HeraldedPauliChannel1.px", op.px) + py = _validate_probability("HeraldedPauliChannel1.py", op.py) + pz = _validate_probability("HeraldedPauliChannel1.pz", op.pz) + _validate_probability_sum("HeraldedPauliChannel1", (pi, px, py, pz)) + targets = " ".join(str(t) for t in op.targets) + return ( + f"HERALDED_PAULI_CHANNEL_1({pi},{px},{py},{pz}) {targets}", + len(op.targets), + ) + + if isinstance(op, HeraldedErase): + if not op.targets: + return "", 0 + p = _validate_probability("HeraldedErase.p", op.p) + targets = " ".join(str(t) for t in op.targets) + return f"HERALDED_ERASE({p}) {targets}", len(op.targets) + + if isinstance(op, MeasurementFlip): + _validate_probability("MeasurementFlip.p", op.p) + # MeasurementFlip is handled specially in the compiler by modifying + # the measurement instruction. It should not be emitted as a separate op. + return "", 0 + + msg = f"Unsupported noise op type: {type(op)!r}" + raise TypeError(msg) + + +def _pauli_channel_2_args(probabilities: Sequence[float] | Mapping[str, float]) -> tuple[float, ...]: + if isinstance(probabilities, Mapping): + unknown = set(probabilities) - set(PAULI_CHANNEL_2_ORDER) + if unknown: + msg = f"Unknown PAULI_CHANNEL_2 keys: {sorted(unknown)}" + raise ValueError(msg) + values = tuple(float(probabilities.get(key, 0.0)) for key in PAULI_CHANNEL_2_ORDER) + for key, value in zip(PAULI_CHANNEL_2_ORDER, values, strict=True): + _validate_probability(f"PauliChannel2.probabilities[{key}]", value) + _validate_probability_sum("PauliChannel2", values) + return values + values = tuple(float(v) for v in probabilities) + if len(values) != len(PAULI_CHANNEL_2_ORDER): + msg = f"PAULI_CHANNEL_2 expects {len(PAULI_CHANNEL_2_ORDER)} probabilities, got {len(values)}" + raise ValueError(msg) + for index, value in enumerate(values): + _validate_probability(f"PauliChannel2.probabilities[{index}]", value) + _validate_probability_sum("PauliChannel2", values) + return values + + +def _flatten_pairs(pairs: Sequence[tuple[int, int]]) -> tuple[int, ...]: + flat: list[int] = [] + for pair in pairs: + if len(pair) != 2: # noqa: PLR2004 + msg = f"PAULI_CHANNEL_2 targets must be pairs, got: {pair!r}" + raise ValueError(msg) + flat.extend(pair) + return tuple(flat) + + +_PER_TARGET_RECORD_DELTA_INSTRUCTIONS: frozenset[str] = frozenset( + { + "M", + "MX", + "MY", + "MZ", + "MR", + "MRX", + "MRY", + "MRZ", + "HERALDED_ERASE", + "HERALDED_PAULI_CHANNEL_1", + } +) + + +def _infer_raw_record_delta(text: str) -> int | None: + """Infer record delta from a raw instruction when the rule is unambiguous. + + Returns + ------- + int | None + Number of records produced if it can be inferred, otherwise None. + """ + stripped = text.strip() + if not stripped: + return 0 + parts = stripped.split() + instruction = parts[0].split("(", 1)[0] + if instruction in _PER_TARGET_RECORD_DELTA_INSTRUCTIONS: + return len(parts) - 1 + return None + + +# ---- Built-in NoiseModel implementations ---- + + +class DepolarizingNoiseModel(NoiseModel): + """Depolarizing noise after single and two-qubit gates. + + This model adds depolarizing noise after qubit preparation (RX) and + entanglement (CZ) operations. + + Parameters + ---------- + p1 : `float` + Single-qubit depolarizing probability (after RX preparation). + p2 : `float` | `None` + Two-qubit depolarizing probability (after CZ). + If None, defaults to p1. + + Examples + -------- + >>> from graphqomb.noise_model import DepolarizingNoiseModel + >>> model = DepolarizingNoiseModel(p1=0.001, p2=0.01) + >>> # Use with stim_compile: + >>> # stim_compile(pattern, noise_models=[model]) + """ + + def __init__(self, p1: float, p2: float | None = None) -> None: + self._p1 = p1 + self._p2 = p2 if p2 is not None else p1 + + def on_prepare(self, event: PrepareEvent) -> Sequence[NoiseOp]: + r"""Add single-qubit depolarizing noise after preparation. + + Returns + ------- + `collections.abc.Sequence`\[`NoiseOp`\] + A tuple containing DEPOLARIZE1 instruction, or empty if p1 <= 0. + """ + if self._p1 <= 0: + return () + return (RawStimOp(f"DEPOLARIZE1({self._p1}) {event.node.id}"),) + + def on_entangle(self, event: EntangleEvent) -> Sequence[NoiseOp]: + r"""Add two-qubit depolarizing noise after entanglement. + + Returns + ------- + `collections.abc.Sequence`\[`NoiseOp`\] + A tuple containing DEPOLARIZE2 instruction, or empty if p2 <= 0. + """ + if self._p2 <= 0: + return () + return (RawStimOp(f"DEPOLARIZE2({self._p2}) {event.node0.id} {event.node1.id}"),) + + +class MeasurementFlipNoiseModel(NoiseModel): + """Measurement bit-flip noise using Stim's built-in measurement error. + + This model produces MX(p), MY(p), MZ(p) instead of MX, MY, MZ, + which adds measurement flip error with probability p. + + Parameters + ---------- + p : `float` + Probability of measurement result flip. + + Examples + -------- + >>> from graphqomb.noise_model import MeasurementFlipNoiseModel + >>> model = MeasurementFlipNoiseModel(p=0.001) + >>> # Use with stim_compile: + >>> # stim_compile(pattern, noise_models=[model]) + """ + + def __init__(self, p: float) -> None: + self._p = p + + def on_measure(self, event: MeasureEvent) -> Sequence[NoiseOp]: + r"""Add measurement flip error. + + Returns + ------- + `collections.abc.Sequence`\[`NoiseOp`\] + A tuple containing MeasurementFlip operation, or empty if p <= 0. + """ + if self._p <= 0: + return () + return (MeasurementFlip(p=self._p, target=event.node.id),) diff --git a/graphqomb/stim_compiler.py b/graphqomb/stim_compiler.py index b1eceab1..363ddaf9 100644 --- a/graphqomb/stim_compiler.py +++ b/graphqomb/stim_compiler.py @@ -7,205 +7,312 @@ from __future__ import annotations +import math from io import StringIO from typing import TYPE_CHECKING +from warnings import warn -import typing_extensions - -from graphqomb.command import TICK, E, M, N +from graphqomb.command import TICK, E, M, N, X, Z from graphqomb.common import Axis, MeasBasis, determine_pauli_axis +from graphqomb.noise_model import ( + Coordinate, + DepolarizingNoiseModel, + EntangleEvent, + IdleEvent, + MeasureEvent, + MeasurementFlip, + MeasurementFlipNoiseModel, + NodeInfo, + NoiseModel, + NoiseOp, + NoisePlacement, + PrepareEvent, + default_noise_placement, + noise_op_to_stim, +) if TYPE_CHECKING: - from collections.abc import Collection, Iterable, Mapping, Sequence + from collections.abc import Callable, Collection, Iterable, Mapping, Sequence from graphqomb.pattern import Pattern - from graphqomb.pauli_frame import PauliFrame -def _emit_qubit_coords( - stim_io: StringIO, - node: int, - coordinate: tuple[float, ...] | None, -) -> None: - r"""Emit QUBIT_COORDS instruction if coordinate is available. +class _StimCompiler: + def __init__( + self, + pattern: Pattern, + *, + emit_qubit_coords: bool, + noise_models: Sequence[NoiseModel], + tick_duration: float, + ) -> None: + self._pattern = pattern + self._pframe = pattern.pauli_frame + self._coord_lookup = pattern.coordinates + self._emit_qubit_coords = emit_qubit_coords + self._noise_models = noise_models + self._tick_duration = tick_duration + self._stim_io = StringIO() + self._meas_order: dict[int, int] = {} + self._rec_index = 0 + self._alive_nodes: set[int] = set(pattern.input_node_indices) + self._touched_nodes: set[int] = set() + self._tick = 0 + + def compile(self, logical_observables: Mapping[int, Collection[int]] | None) -> str: + self._emit_input_nodes() + self._process_commands() + total = self._rec_index + self._emit_detectors(total) + if logical_observables is not None: + self._emit_observables(logical_observables, total) + return self._stim_io.getvalue().strip() + + def _emit_detectors(self, total_measurements: int) -> None: + for checks in self._pframe.detector_groups(): + targets = [f"rec[{self._meas_order[c] - total_measurements}]" for c in checks] + self._stim_io.write(f"DETECTOR {' '.join(targets)}\n") + + def _emit_observables(self, logical_observables: Mapping[int, Collection[int]], total_measurements: int) -> None: + for log_idx, obs in logical_observables.items(): + group = self._pframe.logical_observables_group(obs) + targets = [f"rec[{self._meas_order[n] - total_measurements}]" for n in group] + self._stim_io.write(f"OBSERVABLE_INCLUDE({log_idx}) {' '.join(targets)}\n") + + def _emit_input_nodes(self) -> None: + coordinates = self._pattern.input_coordinates if self._emit_qubit_coords else None + for node in self._pattern.input_node_indices: + coord = coordinates.get(node) if coordinates else None + self._process_prepare(node, coord, is_input=True) + + def _process_commands(self) -> None: + for cmd in self._pattern: + if isinstance(cmd, N): + self._process_prepare(cmd.node, cmd.coordinate, is_input=False) + elif isinstance(cmd, E): + self._handle_entangle(cmd.nodes) + elif isinstance(cmd, M): + self._handle_measure(cmd.node, cmd.meas_basis) + elif isinstance(cmd, TICK): + self._handle_tick() + elif isinstance(cmd, (X, Z)): + cmd_name = type(cmd).__name__ + msg = ( + f"Unsupported command for stim compilation: {cmd_name}. X/Z correction commands are not supported." + ) + raise NotImplementedError(msg) + else: + msg = f"Unsupported command for stim compilation: {type(cmd).__name__}" + raise TypeError(msg) + + def _process_prepare(self, node: int, coordinate: tuple[float, ...] | None, *, is_input: bool) -> None: + event = PrepareEvent(time=self._tick, node=self._node_info(node), is_input=is_input) + ops = self._collect_noise_ops_from_models(lambda m: m.on_prepare(event)) + default_placement = default_noise_placement(event) + self._rec_index += self._emit_noise_ops(ops, NoisePlacement.BEFORE, default_placement) + + coord = coordinate if self._emit_qubit_coords else None + if coord is not None: + self._stim_io.write(f"QUBIT_COORDS({', '.join(str(c) for c in coord)}) {node}\n") + self._stim_io.write(f"RX {node}\n") + + self._rec_index += self._emit_noise_ops(ops, NoisePlacement.AFTER, default_placement) + self._alive_nodes.add(node) + self._touched_nodes.add(node) + + def _handle_entangle(self, nodes: tuple[int, int]) -> None: + n0, n1 = nodes + edge: tuple[int, int] = (n0, n1) if n0 < n1 else (n1, n0) + event = EntangleEvent(time=self._tick, node0=self._node_info(n0), node1=self._node_info(n1), edge=edge) + ops = self._collect_noise_ops_from_models(lambda m: m.on_entangle(event)) + default_placement = default_noise_placement(event) + self._rec_index += self._emit_noise_ops(ops, NoisePlacement.BEFORE, default_placement) + + self._stim_io.write(f"CZ {n0} {n1}\n") + self._touched_nodes.update(nodes) + self._rec_index += self._emit_noise_ops(ops, NoisePlacement.AFTER, default_placement) + + def _handle_measure(self, node: int, meas_basis: MeasBasis) -> None: + axis = determine_pauli_axis(meas_basis) + if axis is None: + msg = f"Unsupported measurement basis: {meas_basis.plane, meas_basis.angle}" + raise ValueError(msg) + event = MeasureEvent(time=self._tick, node=self._node_info(node), axis=axis) + ops = self._collect_noise_ops_from_models(lambda m: m.on_measure(event)) + + # Separate MeasurementFlip from other noise ops + meas_flip_probs: list[float] = [] + other_ops: list[NoiseOp] = [] + for op in ops: + if isinstance(op, MeasurementFlip): + if op.target != node: + msg = ( + f"MeasurementFlip target mismatch: measurement on node {node}, " + f"but flip targets node {op.target}" + ) + raise ValueError(msg) + meas_flip_probs.append(op.p) + else: + other_ops.append(op) + if not meas_flip_probs: + meas_flip_p = 0.0 + elif len(meas_flip_probs) == 1: + meas_flip_p = meas_flip_probs[0] + else: + meas_flip_p = 1.0 - math.prod(1.0 - p for p in meas_flip_probs) + meas_flip_p = min(max(meas_flip_p, 0.0), 1.0) + + default_placement = default_noise_placement(event) + self._rec_index += self._emit_noise_ops(other_ops, NoisePlacement.BEFORE, default_placement) + + # Emit measurement with optional flip probability + meas_instr = {Axis.X: "MX", Axis.Y: "MY", Axis.Z: "MZ"}[axis] + if meas_flip_p > 0.0: + self._stim_io.write(f"{meas_instr}({meas_flip_p}) {node}\n") + else: + self._stim_io.write(f"{meas_instr} {node}\n") + + self._meas_order[node] = self._rec_index + self._rec_index += 1 + self._alive_nodes.discard(node) + self._touched_nodes.add(node) + self._rec_index += self._emit_noise_ops(other_ops, NoisePlacement.AFTER, default_placement) + + def _handle_tick(self) -> None: + idle_nodes = sorted(self._alive_nodes - self._touched_nodes) + if idle_nodes and self._noise_models: + event = IdleEvent( + time=self._tick, + nodes=tuple(self._node_info(node) for node in idle_nodes), + duration=self._tick_duration, + ) + ops = self._collect_noise_ops_from_models(lambda m: m.on_idle(event)) + default_placement = default_noise_placement(event) + else: + ops = () + default_placement = NoisePlacement.AFTER + self._rec_index += self._emit_noise_ops(ops, NoisePlacement.BEFORE, default_placement) + self._stim_io.write("TICK\n") + self._rec_index += self._emit_noise_ops(ops, NoisePlacement.AFTER, default_placement) + self._touched_nodes.clear() + self._tick += 1 + + def _node_info(self, node: int) -> NodeInfo: + coord_raw = self._coord_lookup.get(node) + coord = Coordinate(tuple(coord_raw)) if coord_raw is not None else None + return NodeInfo(id=node, coord=coord) + + def _collect_noise_ops_from_models(self, get_ops: Callable[[NoiseModel], Iterable[NoiseOp]]) -> tuple[NoiseOp, ...]: + ops: list[NoiseOp] = [] + for model in self._noise_models: + ops.extend(get_ops(model)) + return tuple(ops) + + def _emit_noise_ops( + self, ops: Iterable[NoiseOp], placement: NoisePlacement, default_placement: NoisePlacement + ) -> int: + record_delta = 0 + for op in ops: + op_placement = op.placement + if op_placement is NoisePlacement.AUTO: + op_placement = default_placement + if op_placement is not placement: + continue + text, delta = noise_op_to_stim(op) + if text: + self._stim_io.write(f"{text}\n") + record_delta += delta + return record_delta + + +def _validate_probability_parameter(name: str, value: float) -> None: + """Validate that a probability parameter is within [0, 1]. Parameters ---------- - stim_io : `StringIO` - The output stream to write to. - node : `int` - The qubit index. - coordinate : `tuple`\[`float`, ...\] | `None` - The coordinate tuple (2D or 3D), or None if no coordinate. - """ - if coordinate is not None: - coords_str = ", ".join(str(c) for c in coordinate) - stim_io.write(f"QUBIT_COORDS({coords_str}) {node}\n") - - -def _prepare_nodes( - stim_io: StringIO, - nodes: int | Iterable[int], - p_depol_after_clifford: float, - coordinates: Mapping[int, tuple[float, ...]] | None = None, - emit_qubit_coords: bool = True, -) -> None: - r"""Prepare nodes in |+> state. + name : `str` + Parameter name used in error messages. + value : `float` + Probability value to validate. - This function handles both single nodes (N command) and multiple nodes - (input nodes initialization). - - Parameters - ---------- - stim_io : `StringIO` - The output stream to write to. - nodes : `int` | `collections.abc.Iterable`\[`int`\] - A single node index or an iterable of node indices to prepare. - p_depol_after_clifford : `float` - The probability of depolarization after Clifford gates. - coordinates : `collections.abc.Mapping`\[`int`, `tuple`\[`float`, ...\]\] | `None`, optional - Coordinates for nodes, by default None. - emit_qubit_coords : `bool`, optional - Whether to emit QUBIT_COORDS instructions, by default True. + Raises + ------ + ValueError + If ``value`` is outside the inclusive range [0, 1]. """ - if isinstance(nodes, int): - nodes = [nodes] - for node in nodes: - coord = coordinates.get(node) if coordinates else None - if emit_qubit_coords: - _emit_qubit_coords(stim_io, node, coord) - stim_io.write(f"RX {node}\n") - if p_depol_after_clifford > 0.0: - stim_io.write(f"DEPOLARIZE1({p_depol_after_clifford}) {node}\n") + if not 0.0 <= value <= 1.0: + msg = f"{name} must be within [0, 1], got {value}" + raise ValueError(msg) -def _entangle_nodes( - stim_io: StringIO, - nodes: tuple[int, int], - p_depol_after_clifford: float, -) -> None: - r"""Entangle two nodes with CZ gate (E command). +def _normalize_noise_models( + noise_models: Sequence[NoiseModel] | None, + p_depol_after_clifford: float | None, + p_before_meas_flip: float | None, +) -> tuple[NoiseModel, ...]: + r"""Normalize legacy noise parameters and new noise model API. Parameters ---------- - stim_io : `StringIO` - The output stream to write to. - nodes : `tuple`\[`int`, `int`\] - The pair of nodes to entangle. - p_depol_after_clifford : `float` - The probability of depolarization after Clifford gates. - """ - q1, q2 = nodes - stim_io.write(f"CZ {q1} {q2}\n") - if p_depol_after_clifford > 0.0: - stim_io.write(f"DEPOLARIZE2({p_depol_after_clifford}) {q1} {q2}\n") - + noise_models : `collections.abc.Sequence`\[`NoiseModel`\] | `None` + New noise model API input. + p_depol_after_clifford : `float` | `None` + Legacy depolarizing noise parameter. + p_before_meas_flip : `float` | `None` + Legacy measurement flip noise parameter. -def _measure_node( - stim_io: StringIO, - meas_basis: MeasBasis, - node: int, - p_before_meas_flip: float, -) -> None: - r"""Measure a node in the specified basis (M command). - - Parameters - ---------- - stim_io : `StringIO` - The output stream to write to. - meas_basis : `MeasBasis` - The measurement basis. - node : `int` - The node to measure. - p_before_meas_flip : `float` - The probability of flipping a measurement result before measurement. + Returns + ------- + `tuple`\[`NoiseModel`, ...\] + Normalized noise model sequence. Raises ------ ValueError - If an unsupported measurement basis is encountered. + If legacy parameters are invalid or mixed with ``noise_models``. """ - axis = determine_pauli_axis(meas_basis) - if axis is None: - msg = f"Unsupported measurement basis: {meas_basis.plane, meas_basis.angle}" - raise ValueError(msg) - - if axis == Axis.X: - if p_before_meas_flip > 0.0: - stim_io.write(f"Z_ERROR({p_before_meas_flip}) {node}\n") - stim_io.write(f"MX {node}\n") - elif axis == Axis.Y: - if p_before_meas_flip > 0.0: - stim_io.write(f"X_ERROR({p_before_meas_flip}) {node}\n") - stim_io.write(f"Z_ERROR({p_before_meas_flip}) {node}\n") - stim_io.write(f"MY {node}\n") - elif axis == Axis.Z: - if p_before_meas_flip > 0.0: - stim_io.write(f"X_ERROR({p_before_meas_flip}) {node}\n") - stim_io.write(f"MZ {node}\n") - else: - typing_extensions.assert_never(axis) - + used_legacy: list[str] = [] + legacy_models: list[NoiseModel] = [] -def _add_detectors( - stim_io: StringIO, - check_groups: Sequence[Collection[int]], - meas_order: Mapping[int, int], - total_measurements: int, -) -> None: - r"""Add detector declarations to the circuit. + if p_depol_after_clifford is not None: + _validate_probability_parameter("p_depol_after_clifford", p_depol_after_clifford) + used_legacy.append("p_depol_after_clifford") + if p_depol_after_clifford > 0.0: + legacy_models.append(DepolarizingNoiseModel(p1=p_depol_after_clifford, p2=p_depol_after_clifford)) - Parameters - ---------- - stim_io : `StringIO` - The output stream to write to. - check_groups : `collections.abc.Sequence`\[`collections.abc.Collection`\[`int`\]\] - The parity check groups for detectors. - meas_order : `collections.abc.Mapping`\[`int`, `int`\] - The measurement order lookup dict mapping node to measurement index. - total_measurements : `int` - The total number of measurements. - """ - for checks in check_groups: - targets = [f"rec[{meas_order[check] - total_measurements}]" for check in checks] - stim_io.write(f"DETECTOR {' '.join(targets)}\n") + if p_before_meas_flip is not None: + _validate_probability_parameter("p_before_meas_flip", p_before_meas_flip) + used_legacy.append("p_before_meas_flip") + if p_before_meas_flip > 0.0: + legacy_models.append(MeasurementFlipNoiseModel(p=p_before_meas_flip)) + if noise_models is not None and used_legacy: + legacy_args = ", ".join(used_legacy) + msg = f"{legacy_args} cannot be used together with noise_models." + raise ValueError(msg) -def _add_observables( - stim_io: StringIO, - logical_observables: Mapping[int, Collection[int]], - pframe: PauliFrame, - meas_order: Mapping[int, int], - total_measurements: int, -) -> None: - r"""Add logical observable declarations to the circuit. + if used_legacy: + warn( + "p_depol_after_clifford and p_before_meas_flip are deprecated in 0.3.0 and will be removed in 0.4.0. " + "Use noise_models with DepolarizingNoiseModel and MeasurementFlipNoiseModel instead.", + DeprecationWarning, + stacklevel=3, + ) - Parameters - ---------- - stim_io : `StringIO` - The output stream to write to. - logical_observables : `collections.abc.Mapping`\[`int`, `collections.abc.Collection`\[`int`\]\] - A mapping from logical observable index to a collection of node indices. - pframe : `PauliFrame` - The Pauli frame object. - meas_order : `collections.abc.Mapping`\[`int`, `int`\] - The measurement order lookup dict mapping node to measurement index. - total_measurements : `int` - The total number of measurements. - """ - for log_idx, obs in logical_observables.items(): - logical_observables_group = pframe.logical_observables_group(obs) - targets = [f"rec[{meas_order[node] - total_measurements}]" for node in logical_observables_group] - stim_io.write(f"OBSERVABLE_INCLUDE({log_idx}) {' '.join(targets)}\n") + if noise_models is not None: + return tuple(noise_models) + return tuple(legacy_models) -def stim_compile( +def stim_compile( # noqa: PLR0913 pattern: Pattern, logical_observables: Mapping[int, Collection[int]] | None = None, *, - p_depol_after_clifford: float = 0.0, - p_before_meas_flip: float = 0.0, + p_depol_after_clifford: float | None = None, + p_before_meas_flip: float | None = None, emit_qubit_coords: bool = True, + noise_models: Sequence[NoiseModel] | None = None, + tick_duration: float = 1.0, ) -> str: r"""Compile a pattern to stim format. @@ -215,13 +322,21 @@ def stim_compile( The pattern to compile. logical_observables : `collections.abc.Mapping`\[`int`, `collections.abc.Collection`\[`int`\]\], optional A mapping from logical observable index to a collection of node indices, by default None. - p_depol_after_clifford : `float`, optional - The probability of depolarization after a Clifford gate, by default 0.0. - p_before_meas_flip : `float`, optional - The probability of flipping a measurement result before measurement, by default 0.0. + p_depol_after_clifford : `float` | `None`, optional + Legacy depolarizing probability after Clifford gates. Deprecated in 0.3.0. + Use ``noise_models=[DepolarizingNoiseModel(p1=..., p2=...)]`` instead. + p_before_meas_flip : `float` | `None`, optional + Legacy measurement bit-flip probability. Deprecated in 0.3.0. + Use ``noise_models=[MeasurementFlipNoiseModel(p=...)]`` instead. emit_qubit_coords : `bool`, optional Whether to emit QUBIT_COORDS instructions for nodes with coordinates, by default True. + noise_models : `collections.abc.Sequence`\[`NoiseModel`\] | `None`, optional + Custom noise models for injecting Stim noise instructions, by default None. + Use `DepolarizingNoiseModel` for gate noise and `MeasurementFlipNoiseModel` + for measurement errors. + tick_duration : `float`, optional + Duration associated with each TICK for idle noise, by default 1.0. Returns ------- @@ -230,54 +345,40 @@ def stim_compile( Notes ----- + Deprecated parameters ``p_depol_after_clifford`` and ``p_before_meas_flip`` emit + a `DeprecationWarning` and will be removed in 0.4.0. + Legacy parameters cannot be mixed with ``noise_models``. Stim only supports Clifford gates, therefore this compiler only supports Pauli measurements (X, Y, Z basis) which correspond to Clifford operations. Non-Pauli measurements will raise a ValueError. + Patterns containing X or Z correction commands will raise a NotImplementedError. + + Examples + -------- + Basic compilation without noise: + + >>> # stim_str = stim_compile(pattern) + + With depolarizing and measurement flip noise: + + >>> from graphqomb.noise_model import DepolarizingNoiseModel, MeasurementFlipNoiseModel + >>> # stim_str = stim_compile( + >>> # pattern, + >>> # noise_models=[ + >>> # DepolarizingNoiseModel(p1=0.001, p2=0.01), + >>> # MeasurementFlipNoiseModel(p=0.001) + >>> # ] + >>> # ) """ - stim_io = StringIO() - pframe = pattern.pauli_frame - - # Build measurement order lookup dict - meas_order: dict[int, int] = {} - meas_idx = 0 - for cmd in pattern: - if isinstance(cmd, M): - meas_order[cmd.node] = meas_idx - meas_idx += 1 - total_measurements = meas_idx - - # Initialize input nodes (with coordinates if available) - _prepare_nodes( - stim_io, - pattern.input_node_indices, - p_depol_after_clifford, - coordinates=pattern.input_coordinates if emit_qubit_coords else None, + normalized_noise_models = _normalize_noise_models( + noise_models=noise_models, + p_depol_after_clifford=p_depol_after_clifford, + p_before_meas_flip=p_before_meas_flip, + ) + compiler = _StimCompiler( + pattern, emit_qubit_coords=emit_qubit_coords, + noise_models=normalized_noise_models, + tick_duration=tick_duration, ) - - # Process pattern commands - for cmd in pattern: - if isinstance(cmd, N): - _prepare_nodes( - stim_io, - cmd.node, - p_depol_after_clifford, - coordinates={cmd.node: cmd.coordinate} if cmd.coordinate else None, - emit_qubit_coords=emit_qubit_coords, - ) - elif isinstance(cmd, E): - _entangle_nodes(stim_io, cmd.nodes, p_depol_after_clifford) - elif isinstance(cmd, M): - _measure_node(stim_io, cmd.meas_basis, cmd.node, p_before_meas_flip) - elif isinstance(cmd, TICK): - stim_io.write("TICK\n") - - # Add detectors - check_groups = pframe.detector_groups() - _add_detectors(stim_io, check_groups, meas_order, total_measurements) - - # Add logical observables - if logical_observables is not None: - _add_observables(stim_io, logical_observables, pframe, meas_order, total_measurements) - - return stim_io.getvalue().strip() + return compiler.compile(logical_observables) diff --git a/pyproject.toml b/pyproject.toml index ea35c112..cf71c0e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,11 +115,14 @@ docstring-code-format = true "S101", # `assert` detected "SLF001", # private method "PLC2701", # private method - "PLR2004", # magic value in test(should be removed) + "PLR2004", # magic value in test + "PLR6301", # method could be static "D100", "D103", "D104", "D400", + "D417", # return not documented (numpy style) + "DOC201", # return not documented (pydoclint) ] "examples/*.py" = [ "T201", # print diff --git a/tests/test_circuit.py b/tests/test_circuit.py index b5252153..29afeb86 100644 --- a/tests/test_circuit.py +++ b/tests/test_circuit.py @@ -301,10 +301,10 @@ class MockCircuit(BaseCircuit): def num_qubits(self) -> int: return 1 - def instructions(self) -> list[Gate]: # noqa: PLR6301 + def instructions(self) -> list[Gate]: return [X(qubit=0)] - def unit_instructions(self) -> list[UnitGate]: # noqa: PLR6301 + def unit_instructions(self) -> list[UnitGate]: # Return a non-UnitGate object to trigger error return [X(qubit=0)] # type: ignore[list-item] diff --git a/tests/test_noise_model.py b/tests/test_noise_model.py new file mode 100644 index 00000000..0cd6d588 --- /dev/null +++ b/tests/test_noise_model.py @@ -0,0 +1,736 @@ +"""Tests for noise_model module.""" + +from __future__ import annotations + +import pytest + +from graphqomb.common import Axis +from graphqomb.noise_model import ( + PAULI_CHANNEL_2_ORDER, + Coordinate, + DepolarizingNoiseModel, + EntangleEvent, + HeraldedErase, + HeraldedPauliChannel1, + IdleEvent, + MeasureEvent, + MeasurementFlip, + MeasurementFlipNoiseModel, + NodeInfo, + NoiseModel, + NoisePlacement, + PauliChannel1, + PauliChannel2, + PrepareEvent, + RawStimOp, + default_noise_placement, + depolarize1_probs, + depolarize2_probs, + noise_op_to_stim, +) + +# ---- Coordinate Tests ---- + + +class TestCoordinate: + """Tests for Coordinate dataclass.""" + + def test_xy_with_2d(self) -> None: + """Test xy property with 2D coordinates.""" + coord = Coordinate((1.0, 2.0)) + assert coord.xy == (1.0, 2.0) + + def test_xy_with_3d(self) -> None: + """Test xy property with 3D coordinates.""" + coord = Coordinate((1.0, 2.0, 3.0)) + assert coord.xy == (1.0, 2.0) + + def test_xy_with_1d(self) -> None: + """Test xy property with 1D coordinates returns None.""" + coord = Coordinate((1.0,)) + assert coord.xy is None + + def test_xyz_with_3d(self) -> None: + """Test xyz property with 3D coordinates.""" + coord = Coordinate((1.0, 2.0, 3.0)) + assert coord.xyz == (1.0, 2.0, 3.0) + + def test_xyz_with_2d(self) -> None: + """Test xyz property with 2D coordinates returns None.""" + coord = Coordinate((1.0, 2.0)) + assert coord.xyz is None + + def test_xyz_with_4d(self) -> None: + """Test xyz property with 4D coordinates.""" + coord = Coordinate((1.0, 2.0, 3.0, 4.0)) + assert coord.xyz == (1.0, 2.0, 3.0) + + +# ---- NodeInfo Tests ---- + + +class TestNodeInfo: + """Tests for NodeInfo dataclass.""" + + def test_with_coordinate(self) -> None: + """Test NodeInfo with coordinate.""" + coord = Coordinate((1.0, 2.0)) + info = NodeInfo(id=5, coord=coord) + assert info.id == 5 + assert info.coord is not None + assert info.coord.xy == (1.0, 2.0) + + def test_without_coordinate(self) -> None: + """Test NodeInfo without coordinate.""" + info = NodeInfo(id=3, coord=None) + assert info.id == 3 + assert info.coord is None + + +# ---- Event Tests ---- + + +class TestPrepareEvent: + """Tests for PrepareEvent dataclass.""" + + def test_basic(self) -> None: + """Test basic PrepareEvent creation.""" + node = NodeInfo(id=0, coord=None) + event = PrepareEvent(time=0, node=node, is_input=True) + assert event.time == 0 + assert event.node.id == 0 + assert event.is_input is True + + +class TestEntangleEvent: + """Tests for EntangleEvent dataclass.""" + + def test_basic(self) -> None: + """Test basic EntangleEvent creation.""" + node0 = NodeInfo(id=0, coord=None) + node1 = NodeInfo(id=1, coord=None) + event = EntangleEvent(time=1, node0=node0, node1=node1, edge=(0, 1)) + assert event.time == 1 + assert event.node0.id == 0 + assert event.node1.id == 1 + assert event.edge == (0, 1) + + +class TestMeasureEvent: + """Tests for MeasureEvent dataclass.""" + + def test_basic(self) -> None: + """Test basic MeasureEvent creation.""" + node = NodeInfo(id=2, coord=None) + event = MeasureEvent(time=2, node=node, axis=Axis.X) + assert event.time == 2 + assert event.node.id == 2 + assert event.axis == Axis.X + + +class TestIdleEvent: + """Tests for IdleEvent dataclass.""" + + def test_basic(self) -> None: + """Test basic IdleEvent creation.""" + nodes = [NodeInfo(id=i, coord=None) for i in range(3)] + event = IdleEvent(time=1, nodes=nodes, duration=1.0) + assert event.time == 1 + assert len(event.nodes) == 3 + assert event.duration == 1.0 + + +# ---- NoiseOp Tests ---- + + +class TestPauliChannel1: + """Tests for PauliChannel1 noise operation.""" + + def test_basic(self) -> None: + """Test basic PauliChannel1 creation and conversion.""" + op = PauliChannel1(px=0.01, py=0.02, pz=0.03, targets=[0, 1]) + text, delta = noise_op_to_stim(op) + assert text == "PAULI_CHANNEL_1(0.01,0.02,0.03) 0 1" + assert delta == 0 + + def test_single_target(self) -> None: + """Test PauliChannel1 with single target.""" + op = PauliChannel1(px=0.1, py=0.0, pz=0.0, targets=[5]) + text, delta = noise_op_to_stim(op) + assert text == "PAULI_CHANNEL_1(0.1,0.0,0.0) 5" + assert delta == 0 + + def test_empty_targets(self) -> None: + """Test PauliChannel1 with empty targets returns empty string.""" + op = PauliChannel1(px=0.01, py=0.01, pz=0.01, targets=[]) + text, delta = noise_op_to_stim(op) + assert not text + assert delta == 0 + + def test_placement_before(self) -> None: + """Test PauliChannel1 with BEFORE placement.""" + op = PauliChannel1(px=0.01, py=0.0, pz=0.0, targets=[0], placement=NoisePlacement.BEFORE) + assert op.placement == NoisePlacement.BEFORE + + def test_targets_converted_to_tuple(self) -> None: + """Test that targets list is converted to tuple.""" + op = PauliChannel1(px=0.01, py=0.0, pz=0.0, targets=[0, 1, 2]) + assert isinstance(op.targets, tuple) + assert op.targets == (0, 1, 2) + + def test_negative_probability_raises(self) -> None: + """Test PauliChannel1 with negative probability raises ValueError.""" + op = PauliChannel1(px=-0.01, py=0.0, pz=0.0, targets=[0]) + with pytest.raises(ValueError, match=r"PauliChannel1\.px must be within \[0, 1\]"): + noise_op_to_stim(op) + + def test_probability_sum_greater_than_one_raises(self) -> None: + """Test PauliChannel1 with probabilities summing to >1 raises ValueError.""" + op = PauliChannel1(px=0.5, py=0.4, pz=0.2, targets=[0]) + with pytest.raises(ValueError, match=r"PauliChannel1 probabilities must sum to <= 1"): + noise_op_to_stim(op) + + +class TestPauliChannel2: + """Tests for PauliChannel2 noise operation.""" + + def test_with_mapping(self) -> None: + """Test PauliChannel2 with mapping probabilities.""" + op = PauliChannel2(probabilities={"ZZ": 0.01, "XX": 0.005}, targets=[(0, 1)]) + text, delta = noise_op_to_stim(op) + assert "PAULI_CHANNEL_2" in text + assert "0 1" in text + assert delta == 0 + + def test_with_sequence(self) -> None: + """Test PauliChannel2 with sequence probabilities.""" + probs = [0.0] * 15 + probs[14] = 0.01 # ZZ + op = PauliChannel2(probabilities=probs, targets=[(2, 3)]) + text, delta = noise_op_to_stim(op) + assert "PAULI_CHANNEL_2" in text + assert "2 3" in text + assert delta == 0 + + def test_multiple_pairs(self) -> None: + """Test PauliChannel2 with multiple target pairs.""" + op = PauliChannel2(probabilities={"ZZ": 0.01}, targets=[(0, 1), (2, 3)]) + text, delta = noise_op_to_stim(op) + assert "0 1 2 3" in text + assert delta == 0 + + def test_empty_targets(self) -> None: + """Test PauliChannel2 with empty targets returns empty string.""" + op = PauliChannel2(probabilities={"ZZ": 0.01}, targets=[]) + text, delta = noise_op_to_stim(op) + assert not text + assert delta == 0 + + def test_unknown_key_raises(self) -> None: + """Test PauliChannel2 with unknown key raises ValueError.""" + op = PauliChannel2(probabilities={"ZZZ": 0.01}, targets=[(0, 1)]) + with pytest.raises(ValueError, match="Unknown PAULI_CHANNEL_2 keys"): + noise_op_to_stim(op) + + def test_wrong_sequence_length_raises(self) -> None: + """Test PauliChannel2 with wrong sequence length raises ValueError.""" + op = PauliChannel2(probabilities=[0.01] * 10, targets=[(0, 1)]) + with pytest.raises(ValueError, match="PAULI_CHANNEL_2 expects 15 probabilities"): + noise_op_to_stim(op) + + def test_targets_converted_to_tuple(self) -> None: + """Test that targets list is converted to tuple of tuples.""" + op = PauliChannel2(probabilities={"ZZ": 0.01}, targets=[(0, 1), (2, 3)]) + assert isinstance(op.targets, tuple) + assert all(isinstance(pair, tuple) for pair in op.targets) + + def test_pauli_channel_2_order_has_15_elements(self) -> None: + """Test that PAULI_CHANNEL_2_ORDER has exactly 15 elements.""" + assert len(PAULI_CHANNEL_2_ORDER) == 15 + + def test_probability_out_of_range_raises(self) -> None: + """Test PauliChannel2 with out-of-range probability raises ValueError.""" + op = PauliChannel2(probabilities={"ZZ": 1.1}, targets=[(0, 1)]) + with pytest.raises(ValueError, match=r"PauliChannel2\.probabilities\[ZZ\] must be within \[0, 1\]"): + noise_op_to_stim(op) + + def test_probability_sum_greater_than_one_raises(self) -> None: + """Test PauliChannel2 with probabilities summing to >1 raises ValueError.""" + op = PauliChannel2(probabilities={"ZZ": 0.8, "XX": 0.3}, targets=[(0, 1)]) + with pytest.raises(ValueError, match=r"PauliChannel2 probabilities must sum to <= 1"): + noise_op_to_stim(op) + + +class TestDepolarize1Probs: + """Tests for depolarize1_probs utility function.""" + + def test_returns_3_elements(self) -> None: + """Test that depolarize1_probs returns px, py, pz.""" + probs = depolarize1_probs(0.03) + assert len(probs) == 3 + assert set(probs.keys()) == {"px", "py", "pz"} + + def test_each_probability_is_p_over_3(self) -> None: + """Test that each probability is p/3.""" + p = 0.03 + probs = depolarize1_probs(p) + expected = p / 3 + for key, prob in probs.items(): + assert prob == expected, f"Expected {expected} for {key}, got {prob}" + + def test_can_be_used_with_pauli_channel_1(self) -> None: + """Test that depolarize1_probs works with PauliChannel1.""" + probs = depolarize1_probs(0.03) + op = PauliChannel1(px=probs["px"], py=probs["py"], pz=probs["pz"], targets=[0]) + text, delta = noise_op_to_stim(op) + assert "PAULI_CHANNEL_1" in text + assert delta == 0 + + def test_invalid_probability_raises(self) -> None: + """Test that depolarize1_probs validates probability range.""" + with pytest.raises(ValueError, match=r"depolarize1_probs\.p must be within \[0, 1\]"): + depolarize1_probs(-0.1) + + +class TestDepolarize2Probs: + """Tests for depolarize2_probs utility function.""" + + def test_returns_15_elements(self) -> None: + """Test that depolarize2_probs returns 15 Pauli pairs.""" + probs = depolarize2_probs(0.15) + assert len(probs) == 15 + + def test_each_probability_is_p_over_15(self) -> None: + """Test that each probability is p/15.""" + p = 0.15 + probs = depolarize2_probs(p) + expected = p / 15 + for pauli, prob in probs.items(): + assert prob == expected, f"Expected {expected} for {pauli}, got {prob}" + + def test_contains_all_pauli_pairs(self) -> None: + """Test that all 15 Pauli pairs are present.""" + probs = depolarize2_probs(0.1) + for pauli in PAULI_CHANNEL_2_ORDER: + assert pauli in probs + + def test_can_be_used_with_pauli_channel_2(self) -> None: + """Test that depolarize2_probs works with PauliChannel2.""" + probs = depolarize2_probs(0.15) + op = PauliChannel2(probabilities=probs, targets=[(0, 1)]) + text, delta = noise_op_to_stim(op) + assert "PAULI_CHANNEL_2" in text + assert delta == 0 + + def test_invalid_probability_raises(self) -> None: + """Test that depolarize2_probs validates probability range.""" + with pytest.raises(ValueError, match=r"depolarize2_probs\.p must be within \[0, 1\]"): + depolarize2_probs(1.2) + + +class TestHeraldedPauliChannel1: + """Tests for HeraldedPauliChannel1 noise operation.""" + + def test_basic(self) -> None: + """Test basic HeraldedPauliChannel1 creation and conversion.""" + op = HeraldedPauliChannel1(pi=0.0, px=0.01, py=0.0, pz=0.0, targets=[5]) + text, delta = noise_op_to_stim(op) + assert text == "HERALDED_PAULI_CHANNEL_1(0.0,0.01,0.0,0.0) 5" + assert delta == 1 + + def test_multiple_targets(self) -> None: + """Test HeraldedPauliChannel1 with multiple targets.""" + op = HeraldedPauliChannel1(pi=0.1, px=0.0, py=0.0, pz=0.0, targets=[0, 1, 2]) + text, delta = noise_op_to_stim(op) + assert text == "HERALDED_PAULI_CHANNEL_1(0.1,0.0,0.0,0.0) 0 1 2" + assert delta == 3 + + def test_empty_targets(self) -> None: + """Test HeraldedPauliChannel1 with empty targets returns empty string.""" + op = HeraldedPauliChannel1(pi=0.0, px=0.01, py=0.0, pz=0.0, targets=[]) + text, delta = noise_op_to_stim(op) + assert not text + assert delta == 0 + + def test_targets_converted_to_tuple(self) -> None: + """Test that targets list is converted to tuple.""" + op = HeraldedPauliChannel1(pi=0.0, px=0.01, py=0.0, pz=0.0, targets=[0, 1]) + assert isinstance(op.targets, tuple) + + def test_probability_sum_greater_than_one_raises(self) -> None: + """Test HeraldedPauliChannel1 with probabilities summing to >1 raises ValueError.""" + op = HeraldedPauliChannel1(pi=0.4, px=0.4, py=0.3, pz=0.0, targets=[0]) + with pytest.raises(ValueError, match=r"HeraldedPauliChannel1 probabilities must sum to <= 1"): + noise_op_to_stim(op) + + +class TestHeraldedErase: + """Tests for HeraldedErase noise operation.""" + + def test_basic(self) -> None: + """Test basic HeraldedErase creation and conversion.""" + op = HeraldedErase(p=0.05, targets=[0]) + text, delta = noise_op_to_stim(op) + assert text == "HERALDED_ERASE(0.05) 0" + assert delta == 1 + + def test_multiple_targets(self) -> None: + """Test HeraldedErase with multiple targets.""" + op = HeraldedErase(p=0.1, targets=[0, 1, 2]) + text, delta = noise_op_to_stim(op) + assert text == "HERALDED_ERASE(0.1) 0 1 2" + assert delta == 3 + + def test_empty_targets(self) -> None: + """Test HeraldedErase with empty targets returns empty string.""" + op = HeraldedErase(p=0.05, targets=[]) + text, delta = noise_op_to_stim(op) + assert not text + assert delta == 0 + + def test_targets_converted_to_tuple(self) -> None: + """Test that targets list is converted to tuple.""" + op = HeraldedErase(p=0.05, targets=[0, 1, 2]) + assert isinstance(op.targets, tuple) + + def test_probability_out_of_range_raises(self) -> None: + """Test HeraldedErase with out-of-range probability raises ValueError.""" + op = HeraldedErase(p=1.1, targets=[0]) + with pytest.raises(ValueError, match=r"HeraldedErase\.p must be within \[0, 1\]"): + noise_op_to_stim(op) + + +class TestRawStimOp: + """Tests for RawStimOp noise operation.""" + + def test_basic(self) -> None: + """Test basic RawStimOp creation and conversion.""" + op = RawStimOp(text="X_ERROR(0.001) 0 1 2") + text, delta = noise_op_to_stim(op) + assert text == "X_ERROR(0.001) 0 1 2" + assert delta == 0 + + def test_with_record_delta(self) -> None: + """Test RawStimOp with custom record delta.""" + op = RawStimOp(text="MR 5", record_delta=1) + text, delta = noise_op_to_stim(op) + assert text == "MR 5" + assert delta == 1 + + def test_empty_text(self) -> None: + """Test RawStimOp with empty text.""" + op = RawStimOp(text="") + text, delta = noise_op_to_stim(op) + assert not text + assert delta == 0 + + def test_placement_before(self) -> None: + """Test RawStimOp with BEFORE placement.""" + op = RawStimOp(text="Z_ERROR(0.01) 0", placement=NoisePlacement.BEFORE) + assert op.placement == NoisePlacement.BEFORE + + def test_multiline_text_raises(self) -> None: + """Test RawStimOp rejects multiline text.""" + with pytest.raises(ValueError, match="single Stim instruction line"): + RawStimOp(text="M 0\nM 1", record_delta=0) + + def test_negative_record_delta_raises(self) -> None: + """Test RawStimOp rejects negative record_delta.""" + with pytest.raises(ValueError, match="must be non-negative"): + RawStimOp(text="X_ERROR(0.01) 0", record_delta=-1) + + def test_record_delta_mismatch_raises(self) -> None: + """Test RawStimOp validates record_delta consistency for inferable instructions.""" + with pytest.raises(ValueError, match="record_delta mismatch"): + RawStimOp(text="MR 5", record_delta=0) + + +# ---- NoiseModel Tests ---- + + +class TestNoiseModel: + """Tests for NoiseModel base class.""" + + def test_default_on_prepare_returns_empty(self) -> None: + """Test that default on_prepare returns empty list.""" + model = NoiseModel() + node = NodeInfo(id=0, coord=None) + event = PrepareEvent(time=0, node=node, is_input=False) + result = list(model.on_prepare(event)) + assert result == [] + + def test_default_on_entangle_returns_empty(self) -> None: + """Test that default on_entangle returns empty list.""" + model = NoiseModel() + node0 = NodeInfo(id=0, coord=None) + node1 = NodeInfo(id=1, coord=None) + event = EntangleEvent(time=0, node0=node0, node1=node1, edge=(0, 1)) + result = list(model.on_entangle(event)) + assert result == [] + + def test_default_on_measure_returns_empty(self) -> None: + """Test that default on_measure returns empty list.""" + model = NoiseModel() + node = NodeInfo(id=0, coord=None) + event = MeasureEvent(time=0, node=node, axis=Axis.X) + result = list(model.on_measure(event)) + assert result == [] + + def test_default_on_idle_returns_empty(self) -> None: + """Test that default on_idle returns empty list.""" + model = NoiseModel() + nodes = [NodeInfo(id=i, coord=None) for i in range(2)] + event = IdleEvent(time=0, nodes=nodes, duration=1.0) + result = list(model.on_idle(event)) + assert result == [] + + def test_default_noise_placement_for_measure_is_before(self) -> None: + """Test that default_noise_placement returns BEFORE for MeasureEvent.""" + node = NodeInfo(id=0, coord=None) + event = MeasureEvent(time=0, node=node, axis=Axis.X) + assert default_noise_placement(event) == NoisePlacement.BEFORE + + def test_default_noise_placement_for_prepare_is_after(self) -> None: + """Test that default_noise_placement returns AFTER for PrepareEvent.""" + node = NodeInfo(id=0, coord=None) + event = PrepareEvent(time=0, node=node, is_input=False) + assert default_noise_placement(event) == NoisePlacement.AFTER + + def test_default_noise_placement_for_entangle_is_after(self) -> None: + """Test that default_noise_placement returns AFTER for EntangleEvent.""" + node0 = NodeInfo(id=0, coord=None) + node1 = NodeInfo(id=1, coord=None) + event = EntangleEvent(time=0, node0=node0, node1=node1, edge=(0, 1)) + assert default_noise_placement(event) == NoisePlacement.AFTER + + def test_default_noise_placement_for_idle_is_after(self) -> None: + """Test that default_noise_placement returns AFTER for IdleEvent.""" + nodes = [NodeInfo(id=i, coord=None) for i in range(2)] + event = IdleEvent(time=0, nodes=nodes, duration=1.0) + assert default_noise_placement(event) == NoisePlacement.AFTER + + +class TestNoisePlacementAuto: + """Tests for AUTO placement behavior.""" + + def test_auto_is_default_for_pauli_channel_1(self) -> None: + """Test that AUTO is the default placement for PauliChannel1.""" + op = PauliChannel1(px=0.01, py=0.0, pz=0.0, targets=[0]) + assert op.placement == NoisePlacement.AUTO + + def test_auto_is_default_for_pauli_channel_2(self) -> None: + """Test that AUTO is the default placement for PauliChannel2.""" + op = PauliChannel2(probabilities={"ZZ": 0.01}, targets=[(0, 1)]) + assert op.placement == NoisePlacement.AUTO + + def test_auto_is_default_for_heralded_pauli_channel_1(self) -> None: + """Test that AUTO is the default placement for HeraldedPauliChannel1.""" + op = HeraldedPauliChannel1(pi=0.0, px=0.01, py=0.0, pz=0.0, targets=[0]) + assert op.placement == NoisePlacement.AUTO + + def test_auto_is_default_for_heralded_erase(self) -> None: + """Test that AUTO is the default placement for HeraldedErase.""" + op = HeraldedErase(p=0.01, targets=[0]) + assert op.placement == NoisePlacement.AUTO + + def test_auto_is_default_for_raw_stim_op(self) -> None: + """Test that AUTO is the default placement for RawStimOp.""" + op = RawStimOp(text="X_ERROR(0.01) 0") + assert op.placement == NoisePlacement.AUTO + + +class _CustomNoiseModel(NoiseModel): + """Test noise model that adds noise on all events.""" + + def __init__(self, p: float) -> None: + self.p = p + + def on_prepare(self, event: PrepareEvent) -> list[PauliChannel1]: + return [PauliChannel1(px=self.p, py=0.0, pz=0.0, targets=[event.node.id])] + + def on_entangle(self, event: EntangleEvent) -> list[PauliChannel2]: + return [PauliChannel2(probabilities={"ZZ": self.p}, targets=[(event.node0.id, event.node1.id)])] + + def on_measure(self, event: MeasureEvent) -> list[HeraldedPauliChannel1]: + return [HeraldedPauliChannel1(pi=0.0, px=self.p, py=0.0, pz=0.0, targets=[event.node.id])] + + def on_idle(self, event: IdleEvent) -> list[PauliChannel1]: + p = self.p * event.duration + targets = [n.id for n in event.nodes] + return [PauliChannel1(px=p, py=p, pz=p, targets=targets)] + + +class TestCustomNoiseModel: + """Tests for custom NoiseModel subclass.""" + + def test_on_prepare(self) -> None: + """Test custom on_prepare implementation.""" + model = _CustomNoiseModel(p=0.01) + node = NodeInfo(id=5, coord=None) + event = PrepareEvent(time=0, node=node, is_input=False) + ops = list(model.on_prepare(event)) + assert len(ops) == 1 + assert isinstance(ops[0], PauliChannel1) + assert ops[0].px == 0.01 + assert 5 in ops[0].targets + + def test_on_entangle(self) -> None: + """Test custom on_entangle implementation.""" + model = _CustomNoiseModel(p=0.02) + node0 = NodeInfo(id=0, coord=None) + node1 = NodeInfo(id=1, coord=None) + event = EntangleEvent(time=1, node0=node0, node1=node1, edge=(0, 1)) + ops = list(model.on_entangle(event)) + assert len(ops) == 1 + assert isinstance(ops[0], PauliChannel2) + + def test_on_measure(self) -> None: + """Test custom on_measure implementation.""" + model = _CustomNoiseModel(p=0.03) + node = NodeInfo(id=2, coord=None) + event = MeasureEvent(time=2, node=node, axis=Axis.Z) + ops = list(model.on_measure(event)) + assert len(ops) == 1 + assert isinstance(ops[0], HeraldedPauliChannel1) + + def test_on_idle(self) -> None: + """Test custom on_idle implementation.""" + model = _CustomNoiseModel(p=0.001) + nodes = [NodeInfo(id=i, coord=None) for i in range(3)] + event = IdleEvent(time=1, nodes=nodes, duration=2.0) + ops = list(model.on_idle(event)) + assert len(ops) == 1 + assert isinstance(ops[0], PauliChannel1) + assert ops[0].px == 0.002 # p * duration + + +# ---- MeasurementFlip Tests ---- + + +class TestMeasurementFlip: + """Tests for MeasurementFlip noise operation.""" + + def test_basic(self) -> None: + """Test basic MeasurementFlip creation.""" + op = MeasurementFlip(p=0.01, target=5) + assert op.p == 0.01 + assert op.target == 5 + assert op.placement == NoisePlacement.AUTO + + def test_to_stim_returns_empty(self) -> None: + """Test that noise_op_to_stim returns empty string for MeasurementFlip. + + MeasurementFlip is handled specially by modifying the measurement + instruction itself, so it should not emit a separate instruction. + """ + op = MeasurementFlip(p=0.01, target=0) + text, delta = noise_op_to_stim(op) + assert not text + assert delta == 0 + + def test_invalid_probability_raises(self) -> None: + """Test MeasurementFlip validates probability range.""" + op = MeasurementFlip(p=-0.1, target=0) + with pytest.raises(ValueError, match=r"MeasurementFlip\.p must be within \[0, 1\]"): + noise_op_to_stim(op) + + +# ---- DepolarizingNoiseModel Tests ---- + + +class TestDepolarizingNoiseModel: + """Tests for DepolarizingNoiseModel built-in noise model.""" + + def test_on_prepare_emits_depolarize1(self) -> None: + """Test that on_prepare returns DEPOLARIZE1 instruction.""" + model = DepolarizingNoiseModel(p1=0.01) + node = NodeInfo(id=5, coord=None) + event = PrepareEvent(time=0, node=node, is_input=False) + ops = list(model.on_prepare(event)) + assert len(ops) == 1 + text, _ = noise_op_to_stim(ops[0]) + assert text == "DEPOLARIZE1(0.01) 5" + + def test_on_entangle_emits_depolarize2(self) -> None: + """Test that on_entangle returns DEPOLARIZE2 instruction.""" + model = DepolarizingNoiseModel(p1=0.01) + node0 = NodeInfo(id=0, coord=None) + node1 = NodeInfo(id=1, coord=None) + event = EntangleEvent(time=1, node0=node0, node1=node1, edge=(0, 1)) + ops = list(model.on_entangle(event)) + assert len(ops) == 1 + text, _ = noise_op_to_stim(ops[0]) + assert text == "DEPOLARIZE2(0.01) 0 1" + + def test_p2_defaults_to_p1(self) -> None: + """Test that p2 defaults to p1 when not specified.""" + model = DepolarizingNoiseModel(p1=0.02) + node0 = NodeInfo(id=2, coord=None) + node1 = NodeInfo(id=3, coord=None) + event = EntangleEvent(time=1, node0=node0, node1=node1, edge=(2, 3)) + ops = list(model.on_entangle(event)) + text, _ = noise_op_to_stim(ops[0]) + assert "DEPOLARIZE2(0.02)" in text + + def test_different_p1_and_p2(self) -> None: + """Test DepolarizingNoiseModel with different p1 and p2.""" + model = DepolarizingNoiseModel(p1=0.001, p2=0.01) + # Check prepare uses p1 + node = NodeInfo(id=0, coord=None) + prepare_event = PrepareEvent(time=0, node=node, is_input=False) + ops = list(model.on_prepare(prepare_event)) + text, _ = noise_op_to_stim(ops[0]) + assert "DEPOLARIZE1(0.001)" in text + + # Check entangle uses p2 + node0 = NodeInfo(id=0, coord=None) + node1 = NodeInfo(id=1, coord=None) + entangle_event = EntangleEvent(time=1, node0=node0, node1=node1, edge=(0, 1)) + ops = list(model.on_entangle(entangle_event)) + text, _ = noise_op_to_stim(ops[0]) + assert "DEPOLARIZE2(0.01)" in text + + def test_zero_probability_returns_empty(self) -> None: + """Test that zero probability returns empty sequence.""" + model = DepolarizingNoiseModel(p1=0.0) + node = NodeInfo(id=0, coord=None) + event = PrepareEvent(time=0, node=node, is_input=False) + ops = list(model.on_prepare(event)) + assert len(ops) == 0 + + +# ---- MeasurementFlipNoiseModel Tests ---- + + +class TestMeasurementFlipNoiseModel: + """Tests for MeasurementFlipNoiseModel built-in noise model.""" + + def test_on_measure_returns_measurement_flip(self) -> None: + """Test that on_measure returns MeasurementFlip operation.""" + model = MeasurementFlipNoiseModel(p=0.01) + node = NodeInfo(id=5, coord=None) + event = MeasureEvent(time=0, node=node, axis=Axis.X) + ops = list(model.on_measure(event)) + assert len(ops) == 1 + assert isinstance(ops[0], MeasurementFlip) + assert ops[0].p == 0.01 + assert ops[0].target == 5 + + def test_zero_probability_returns_empty(self) -> None: + """Test that zero probability returns empty sequence.""" + model = MeasurementFlipNoiseModel(p=0.0) + node = NodeInfo(id=0, coord=None) + event = MeasureEvent(time=0, node=node, axis=Axis.Z) + ops = list(model.on_measure(event)) + assert len(ops) == 0 + + def test_different_axes(self) -> None: + """Test MeasurementFlipNoiseModel works with all measurement axes.""" + model = MeasurementFlipNoiseModel(p=0.005) + for axis in [Axis.X, Axis.Y, Axis.Z]: + node = NodeInfo(id=0, coord=None) + event = MeasureEvent(time=0, node=node, axis=axis) + ops = list(model.on_measure(event)) + assert len(ops) == 1 + assert isinstance(ops[0], MeasurementFlip) + assert ops[0].p == 0.005 diff --git a/tests/test_stim_compiler.py b/tests/test_stim_compiler.py index 7c3d783f..74bd5f8e 100644 --- a/tests/test_stim_compiler.py +++ b/tests/test_stim_compiler.py @@ -10,6 +10,14 @@ from graphqomb.command import TICK, E from graphqomb.common import Axis, AxisMeasBasis, Plane, PlannerMeasBasis, Sign from graphqomb.graphstate import GraphState +from graphqomb.noise_model import ( + DepolarizingNoiseModel, + HeraldedPauliChannel1, + MeasureEvent, + MeasurementFlip, + MeasurementFlipNoiseModel, + NoiseModel, +) from graphqomb.qompiler import qompile from graphqomb.schedule_solver import ScheduleConfig, Strategy from graphqomb.scheduler import Scheduler @@ -42,6 +50,7 @@ def create_simple_pattern_x_measurement() -> tuple[Pattern, int, int]: # X measurement: XY plane with angle 0 graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XY, 0.0)) graph.assign_meas_basis(meas_node, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(out_node, PlannerMeasBasis(Plane.XY, 0.0)) xflow = {in_node: {meas_node}, meas_node: {out_node}} pattern = qompile(graph, xflow) @@ -72,6 +81,7 @@ def create_simple_pattern_y_measurement() -> tuple[Pattern, int, int]: # Y measurement: XY plane with angle π/2 graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XY, math.pi / 2)) graph.assign_meas_basis(meas_node, PlannerMeasBasis(Plane.XY, math.pi / 2)) + graph.assign_meas_basis(out_node, PlannerMeasBasis(Plane.XY, math.pi / 2)) xflow = {in_node: {meas_node}, meas_node: {out_node}} pattern = qompile(graph, xflow) @@ -102,6 +112,7 @@ def create_simple_pattern_z_measurement() -> tuple[Pattern, int, int]: # Z measurement: XZ plane with angle 0 graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XZ, 0.0)) graph.assign_meas_basis(meas_node, PlannerMeasBasis(Plane.XZ, 0.0)) + graph.assign_meas_basis(out_node, PlannerMeasBasis(Plane.XZ, 0.0)) xflow = {in_node: {meas_node}, meas_node: {out_node}} pattern = qompile(graph, xflow) @@ -156,57 +167,105 @@ def test_stim_compile_z_measurement() -> None: def test_stim_compile_with_depolarization() -> None: - """Test that depolarization error is correctly inserted.""" + """Test that depolarization error is correctly inserted using DepolarizingNoiseModel.""" pattern, _, _ = create_simple_pattern_x_measurement() - stim_str = stim_compile(pattern, p_depol_after_clifford=0.01) + stim_str = stim_compile(pattern, noise_models=[DepolarizingNoiseModel(p1=0.01)]) # Check DEPOLARIZE instructions are present assert "DEPOLARIZE1(0.01)" in stim_str assert "DEPOLARIZE2(0.01)" in stim_str +def test_stim_compile_with_legacy_depolarization_parameter() -> None: + """Legacy depolarization parameter should still work with a deprecation warning.""" + pattern, _, _ = create_simple_pattern_x_measurement() + + with pytest.warns(DeprecationWarning, match="p_depol_after_clifford"): + stim_str = stim_compile(pattern, p_depol_after_clifford=0.01) + + assert "DEPOLARIZE1(0.01)" in stim_str + assert "DEPOLARIZE2(0.01)" in stim_str + + def test_stim_compile_with_measurement_errors_x() -> None: - """Test that X measurement errors are correctly inserted.""" + """Test that X measurement errors are correctly inserted using MeasurementFlipNoiseModel.""" + pattern, _, _ = create_simple_pattern_x_measurement() + + stim_str = stim_compile(pattern, noise_models=[MeasurementFlipNoiseModel(p=0.01)]) + + # For X measurement, error probability is attached to MX instruction + assert "MX(0.01)" in stim_str + + +def test_stim_compile_with_legacy_measurement_flip_parameter() -> None: + """Legacy measurement flip parameter should still work with a deprecation warning.""" pattern, _, _ = create_simple_pattern_x_measurement() - stim_str = stim_compile(pattern, p_before_meas_flip=0.01) + with pytest.warns(DeprecationWarning, match="p_before_meas_flip"): + stim_str = stim_compile(pattern, p_before_meas_flip=0.01) - # For X measurement, Z_ERROR should be inserted before MX - assert "Z_ERROR(0.01)" in stim_str - lines = stim_str.split("\n") - for i, line in enumerate(lines): - if "Z_ERROR(0.01)" in line and i + 1 < len(lines): - # Next non-empty line should be MX - next_line = lines[i + 1] - assert "MX" in next_line + assert "MX(0.01)" in stim_str def test_stim_compile_with_measurement_errors_y() -> None: - """Test that Y measurement errors are correctly inserted.""" + """Test that Y measurement errors are correctly inserted using MeasurementFlipNoiseModel.""" pattern, _, _ = create_simple_pattern_y_measurement() - stim_str = stim_compile(pattern, p_before_meas_flip=0.01) + stim_str = stim_compile(pattern, noise_models=[MeasurementFlipNoiseModel(p=0.01)]) - # For Y measurement, both X_ERROR and Z_ERROR should be inserted before MY - assert "X_ERROR(0.01)" in stim_str - assert "Z_ERROR(0.01)" in stim_str + # For Y measurement, error probability is attached to MY instruction + assert "MY(0.01)" in stim_str def test_stim_compile_with_measurement_errors_z() -> None: - """Test that Z measurement errors are correctly inserted.""" + """Test that Z measurement errors are correctly inserted using MeasurementFlipNoiseModel.""" pattern, _, _ = create_simple_pattern_z_measurement() - stim_str = stim_compile(pattern, p_before_meas_flip=0.01) + stim_str = stim_compile(pattern, noise_models=[MeasurementFlipNoiseModel(p=0.01)]) + + # For Z measurement, error probability is attached to MZ instruction + assert "MZ(0.01)" in stim_str + + +def test_stim_compile_combines_measurement_flip_probabilities() -> None: + """Multiple MeasurementFlip models should combine as independent events.""" + pattern, _, _ = create_simple_pattern_x_measurement() + + stim_str = stim_compile( + pattern, + noise_models=[ + MeasurementFlipNoiseModel(p=0.1), + MeasurementFlipNoiseModel(p=0.2), + ], + ) + + expected = 1.0 - (1.0 - 0.1) * (1.0 - 0.2) + mx_lines = [line for line in stim_str.splitlines() if line.startswith("MX(")] + assert mx_lines + for line in mx_lines: + prob = float(line.split("(", 1)[1].split(")", 1)[0]) + assert math.isclose(prob, expected) + - # For Z measurement, X_ERROR should be inserted before MZ - assert "X_ERROR(0.01)" in stim_str - lines = stim_str.split("\n") - for i, line in enumerate(lines): - if "X_ERROR(0.01)" in line and i + 1 < len(lines): - # Next non-empty line should be MZ - next_line = lines[i + 1] - assert "MZ" in next_line +def test_stim_compile_rejects_mixed_legacy_and_noise_models() -> None: + """Legacy noise parameters cannot be used together with noise_models.""" + pattern, _, _ = create_simple_pattern_x_measurement() + + with pytest.raises(ValueError, match="cannot be used together with noise_models"): + stim_compile( + pattern, + noise_models=[DepolarizingNoiseModel(p1=0.01)], + p_depol_after_clifford=0.01, + ) + + +def test_stim_compile_validates_legacy_probability_parameters() -> None: + """Legacy probability parameters should be validated before compilation.""" + pattern, _, _ = create_simple_pattern_x_measurement() + + with pytest.raises(ValueError, match="must be within \\[0, 1\\]"): + stim_compile(pattern, p_before_meas_flip=1.1) def test_stim_compile_with_detectors() -> None: @@ -225,6 +284,7 @@ def test_stim_compile_with_detectors() -> None: graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XY, 0.0)) graph.assign_meas_basis(meas_node, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(out_node, PlannerMeasBasis(Plane.XY, 0.0)) xflow = {in_node: {meas_node}, meas_node: {out_node}} # Add parity check groups @@ -239,6 +299,93 @@ def test_stim_compile_with_detectors() -> None: # This is valid behavior for certain graph configurations +class _HeraldedNoise(NoiseModel): + """Test noise model that adds heralded Pauli channel on measurements.""" + + def on_measure(self, event: MeasureEvent) -> list[HeraldedPauliChannel1]: + return [HeraldedPauliChannel1(0.0, 0.0, 0.0, 0.1, targets=[event.node.id])] + + +class _MismatchedMeasurementFlipNoise(NoiseModel): + """Test noise model with intentionally mismatched MeasurementFlip target.""" + + def on_measure(self, event: MeasureEvent) -> list[MeasurementFlip]: + return [MeasurementFlip(p=0.1, target=event.node.id + 999)] + + +def _parse_stim_measurements(stim_str: str) -> tuple[dict[int, int], int]: + """Parse stim string to extract measurement order and total record count.""" + rec_index = 0 + actual_meas_order: dict[int, int] = {} + for raw_line in stim_str.splitlines(): + stripped = raw_line.strip() + if not stripped: + continue + opcode = stripped.split()[0].split("(", 1)[0] + if opcode == "HERALDED_PAULI_CHANNEL_1": + targets = stripped.split(")", 1)[1].strip().split() + rec_index += len(targets) + elif opcode in {"MX", "MY", "MZ"}: + node = int(stripped.split()[1]) + actual_meas_order[node] = rec_index + rec_index += 1 + return actual_meas_order, rec_index + + +def _normalize_detector(line: str) -> str: + """Normalize detector line by sorting targets.""" + parts = line.strip().split() + if len(parts) <= 1: + return "DETECTOR" + targets = sorted(parts[1:]) + return f"DETECTOR {' '.join(targets)}" + + +def test_stim_compile_with_heralded_noise_updates_detectors() -> None: + """Heralded noise should shift rec indices used by detectors.""" + graph = GraphState() + in_node = graph.add_physical_node() + meas_node = graph.add_physical_node() + out_node = graph.add_physical_node() + + q_idx = 0 + graph.register_input(in_node, q_idx) + graph.register_output(out_node, q_idx) + + graph.add_physical_edge(in_node, meas_node) + graph.add_physical_edge(meas_node, out_node) + + graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(meas_node, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(out_node, PlannerMeasBasis(Plane.XY, 0.0)) + + xflow = {in_node: {meas_node}, meas_node: {out_node}} + parity_check_group = [{in_node}] + pattern = qompile(graph, xflow, parity_check_group=parity_check_group) + + stim_str = stim_compile(pattern, noise_models=[_HeraldedNoise()]) + + actual_meas_order, total_measurements = _parse_stim_measurements(stim_str) + + check_groups = pattern.pauli_frame.detector_groups() + expected_detectors = { + _normalize_detector( + f"DETECTOR {' '.join(f'rec[{actual_meas_order[check] - total_measurements}]' for check in checks)}" + ) + for checks in check_groups + } + actual_detectors = {_normalize_detector(line) for line in stim_str.splitlines() if line.startswith("DETECTOR")} + assert expected_detectors == actual_detectors + + +def test_stim_compile_rejects_mismatched_measurement_flip_target() -> None: + """MeasurementFlip target must match the current measurement node.""" + pattern, _, _ = create_simple_pattern_x_measurement() + + with pytest.raises(ValueError, match="MeasurementFlip target mismatch"): + stim_compile(pattern, noise_models=[_MismatchedMeasurementFlipNoise()]) + + def test_stim_compile_with_logical_observables() -> None: """Test OBSERVABLE_INCLUDE generation.""" pattern, meas_node, _ = create_simple_pattern_x_measurement() @@ -271,6 +418,7 @@ def test_stim_compile_unsupported_basis() -> None: # Non-Pauli measurement: XY plane with arbitrary angle graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XY, 0.1)) graph.assign_meas_basis(meas_node, PlannerMeasBasis(Plane.XY, 0.1)) + graph.assign_meas_basis(out_node, PlannerMeasBasis(Plane.XY, 0.1)) xflow = {in_node: {meas_node}, meas_node: {out_node}} pattern = qompile(graph, xflow) @@ -280,6 +428,26 @@ def test_stim_compile_unsupported_basis() -> None: stim_compile(pattern) +def test_stim_compile_unsupported_output_corrections() -> None: + """Test that X/Z correction commands in a pattern raise NotImplementedError.""" + graph = GraphState() + in_node = graph.add_physical_node() + out_node = graph.add_physical_node() + + q_idx = 0 + graph.register_input(in_node, q_idx) + graph.register_output(out_node, q_idx) + + graph.add_physical_edge(in_node, out_node) + graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XY, 0.0)) + + xflow = {in_node: {out_node}} + pattern = qompile(graph, xflow) + + with pytest.raises(NotImplementedError, match="X/Z correction commands are not supported"): + stim_compile(pattern) + + def test_stim_compile_empty_pattern() -> None: """Test compilation of minimal pattern.""" graph = GraphState() @@ -292,6 +460,7 @@ def test_stim_compile_empty_pattern() -> None: graph.add_physical_edge(in_node, out_node) graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(out_node, PlannerMeasBasis(Plane.XY, 0.0)) xflow = {in_node: {out_node}} pattern = qompile(graph, xflow) @@ -320,6 +489,7 @@ def test_stim_compile_axis_meas_basis() -> None: # Use AxisMeasBasis instead of PlannerMeasBasis graph.assign_meas_basis(in_node, AxisMeasBasis(Axis.X, Sign.PLUS)) graph.assign_meas_basis(meas_node, AxisMeasBasis(Axis.Y, Sign.PLUS)) + graph.assign_meas_basis(out_node, AxisMeasBasis(Axis.X, Sign.PLUS)) xflow = {in_node: {meas_node}, meas_node: {out_node}} pattern = qompile(graph, xflow) @@ -346,6 +516,7 @@ def test_stim_compile_with_tick_commands() -> None: graph.assign_meas_basis(node0, PlannerMeasBasis(Plane.XY, 0.0)) graph.assign_meas_basis(node1, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(node2, PlannerMeasBasis(Plane.XY, 0.0)) flow = {node0: {node1}, node1: {node2}} scheduler = Scheduler(graph, flow) @@ -435,6 +606,7 @@ def test_stim_compile_respects_manual_entangle_time() -> None: graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XY, 0.0)) graph.assign_meas_basis(mid_node, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(out_node, PlannerMeasBasis(Plane.XY, 0.0)) scheduler = Scheduler(graph, {in_node: {mid_node}, mid_node: {out_node}}) @@ -479,6 +651,7 @@ def test_stim_compile_with_coordinates() -> None: graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XY, 0.0)) graph.assign_meas_basis(mid_node, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(out_node, PlannerMeasBasis(Plane.XY, 0.0)) pattern = qompile(graph, {in_node: {mid_node}, mid_node: {out_node}}) stim_str = stim_compile(pattern) @@ -500,6 +673,7 @@ def test_stim_compile_with_3d_coordinates() -> None: graph.add_physical_edge(in_node, out_node) graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(out_node, PlannerMeasBasis(Plane.XY, 0.0)) pattern = qompile(graph, {in_node: {out_node}}) stim_str = stim_compile(pattern) @@ -519,6 +693,7 @@ def test_stim_compile_without_coordinates() -> None: graph.add_physical_edge(in_node, out_node) graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(out_node, PlannerMeasBasis(Plane.XY, 0.0)) pattern = qompile(graph, {in_node: {out_node}}) stim_str = stim_compile(pattern, emit_qubit_coords=False) @@ -541,6 +716,7 @@ def test_pattern_coordinates_property() -> None: graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XY, 0.0)) graph.assign_meas_basis(mid_node, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(out_node, PlannerMeasBasis(Plane.XY, 0.0)) pattern = qompile(graph, {in_node: {mid_node}, mid_node: {out_node}})