From fc0601b7dcff1f408ac557ab37f5ffe3306a4c23 Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Sun, 15 Feb 2026 10:24:46 +0100 Subject: [PATCH 1/2] Fix #324: use `match` for better readability and typing This commit replaces `if`/`elif`/`else` chains with `match` when the conditions perform equality comparisons on the same subject. These changes improve readibility and avoid `mypy` shortcomings related to type narrowing. Before this commit, PLR1714 suggested rewriting code such as: ```python if cmd.kind == CommandKind.X or cmd.kind == CommandKind.Z: ... ``` into: ```python if cmd.kind in {CommandKind.X, CommandKind.Z}: ... ``` However, `mypy` does not currently infer that `cmd` is an instance of `X` or `Z` within the `if` block when set membership is used. With this commit, the condition is expressed using an or-pattern, and type checking behaves as expected: ```python match cmd.kind: case CommandKind.X | CommandKind.Z: ... ``` --- graphix/clifford.py | 38 +++-- graphix/flow/core.py | 102 ++++++------ graphix/graphsim.py | 37 +++-- graphix/optimization.py | 191 ++++++++++----------- graphix/pattern.py | 266 +++++++++++++++--------------- graphix/pretty_print.py | 147 +++++++++-------- graphix/qasm3_exporter.py | 92 +++++------ graphix/sim/tensornet.py | 134 +++++++-------- graphix/simulator.py | 50 +++--- graphix/transpiler.py | 337 +++++++++++++++++++------------------- 10 files changed, 707 insertions(+), 687 deletions(-) diff --git a/graphix/clifford.py b/graphix/clifford.py index 9c500292..83bfa165 100644 --- a/graphix/clifford.py +++ b/graphix/clifford.py @@ -129,14 +129,15 @@ def measure(self, pauli: Pauli) -> Pauli: if pauli.symbol == I: return copy.deepcopy(pauli) table = CLIFFORD_MEASURE[self.value] - if pauli.symbol == Axis.X: - symbol, sign = table.x - elif pauli.symbol == Axis.Y: - symbol, sign = table.y - elif pauli.symbol == Axis.Z: - symbol, sign = table.z - else: - typing_extensions.assert_never(pauli.symbol) + match pauli.symbol: + case Axis.X: + symbol, sign = table.x + case Axis.Y: + symbol, sign = table.y + case Axis.Z: + symbol, sign = table.z + case _: + typing_extensions.assert_never(pauli.symbol) return pauli.unit * Pauli(symbol, ComplexUnit.from_properties(sign=sign)) def commute_domains(self, domains: Domains) -> Domains: @@ -151,16 +152,17 @@ def commute_domains(self, domains: Domains) -> Domains: s_domain = domains.s_domain.copy() t_domain = domains.t_domain.copy() for gate in self.hsz: - if gate == Clifford.I: - pass - elif gate == Clifford.H: - t_domain, s_domain = s_domain, t_domain - elif gate == Clifford.S: - t_domain ^= s_domain - elif gate == Clifford.Z: - pass - else: # pragma: no cover - raise RuntimeError(f"{gate} should be either I, H, S or Z.") + match gate: + case Clifford.I: + pass + case Clifford.H: + t_domain, s_domain = s_domain, t_domain + case Clifford.S: + t_domain ^= s_domain + case Clifford.Z: + pass + case _: # pragma: no cover + raise RuntimeError(f"{gate} should be either I, H, S or Z.") return Domains(s_domain, t_domain) diff --git a/graphix/flow/core.py b/graphix/flow/core.py index f4b1f632..ff650e09 100644 --- a/graphix/flow/core.py +++ b/graphix/flow/core.py @@ -600,39 +600,39 @@ def check_well_formed(self) -> None: correction_set=correction_set, past_and_present_nodes=past_and_present_nodes_y_meas, ) - - if meas == Plane.XY: - if not (node not in correction_set and node in odd_neighbors): - raise FlowPropositionError( - FlowPropositionErrorReason.P4, node=node, correction_set=correction_set - ) - elif meas == Plane.XZ: - if not (node in correction_set and node in odd_neighbors): - raise FlowPropositionError( - FlowPropositionErrorReason.P5, node=node, correction_set=correction_set - ) - elif meas == Plane.YZ: - if not (node in correction_set and node not in odd_neighbors): - raise FlowPropositionError( - FlowPropositionErrorReason.P6, node=node, correction_set=correction_set - ) - elif meas == Axis.X: - if node not in odd_neighbors: - raise FlowPropositionError( - FlowPropositionErrorReason.P7, node=node, correction_set=correction_set - ) - elif meas == Axis.Z: - if node not in correction_set: - raise FlowPropositionError( - FlowPropositionErrorReason.P8, node=node, correction_set=correction_set - ) - elif meas == Axis.Y: - if node not in closed_odd_neighbors: - raise FlowPropositionError( - FlowPropositionErrorReason.P9, node=node, correction_set=correction_set - ) - else: - assert_never(meas) + match meas: + case Plane.XY: + if not (node not in correction_set and node in odd_neighbors): + raise FlowPropositionError( + FlowPropositionErrorReason.P4, node=node, correction_set=correction_set + ) + case Plane.XZ: + if not (node in correction_set and node in odd_neighbors): + raise FlowPropositionError( + FlowPropositionErrorReason.P5, node=node, correction_set=correction_set + ) + case Plane.YZ: + if not (node in correction_set and node not in odd_neighbors): + raise FlowPropositionError( + FlowPropositionErrorReason.P6, node=node, correction_set=correction_set + ) + case Axis.X: + if node not in odd_neighbors: + raise FlowPropositionError( + FlowPropositionErrorReason.P7, node=node, correction_set=correction_set + ) + case Axis.Z: + if node not in correction_set: + raise FlowPropositionError( + FlowPropositionErrorReason.P8, node=node, correction_set=correction_set + ) + case Axis.Y: + if node not in closed_odd_neighbors: + raise FlowPropositionError( + FlowPropositionErrorReason.P9, node=node, correction_set=correction_set + ) + case _: + assert_never(meas) layer_idx -= 1 @@ -870,24 +870,24 @@ def check_well_formed(self) -> None: ) plane = self.node_measurement_label(node) - - if plane == Plane.XY: - if not (node not in correction_set and node in odd_neighbors): - raise FlowPropositionError( - FlowPropositionErrorReason.G3, node=node, correction_set=correction_set - ) - elif plane == Plane.XZ: - if not (node in correction_set and node in odd_neighbors): - raise FlowPropositionError( - FlowPropositionErrorReason.G4, node=node, correction_set=correction_set - ) - elif plane == Plane.YZ: - if not (node in correction_set and node not in odd_neighbors): - raise FlowPropositionError( - FlowPropositionErrorReason.G5, node=node, correction_set=correction_set - ) - else: - assert_never(plane) + match plane: + case Plane.XY: + if not (node not in correction_set and node in odd_neighbors): + raise FlowPropositionError( + FlowPropositionErrorReason.G3, node=node, correction_set=correction_set + ) + case Plane.XZ: + if not (node in correction_set and node in odd_neighbors): + raise FlowPropositionError( + FlowPropositionErrorReason.G4, node=node, correction_set=correction_set + ) + case Plane.YZ: + if not (node in correction_set and node not in odd_neighbors): + raise FlowPropositionError( + FlowPropositionErrorReason.G5, node=node, correction_set=correction_set + ) + case _: + assert_never(plane) layer_idx -= 1 diff --git a/graphix/graphsim.py b/graphix/graphsim.py index 186392c7..6bfcf1ed 100644 --- a/graphix/graphsim.py +++ b/graphix/graphsim.py @@ -89,16 +89,16 @@ def add_nodes_from( # pyright: ignore[reportIncompatibleMethodOverride] for k, v_ in mp.items(): dst = self.nodes[u] v = bool(v_) - # Need to use literal inside brackets - if k == "sign": - dst["sign"] = v - elif k == "hollow": - dst["hollow"] = v - elif k == "loop": - dst["loop"] = v - else: - msg = "Invalid node attribute." - raise ValueError(msg) + match k: + case "sign": + dst["sign"] = v + case "hollow": + dst["hollow"] = v + case "loop": + dst["loop"] = v + case _: + msg = "Invalid node attribute." + raise ValueError(msg) @typing_extensions.override def add_node( @@ -130,14 +130,15 @@ def apply_vops(self, vops: Mapping[int, Clifford]) -> None: """ for node, vop in vops.items(): for lc in reversed(vop.hsz): - if lc == Clifford.Z: - self.z(node) - elif lc == Clifford.H: - self.h(node) - elif lc == Clifford.S: - self.s(node) - else: - raise RuntimeError + match lc: + case Clifford.Z: + self.z(node) + case Clifford.H: + self.h(node) + case Clifford.S: + self.s(node) + case _: + raise RuntimeError def extract_vops(self) -> dict[int, Clifford]: """Apply local Clifford operators to the graph state from a dictionary. diff --git a/graphix/optimization.py b/graphix/optimization.py index b4dbf667..241000a4 100644 --- a/graphix/optimization.py +++ b/graphix/optimization.py @@ -178,53 +178,53 @@ def from_pattern(cls, pattern: Pattern) -> Self: pattern.check_runnability() for cmd in pattern: - if cmd.kind == CommandKind.N: - n_list.append(cmd) - elif cmd.kind == CommandKind.E: - for side in (0, 1): - i, j = cmd.nodes[side], cmd.nodes[1 - side] - if clifford_gate := c_dict.get(i): - _commute_clifford(clifford_gate, c_dict, i, j) - if s_domain_opt := x_dict.get(i): - _add_correction_domain(z_dict, j, s_domain_opt) - edge = frozenset(cmd.nodes) - e_set.symmetric_difference_update((edge,)) - elif cmd.kind == CommandKind.M: - new_cmd = None - if clifford_gate := c_dict.pop(cmd.node, None): - new_cmd = cmd.clifford(clifford_gate) - if t_domain_opt := z_dict.pop(cmd.node, None): + match cmd.kind: + case CommandKind.N: + n_list.append(cmd) + case CommandKind.E: + for side in (0, 1): + i, j = cmd.nodes[side], cmd.nodes[1 - side] + if clifford_gate := c_dict.get(i): + _commute_clifford(clifford_gate, c_dict, i, j) + if s_domain_opt := x_dict.get(i): + _add_correction_domain(z_dict, j, s_domain_opt) + edge = frozenset(cmd.nodes) + e_set.symmetric_difference_update((edge,)) + case CommandKind.M: + new_cmd = None + if clifford_gate := c_dict.pop(cmd.node, None): + new_cmd = cmd.clifford(clifford_gate) + if t_domain_opt := z_dict.pop(cmd.node, None): + if new_cmd is None: + new_cmd = copy(cmd) + # The original domain should not be mutated + new_cmd.t_domain = new_cmd.t_domain ^ t_domain_opt # noqa: PLR6104 + if s_domain_opt := x_dict.pop(cmd.node, None): + if new_cmd is None: + new_cmd = copy(cmd) + # The original domain should not be mutated + new_cmd.s_domain = new_cmd.s_domain ^ s_domain_opt # noqa: PLR6104 if new_cmd is None: - new_cmd = copy(cmd) - # The original domain should not be mutated - new_cmd.t_domain = new_cmd.t_domain ^ t_domain_opt # noqa: PLR6104 - if s_domain_opt := x_dict.pop(cmd.node, None): - if new_cmd is None: - new_cmd = copy(cmd) - # The original domain should not be mutated - new_cmd.s_domain = new_cmd.s_domain ^ s_domain_opt # noqa: PLR6104 - if new_cmd is None: - m_list.append(cmd) - else: - m_list.append(new_cmd) - # Use of `==` here for mypy - elif cmd.kind == CommandKind.X or cmd.kind == CommandKind.Z: # noqa: PLR1714 - if cmd.kind == CommandKind.X: - s_domain = cmd.domain - t_domain = set() - else: - s_domain = set() - t_domain = cmd.domain - domains = c_dict.get(cmd.node, Clifford.I).commute_domains(Domains(s_domain, t_domain)) - if domains.t_domain: - _add_correction_domain(z_dict, cmd.node, domains.t_domain) - if domains.s_domain: - _add_correction_domain(x_dict, cmd.node, domains.s_domain) - elif cmd.kind == CommandKind.C: - # Each pattern command is applied by left multiplication: if a clifford `C` - # has been already applied to a node, applying a clifford `C'` to the same - # node is equivalent to apply `C'C` to a fresh node. - c_dict[cmd.node] = cmd.clifford @ c_dict.get(cmd.node, Clifford.I) + m_list.append(cmd) + else: + m_list.append(new_cmd) + case CommandKind.X | CommandKind.Z: + if cmd.kind == CommandKind.X: + s_domain = cmd.domain + t_domain = set() + else: + s_domain = set() + t_domain = cmd.domain + domains = c_dict.get(cmd.node, Clifford.I).commute_domains(Domains(s_domain, t_domain)) + if domains.t_domain: + _add_correction_domain(z_dict, cmd.node, domains.t_domain) + if domains.s_domain: + _add_correction_domain(x_dict, cmd.node, domains.s_domain) + case CommandKind.C: + # Each pattern command is applied by left multiplication: if a clifford `C` + # has been already applied to a node, applying a clifford `C'` to the same + # node is equivalent to apply `C'C` to a fresh node. + c_dict[cmd.node] = cmd.clifford @ c_dict.get(cmd.node, Clifford.I) return cls( pattern.input_nodes, pattern.output_nodes, pattern.results, n_list, e_set, m_list, c_dict, z_dict, x_dict ) @@ -666,43 +666,43 @@ def incorporate_pauli_results(pattern: Pattern) -> Pattern: """Return an equivalent pattern where results from Pauli presimulation are integrated in corrections.""" result = graphix.pattern.Pattern(input_nodes=pattern.input_nodes) for cmd in pattern: - if cmd.kind == CommandKind.M: - s = _incorporate_pauli_results_in_domain(pattern.results, cmd.s_domain) - t = _incorporate_pauli_results_in_domain(pattern.results, cmd.t_domain) - if s or t: - if s: - apply_x, new_s_domain = s + match cmd.kind: + case CommandKind.M: + s = _incorporate_pauli_results_in_domain(pattern.results, cmd.s_domain) + t = _incorporate_pauli_results_in_domain(pattern.results, cmd.t_domain) + if s or t: + if s: + apply_x, new_s_domain = s + else: + apply_x = False + new_s_domain = cmd.s_domain + if t: + apply_z, new_t_domain = t + else: + apply_z = False + new_t_domain = cmd.t_domain + new_cmd = command.M(cmd.node, cmd.plane, cmd.angle, new_s_domain, new_t_domain) + if apply_x: + new_cmd = new_cmd.clifford(Clifford.X) + if apply_z: + new_cmd = new_cmd.clifford(Clifford.Z) + result.add(new_cmd) else: - apply_x = False - new_s_domain = cmd.s_domain - if t: - apply_z, new_t_domain = t + result.add(cmd) + case CommandKind.X | CommandKind.Z: + signal = _incorporate_pauli_results_in_domain(pattern.results, cmd.domain) + if signal: + apply_c, new_domain = signal + if new_domain: + cmd_cstr = command.X if cmd.kind == CommandKind.X else command.Z + result.add(cmd_cstr(cmd.node, new_domain)) + if apply_c: + c = Clifford.X if cmd.kind == CommandKind.X else Clifford.Z + result.add(command.C(cmd.node, c)) else: - apply_z = False - new_t_domain = cmd.t_domain - new_cmd = command.M(cmd.node, cmd.plane, cmd.angle, new_s_domain, new_t_domain) - if apply_x: - new_cmd = new_cmd.clifford(Clifford.X) - if apply_z: - new_cmd = new_cmd.clifford(Clifford.Z) - result.add(new_cmd) - else: + result.add(cmd) + case _: result.add(cmd) - # Use == for mypy - elif cmd.kind == CommandKind.X or cmd.kind == CommandKind.Z: # noqa: PLR1714 - signal = _incorporate_pauli_results_in_domain(pattern.results, cmd.domain) - if signal: - apply_c, new_domain = signal - if new_domain: - cmd_cstr = command.X if cmd.kind == CommandKind.X else command.Z - result.add(cmd_cstr(cmd.node, new_domain)) - if apply_c: - c = Clifford.X if cmd.kind == CommandKind.X else Clifford.Z - result.add(command.C(cmd.node, c)) - else: - result.add(cmd) - else: - result.add(cmd) result.reorder_output_nodes(pattern.output_nodes) return result @@ -740,21 +740,22 @@ def decompose_domain(cmd: Callable[[int, set[int]], command.Command], node: int, return True for cmd in pattern: - if cmd.kind == CommandKind.M: - replaced_s_domain = decompose_domain(command.X, cmd.node, cmd.s_domain) - replaced_t_domain = decompose_domain(command.Z, cmd.node, cmd.t_domain) - if replaced_s_domain or replaced_t_domain: - new_s_domain = set() if replaced_s_domain else cmd.s_domain - new_t_domain = set() if replaced_t_domain else cmd.t_domain - new_cmd = dataclasses.replace(cmd, s_domain=new_s_domain, t_domain=new_t_domain) - new_pattern.add(new_cmd) - continue - elif cmd.kind == CommandKind.X: - if decompose_domain(command.X, cmd.node, cmd.domain): - continue - elif cmd.kind == CommandKind.Z: - if decompose_domain(command.Z, cmd.node, cmd.domain): - continue + match cmd.kind: + case CommandKind.M: + replaced_s_domain = decompose_domain(command.X, cmd.node, cmd.s_domain) + replaced_t_domain = decompose_domain(command.Z, cmd.node, cmd.t_domain) + if replaced_s_domain or replaced_t_domain: + new_s_domain = set() if replaced_s_domain else cmd.s_domain + new_t_domain = set() if replaced_t_domain else cmd.t_domain + new_cmd = dataclasses.replace(cmd, s_domain=new_s_domain, t_domain=new_t_domain) + new_pattern.add(new_cmd) + continue + case CommandKind.X: + if decompose_domain(command.X, cmd.node, cmd.domain): + continue + case CommandKind.Z: + if decompose_domain(command.Z, cmd.node, cmd.domain): + continue new_pattern.add(cmd) new_pattern.reorder_output_nodes(pattern.output_nodes) diff --git a/graphix/pattern.py b/graphix/pattern.py index f52313a6..51c8b7ed 100644 --- a/graphix/pattern.py +++ b/graphix/pattern.py @@ -129,12 +129,13 @@ def add(self, cmd: Command) -> None: cmd : :class:`graphix.command.Command` MBQC command. """ - if cmd.kind == CommandKind.N: - self.__n_node += 1 - self.__output_nodes.append(cmd.node) - elif cmd.kind == CommandKind.M: - if cmd.node in self.__output_nodes: - self.__output_nodes.remove(cmd.node) + match cmd.kind: + case CommandKind.N: + self.__n_node += 1 + self.__output_nodes.append(cmd.node) + case CommandKind.M: + if cmd.node in self.__output_nodes: + self.__output_nodes.remove(cmd.node) self.__seq.append(cmd) def extend(self, *cmds: Command | Iterable[Command]) -> None: @@ -525,43 +526,44 @@ def expand_domain(domain: set[command.Node]) -> None: domain ^= signal_dict[node] for i, cmd in enumerate(self): - if cmd.kind == CommandKind.M: - s_domain = set(cmd.s_domain) - t_domain = set(cmd.t_domain) - expand_domain(s_domain) - expand_domain(t_domain) - plane = cmd.plane - if plane == Plane.XY: - # M^{XY,α} X^s Z^t = M^{XY,(-1)^s·α+tπ} - # = S^t M^{XY,(-1)^s·α} - # = S^t M^{XY,α} X^s - if t_domain: - signal_dict[cmd.node] = t_domain - t_domain = set() - elif plane == Plane.XZ: - # M^{XZ,α} X^s Z^t = M^{XZ,(-1)^t((-1)^s·α+sπ)} - # = M^{XZ,(-1)^{s+t}·α+(-1)^t·sπ} - # = M^{XZ,(-1)^{s+t}·α+sπ} (since (-1)^t·π ≡ π (mod 2π)) - # = S^s M^{XZ,(-1)^{s+t}·α} - # = S^s M^{XZ,α} Z^{s+t} - if s_domain: - signal_dict[cmd.node] = s_domain - t_domain ^= s_domain - s_domain = set() - elif plane == Plane.YZ and s_domain: - # M^{YZ,α} X^s Z^t = M^{YZ,(-1)^t·α+sπ)} - # = S^s M^{YZ,(-1)^t·α} - # = S^s M^{YZ,α} Z^t - signal_dict[cmd.node] = s_domain - s_domain = set() - if s_domain != cmd.s_domain or t_domain != cmd.t_domain: - self.__seq[i] = dataclasses.replace(cmd, s_domain=s_domain, t_domain=t_domain) - # Use of `==` here for mypy - elif cmd.kind == CommandKind.X or cmd.kind == CommandKind.Z: # noqa: PLR1714 - domain = set(cmd.domain) - expand_domain(domain) - if domain != cmd.domain: - self.__seq[i] = dataclasses.replace(cmd, domain=domain) + match cmd.kind: + case CommandKind.M: + s_domain = set(cmd.s_domain) + t_domain = set(cmd.t_domain) + expand_domain(s_domain) + expand_domain(t_domain) + plane = cmd.plane + match plane: + case Plane.XY: + # M^{XY,α} X^s Z^t = M^{XY,(-1)^s·α+tπ} + # = S^t M^{XY,(-1)^s·α} + # = S^t M^{XY,α} X^s + if t_domain: + signal_dict[cmd.node] = t_domain + t_domain = set() + case Plane.XZ: + # M^{XZ,α} X^s Z^t = M^{XZ,(-1)^t((-1)^s·α+sπ)} + # = M^{XZ,(-1)^{s+t}·α+(-1)^t·sπ} + # = M^{XZ,(-1)^{s+t}·α+sπ} (since (-1)^t·π ≡ π (mod 2π)) + # = S^s M^{XZ,(-1)^{s+t}·α} + # = S^s M^{XZ,α} Z^{s+t} + if s_domain: + signal_dict[cmd.node] = s_domain + t_domain ^= s_domain + s_domain = set() + case Plane.YZ if s_domain: + # M^{YZ,α} X^s Z^t = M^{YZ,(-1)^t·α+sπ)} + # = S^s M^{YZ,(-1)^t·α} + # = S^s M^{YZ,α} Z^t + signal_dict[cmd.node] = s_domain + s_domain = set() + if s_domain != cmd.s_domain or t_domain != cmd.t_domain: + self.__seq[i] = dataclasses.replace(cmd, s_domain=s_domain, t_domain=t_domain) + case CommandKind.X | CommandKind.Z: + domain = set(cmd.domain) + expand_domain(domain) + if domain != cmd.domain: + self.__seq[i] = dataclasses.replace(cmd, domain=domain) return signal_dict def _find_op_to_be_moved(self, op: CommandKind, rev: bool = False, skipnum: int = 0) -> int | None: @@ -865,11 +867,11 @@ def _extract_dependency(self) -> dict[int, set[int]]: nodes = self.extract_nodes() dependency: dict[int, set[int]] = {i: set() for i in nodes} for cmd in self.__seq: - if cmd.kind == CommandKind.M: - dependency[cmd.node] |= cmd.s_domain | cmd.t_domain - # Use of `==` here for mypy - elif cmd.kind == CommandKind.X or cmd.kind == CommandKind.Z: # noqa: PLR1714 - dependency[cmd.node] |= cmd.domain + match cmd.kind: + case CommandKind.M: + dependency[cmd.node] |= cmd.s_domain | cmd.t_domain + case CommandKind.X | CommandKind.Z: + dependency[cmd.node] |= cmd.domain return dependency @staticmethod @@ -1101,14 +1103,15 @@ def extract_graph(self) -> nx.Graph[int]: graph: nx.Graph[int] = nx.Graph() graph.add_nodes_from(self.input_nodes) for cmd in self.__seq: - if cmd.kind == CommandKind.N: - graph.add_node(cmd.node) - elif cmd.kind == CommandKind.E: - u, v = cmd.nodes - if graph.has_edge(u, v): - graph.remove_edge(u, v) - else: - graph.add_edge(u, v) + match cmd.kind: + case CommandKind.N: + graph.add_node(cmd.node) + case CommandKind.E: + u, v = cmd.nodes + if graph.has_edge(u, v): + graph.remove_edge(u, v) + else: + graph.add_edge(u, v) return graph def extract_nodes(self) -> set[int]: @@ -1151,19 +1154,20 @@ def extract_opengraph(self) -> OpenGraph[Measurement]: measurements: dict[int, Measurement] = {} for cmd in self.__seq: - if cmd.kind == CommandKind.N: - if cmd.state != BasicStates.PLUS: - raise PatternError( - f"Open graph extraction requires N commands to represent a |+⟩ state. Error found in {cmd}." - ) - nodes.add(cmd.node) - elif cmd.kind == CommandKind.E: - u, v = cmd.nodes - if u > v: - u, v = v, u - edges.symmetric_difference_update({(u, v)}) - elif cmd.kind == CommandKind.M: - measurements[cmd.node] = Measurement(cmd.angle, cmd.plane) + match cmd.kind: + case CommandKind.N: + if cmd.state != BasicStates.PLUS: + raise PatternError( + f"Open graph extraction requires N commands to represent a |+⟩ state. Error found in {cmd}." + ) + nodes.add(cmd.node) + case CommandKind.E: + u, v = cmd.nodes + if u > v: + u, v = v, u + edges.symmetric_difference_update({(u, v)}) + case CommandKind.M: + measurements[cmd.node] = Measurement(cmd.angle, cmd.plane) graph = nx.Graph(edges) graph.add_nodes_from(nodes) @@ -1282,10 +1286,11 @@ def max_space(self) -> int: nodes = len(self.input_nodes) max_nodes = nodes for cmd in self.__seq: - if cmd.kind == CommandKind.N: - nodes += 1 - elif cmd.kind == CommandKind.M: - nodes -= 1 + match cmd.kind: + case CommandKind.N: + nodes += 1 + case CommandKind.M: + nodes -= 1 max_nodes = max(nodes, max_nodes) return max_nodes @@ -1300,12 +1305,13 @@ def space_list(self) -> list[int]: nodes = 0 n_list = [] for cmd in self.__seq: - if cmd.kind == CommandKind.N: - nodes += 1 - n_list.append(nodes) - elif cmd.kind == CommandKind.M: - nodes -= 1 - n_list.append(nodes) + match cmd.kind: + case CommandKind.N: + nodes += 1 + n_list.append(nodes) + case CommandKind.M: + nodes -= 1 + n_list.append(nodes) return n_list @overload @@ -1561,39 +1567,39 @@ def check_measured(cmd: Command, node: int) -> None: raise RunnabilityError(cmd, node, RunnabilityErrorReason.NotYetMeasured) for cmd in self: - if cmd.kind == CommandKind.N: - if cmd.node in active: - raise RunnabilityError(cmd, cmd.node, RunnabilityErrorReason.AlreadyActive) - if cmd.node in measured: - raise RunnabilityError(cmd, cmd.node, RunnabilityErrorReason.AlreadyMeasured) - active.add(cmd.node) - elif cmd.kind == CommandKind.E: - n0, n1 = cmd.nodes - check_active(cmd, n0) - check_active(cmd, n1) - elif cmd.kind == CommandKind.M: - check_active(cmd, cmd.node) - if isinstance(cmd, command.M): - # `cmd.s_domain` and `cmd.t_domain` are only - # defined if the command is an actual `M` command, - # which may not be the case if the method is - # called with a pattern constructed with another - # implementation of `BaseM` (for instance, a blind - # pattern from Veriphix). - for domain in cmd.s_domain, cmd.t_domain: - if cmd.node in domain: - raise RunnabilityError(cmd, cmd.node, RunnabilityErrorReason.DomainSelfLoop) - for node in domain: - check_measured(cmd, node) - active.remove(cmd.node) - measured.add(cmd.node) - # Use of `==` here for mypy - elif cmd.kind == CommandKind.X or cmd.kind == CommandKind.Z: # noqa: PLR1714 - check_active(cmd, cmd.node) - for node in cmd.domain: - check_measured(cmd, node) - elif cmd.kind == CommandKind.C: - check_active(cmd, cmd.node) + match cmd.kind: + case CommandKind.N: + if cmd.node in active: + raise RunnabilityError(cmd, cmd.node, RunnabilityErrorReason.AlreadyActive) + if cmd.node in measured: + raise RunnabilityError(cmd, cmd.node, RunnabilityErrorReason.AlreadyMeasured) + active.add(cmd.node) + case CommandKind.E: + n0, n1 = cmd.nodes + check_active(cmd, n0) + check_active(cmd, n1) + case CommandKind.M: + check_active(cmd, cmd.node) + if isinstance(cmd, command.M): + # `cmd.s_domain` and `cmd.t_domain` are only + # defined if the command is an actual `M` command, + # which may not be the case if the method is + # called with a pattern constructed with another + # implementation of `BaseM` (for instance, a blind + # pattern from Veriphix). + for domain in cmd.s_domain, cmd.t_domain: + if cmd.node in domain: + raise RunnabilityError(cmd, cmd.node, RunnabilityErrorReason.DomainSelfLoop) + for node in domain: + check_measured(cmd, node) + active.remove(cmd.node) + measured.add(cmd.node) + case CommandKind.X | CommandKind.Z: + check_active(cmd, cmd.node) + for node in cmd.domain: + check_measured(cmd, node) + case CommandKind.C: + check_active(cmd, cmd.node) class PatternError(Exception): @@ -1685,15 +1691,16 @@ def measure_pauli(pattern: Pattern, *, ignore_pauli_with_deps: bool = False) -> # extract signals for adaptive angle. s_signal = 0 t_signal = 0 - if measurement_basis.axis == Axis.X: # X measurement is not affected by s_signal - t_signal = sum(results[j] for j in pattern_cmd.t_domain) - elif measurement_basis.axis == Axis.Y: - s_signal = sum(results[j] for j in pattern_cmd.s_domain) - t_signal = sum(results[j] for j in pattern_cmd.t_domain) - elif measurement_basis.axis == Axis.Z: # Z measurement is not affected by t_signal - s_signal = sum(results[j] for j in pattern_cmd.s_domain) - else: - assert_never(measurement_basis.axis) + match measurement_basis.axis: + case Axis.X: # X measurement is not affected by s_signal + t_signal = sum(results[j] for j in pattern_cmd.t_domain) + case Axis.Y: + s_signal = sum(results[j] for j in pattern_cmd.s_domain) + t_signal = sum(results[j] for j in pattern_cmd.t_domain) + case Axis.Z: # Z measurement is not affected by t_signal + s_signal = sum(results[j] for j in pattern_cmd.s_domain) + case _: + assert_never(measurement_basis.axis) if int(s_signal % 2) == 1: # equivalent to X byproduct graph_state.h(pattern_cmd.node) @@ -1702,14 +1709,15 @@ def measure_pauli(pattern: Pattern, *, ignore_pauli_with_deps: bool = False) -> if int(t_signal % 2) == 1: # equivalent to Z byproduct graph_state.z(pattern_cmd.node) basis = measurement_basis - if basis.axis == Axis.X: - measure = graph_state.measure_x - elif basis.axis == Axis.Y: - measure = graph_state.measure_y - elif basis.axis == Axis.Z: - measure = graph_state.measure_z - else: - assert_never(basis.axis) + match basis.axis: + case Axis.X: + measure = graph_state.measure_x + case Axis.Y: + measure = graph_state.measure_y + case Axis.Z: + measure = graph_state.measure_z + case _: + assert_never(basis.axis) if basis.sign == Sign.PLUS: results[pattern_cmd.node] = measure(pattern_cmd.node, choice=0) else: diff --git a/graphix/pretty_print.py b/graphix/pretty_print.py index f6d14736..b3cd7734 100644 --- a/graphix/pretty_print.py +++ b/graphix/pretty_print.py @@ -125,82 +125,85 @@ def command_to_str(cmd: command.Command, output: OutputFormat) -> str: The expected format. """ out = [cmd.kind.name] - - if cmd.kind == command.CommandKind.E: - u, v = cmd.nodes - if output == OutputFormat.LaTeX: - out.append(f"_{{{u},{v}}}") - elif output == OutputFormat.Unicode: - u_subscripts = str(u).translate(SUBSCRIPTS) - v_subscripts = str(v).translate(SUBSCRIPTS) - out.append(f"{u_subscripts}₋{v_subscripts}") - else: - out.append(f"({u},{v})") - elif cmd.kind == command.CommandKind.T: - pass - else: - # All other commands have a field `node` to print, together - # with some other arguments and/or domains. - arguments = [] - if cmd.kind == command.CommandKind.M: - if cmd.plane != Plane.XY: - arguments.append(cmd.plane.name) - # We use `SupportsFloat` since `isinstance(cmd.angle, float)` - # is `False` if `cmd.angle` is an integer. - if isinstance(cmd.angle, SupportsFloat): - angle = float(cmd.angle) - if not math.isclose(angle, 0.0): - arguments.append(angle_to_str(angle, output)) - else: - # If the angle is a symbolic expression, we can only delegate the printing - # TODO: We should have a mean to specify the format - arguments.append(str(cmd.angle * math.pi)) - elif cmd.kind == command.CommandKind.C: - arguments.append(str(cmd.clifford)) - # Use of `==` here for mypy - command_domain = ( - cmd.domain - if cmd.kind == command.CommandKind.X # noqa: PLR1714 - or cmd.kind == command.CommandKind.Z - or cmd.kind == command.CommandKind.S - else None - ) - if output == OutputFormat.LaTeX: - out.append(f"_{{{cmd.node}}}") - if arguments: - out.append(f"^{{{','.join(arguments)}}}") - elif output == OutputFormat.Unicode: - node_subscripts = str(cmd.node).translate(SUBSCRIPTS) - out.append(f"{node_subscripts}") - if arguments: - out.append(f"({','.join(arguments)})") - else: - arguments = [str(cmd.node), *arguments] + match cmd.kind: + case command.CommandKind.E: + u, v = cmd.nodes + match output: + case OutputFormat.LaTeX: + out.append(f"_{{{u},{v}}}") + case OutputFormat.Unicode: + u_subscripts = str(u).translate(SUBSCRIPTS) + v_subscripts = str(v).translate(SUBSCRIPTS) + out.append(f"{u_subscripts}₋{v_subscripts}") + case _: + out.append(f"({u},{v})") + case command.CommandKind.T: + pass + case _: + # All other commands have a field `node` to print, together + # with some other arguments and/or domains. + arguments = [] + match cmd.kind: + case command.CommandKind.M: + if cmd.plane != Plane.XY: + arguments.append(cmd.plane.name) + # We use `SupportsFloat` since `isinstance(cmd.angle, float)` + # is `False` if `cmd.angle` is an integer. + if isinstance(cmd.angle, SupportsFloat): + angle = float(cmd.angle) + if not math.isclose(angle, 0.0): + arguments.append(angle_to_str(angle, output)) + else: + # If the angle is a symbolic expression, we can only delegate the printing + # TODO: We should have a mean to specify the format + arguments.append(str(cmd.angle * math.pi)) + case command.CommandKind.C: + arguments.append(str(cmd.clifford)) + # Use of `==` here for mypy + command_domain = ( + cmd.domain + if cmd.kind == command.CommandKind.X # noqa: PLR1714 + or cmd.kind == command.CommandKind.Z + or cmd.kind == command.CommandKind.S + else None + ) + match output: + case OutputFormat.LaTeX: + out.append(f"_{{{cmd.node}}}") + if arguments: + out.append(f"^{{{','.join(arguments)}}}") + case OutputFormat.Unicode: + node_subscripts = str(cmd.node).translate(SUBSCRIPTS) + out.append(f"{node_subscripts}") + if arguments: + out.append(f"({','.join(arguments)})") + case _: + arguments = [str(cmd.node), *arguments] + if command_domain: + arguments.append(domain_to_str(command_domain)) + command_domain = None + out.append(f"({','.join(arguments)})") + if cmd.kind == command.CommandKind.M and (cmd.s_domain or cmd.t_domain): + out = ["[", *out, "]"] + if cmd.t_domain: + if output == OutputFormat.LaTeX: + t_domain_str = f"{{}}_{{{','.join(str(node) for node in cmd.t_domain)}}}" + elif output == OutputFormat.Unicode: + t_domain_subscripts = [str(node).translate(SUBSCRIPTS) for node in cmd.t_domain] + t_domain_str = "₊".join(t_domain_subscripts) + else: + t_domain_str = f"{{{','.join(str(node) for node in cmd.t_domain)}}}" + out = [t_domain_str, *out] + command_domain = cmd.s_domain if command_domain: - arguments.append(domain_to_str(command_domain)) - command_domain = None - out.append(f"({','.join(arguments)})") - if cmd.kind == command.CommandKind.M and (cmd.s_domain or cmd.t_domain): - out = ["[", *out, "]"] - if cmd.t_domain: if output == OutputFormat.LaTeX: - t_domain_str = f"{{}}_{{{','.join(str(node) for node in cmd.t_domain)}}}" + domain_str = f"^{{{','.join(str(node) for node in command_domain)}}}" elif output == OutputFormat.Unicode: - t_domain_subscripts = [str(node).translate(SUBSCRIPTS) for node in cmd.t_domain] - t_domain_str = "₊".join(t_domain_subscripts) + domain_superscripts = [str(node).translate(SUPERSCRIPTS) for node in command_domain] + domain_str = "⁺".join(domain_superscripts) else: - t_domain_str = f"{{{','.join(str(node) for node in cmd.t_domain)}}}" - out = [t_domain_str, *out] - command_domain = cmd.s_domain - if command_domain: - if output == OutputFormat.LaTeX: - domain_str = f"^{{{','.join(str(node) for node in command_domain)}}}" - elif output == OutputFormat.Unicode: - domain_superscripts = [str(node).translate(SUPERSCRIPTS) for node in command_domain] - domain_str = "⁺".join(domain_superscripts) - else: - domain_str = f"{{{','.join(str(node) for node in command_domain)}}}" - out.append(domain_str) + domain_str = f"{{{','.join(str(node) for node in command_domain)}}}" + out.append(domain_str) return f"{''.join(out)}" diff --git a/graphix/qasm3_exporter.py b/graphix/qasm3_exporter.py index 71efd101..b23303c9 100644 --- a/graphix/qasm3_exporter.py +++ b/graphix/qasm3_exporter.py @@ -184,61 +184,57 @@ def command_to_qasm3_lines(cmd: Command) -> Iterator[str]: """ yield f"// {cmd}\n" - if cmd.kind == CommandKind.N: - yield f"qubit q{cmd.node};\n" - yield from state_to_qasm3_lines(cmd.node, cmd.state) - - elif cmd.kind == CommandKind.E: - n0, n1 = cmd.nodes - yield f"cz q{n0}, q{n1};\n" - - elif cmd.kind == CommandKind.M: - yield from domain_to_qasm3_lines(cmd.s_domain, f"x q{cmd.node}") - yield from domain_to_qasm3_lines(cmd.t_domain, f"z q{cmd.node}") - if cmd.plane == Plane.XY: - yield f"h q{cmd.node};\n" - if cmd.angle != 0: + match cmd.kind: + case CommandKind.N: + yield f"qubit q{cmd.node};\n" + yield from state_to_qasm3_lines(cmd.node, cmd.state) + case CommandKind.E: + n0, n1 = cmd.nodes + yield f"cz q{n0}, q{n1};\n" + case CommandKind.M: + yield from domain_to_qasm3_lines(cmd.s_domain, f"x q{cmd.node}") + yield from domain_to_qasm3_lines(cmd.t_domain, f"z q{cmd.node}") if cmd.plane == Plane.XY: - gate = "rx" - angle = -cmd.angle - elif cmd.plane == Plane.XZ: - gate = "ry" - angle = -cmd.angle - elif cmd.plane == Plane.YZ: - gate = "rx" - angle = cmd.angle - else: - assert_never(cmd.plane) - rad_angle = angle_to_qasm3(angle) - yield f"{gate}({rad_angle}) q{cmd.node};\n" - yield f"bit c{cmd.node};\n" - yield f"c{cmd.node} = measure q{cmd.node};\n" - - elif cmd.kind == CommandKind.X: - yield from domain_to_qasm3_lines(cmd.domain, f"x q{cmd.node}") - - elif cmd.kind == CommandKind.Z: - yield from domain_to_qasm3_lines(cmd.domain, f"z q{cmd.node}") - - elif cmd.kind == CommandKind.C: - for op in cmd.clifford.qasm3: - yield str(op) + " q" + str(cmd.node) + ";\n" - - else: - raise ValueError(f"invalid command {cmd}") + yield f"h q{cmd.node};\n" + if cmd.angle != 0: + if cmd.plane == Plane.XY: + gate = "rx" + angle = -cmd.angle + elif cmd.plane == Plane.XZ: + gate = "ry" + angle = -cmd.angle + elif cmd.plane == Plane.YZ: + gate = "rx" + angle = cmd.angle + else: + assert_never(cmd.plane) + rad_angle = angle_to_qasm3(angle) + yield f"{gate}({rad_angle}) q{cmd.node};\n" + yield f"bit c{cmd.node};\n" + yield f"c{cmd.node} = measure q{cmd.node};\n" + case CommandKind.X: + yield from domain_to_qasm3_lines(cmd.domain, f"x q{cmd.node}") + case CommandKind.Z: + yield from domain_to_qasm3_lines(cmd.domain, f"z q{cmd.node}") + case CommandKind.C: + for op in cmd.clifford.qasm3: + yield str(op) + " q" + str(cmd.node) + ";\n" + case _: + raise ValueError(f"invalid command {cmd}") yield "\n" def state_to_qasm3_lines(node: int, state: State) -> Iterator[str]: """Convert initial state into OpenQASM 3.0 statement.""" - if state == BasicStates.ZERO: - yield f"// qubit {node} prepared in |0⟩: do nothing\n" - elif state == BasicStates.PLUS: - yield f"// qubit {node} prepared in |+⟩\n" - yield f"h q{node};\n" - else: - raise ValueError("QASM3 conversion only supports |0⟩ or |+⟩ initial states.") + match state: + case BasicStates.ZERO: + yield f"// qubit {node} prepared in |0⟩: do nothing\n" + case BasicStates.PLUS: + yield f"// qubit {node} prepared in |+⟩\n" + yield f"h q{node};\n" + case _: + raise ValueError("QASM3 conversion only supports |0⟩ or |+⟩ initial states.") def domain_to_qasm3_lines(domain: Iterable[int], cmd: str) -> Iterator[str]: diff --git a/graphix/sim/tensornet.py b/graphix/sim/tensornet.py index e86f5913..6b20a37b 100644 --- a/graphix/sim/tensornet.py +++ b/graphix/sim/tensornet.py @@ -110,26 +110,27 @@ def add_qubit(self, index: int, state: PrepareState = "plus") -> None: """ ind = gen_str() tag = str(index) - if state == "plus": - vec = BasicStates.PLUS.to_statevector() - elif state == "minus": - vec = BasicStates.MINUS.to_statevector() - elif state == "zero": - vec = BasicStates.ZERO.to_statevector() - elif state == "one": - vec = BasicStates.ONE.to_statevector() - elif state == "iplus": - vec = BasicStates.PLUS_I.to_statevector() - elif state == "iminus": - vec = BasicStates.MINUS_I.to_statevector() - else: - if isinstance(state, str): - raise TypeError(f"Unknown state: {state}") - if state.shape != (2,): - raise ValueError("state must be 2-element np.ndarray") - if not np.isclose(np.linalg.norm(state), 1): - raise ValueError("state must be normalized") - vec = state + match state: + case "plus": + vec = BasicStates.PLUS.to_statevector() + case "minus": + vec = BasicStates.MINUS.to_statevector() + case "zero": + vec = BasicStates.ZERO.to_statevector() + case "one": + vec = BasicStates.ONE.to_statevector() + case "iplus": + vec = BasicStates.PLUS_I.to_statevector() + case "iminus": + vec = BasicStates.MINUS_I.to_statevector() + case _: + if isinstance(state, str): + raise TypeError(f"Unknown state: {state}") + if state.shape != (2,): + raise ValueError("state must be 2-element np.ndarray") + if not np.isclose(np.linalg.norm(state), 1): + raise ValueError("state must be normalized") + vec = state tsr = Tensor(vec, [ind], [tag, "Open"]) self.add_tensor(tsr) self._dangling[tag] = ind @@ -621,20 +622,21 @@ def __init__( raise NotImplementedError(msg) if branch_selector is None: branch_selector = RandomBranchSelector() - if graph_prep in {"parallel", "sequential"}: - pass - elif graph_prep == "opt": - graph_prep = "parallel" - warnings.warn( - f"graph preparation strategy '{graph_prep}' is deprecated and will be replaced by 'parallel'", - stacklevel=1, - ) - elif graph_prep == "auto": - max_degree = pattern.compute_max_degree() - # "parallel" does not support non standard pattern - graph_prep = "sequential" if max_degree > 5 or not pattern.is_standard() else "parallel" - else: - raise ValueError(f"Invalid graph preparation strategy: {graph_prep}") + match graph_prep: + case "parallel" | "sequential": + pass + case "opt": + graph_prep = "parallel" + warnings.warn( + f"graph preparation strategy '{graph_prep}' is deprecated and will be replaced by 'parallel'", + stacklevel=1, + ) + case "auto": + max_degree = pattern.compute_max_degree() + # "parallel" does not support non standard pattern + graph_prep = "sequential" if max_degree > 5 or not pattern.is_standard() else "parallel" + case _: + raise ValueError(f"Invalid graph preparation strategy: {graph_prep}") results = deepcopy(pattern.results) if graph_prep == "parallel": if not pattern.is_standard(): @@ -690,10 +692,11 @@ def add_nodes(self, nodes: Sequence[int], data: Data = BasicStates.PLUS) -> None raise NotImplementedError( "TensorNetworkBackend currently only supports |+> input state (see https://github.com/TeamGraphix/graphix/issues/167)." ) - if self.graph_prep == "sequential": - self.state.add_qubits(nodes) - elif self.graph_prep == "opt": - pass + match self.graph_prep: + case "sequential": + self.state.add_qubits(nodes) + case "opt": + pass @override def entangle_nodes(self, edge: tuple[int, int]) -> None: @@ -704,33 +707,34 @@ def entangle_nodes(self, edge: tuple[int, int]) -> None: edge : tuple of int edge specifies two target nodes of the CZ gate. """ - if self.graph_prep == "sequential": - old_inds = [self.state._dangling[str(node)] for node in edge] - tids = self.state._get_tids_from_inds(old_inds, which="any") - tensors = [self.state.tensor_map[tid] for tid in tids] - new_inds = [gen_str() for _ in range(3)] - - # retag dummy indices - for i in range(2): - tensors[i].retag({"Open": "Close"}, inplace=True) - self.state._dangling[str(edge[i])] = new_inds[i] - cz_tn = TensorNetwork( - [ - qtn.Tensor( - self._decomposed_cz[0], - [new_inds[0], old_inds[0], new_inds[2]], - [str(edge[0]), "CZ", "Open"], - ), - qtn.Tensor( - self._decomposed_cz[1], - [new_inds[2], new_inds[1], old_inds[1]], - [str(edge[1]), "CZ", "Open"], - ), - ] - ) - self.state.add_tensor_network(cz_tn) - elif self.graph_prep == "opt": - pass + match self.graph_prep: + case "sequential": + old_inds = [self.state._dangling[str(node)] for node in edge] + tids = self.state._get_tids_from_inds(old_inds, which="any") + tensors = [self.state.tensor_map[tid] for tid in tids] + new_inds = [gen_str() for _ in range(3)] + + # retag dummy indices + for i in range(2): + tensors[i].retag({"Open": "Close"}, inplace=True) + self.state._dangling[str(edge[i])] = new_inds[i] + cz_tn = TensorNetwork( + [ + qtn.Tensor( + self._decomposed_cz[0], + [new_inds[0], old_inds[0], new_inds[2]], + [str(edge[0]), "CZ", "Open"], + ), + qtn.Tensor( + self._decomposed_cz[1], + [new_inds[2], new_inds[1], old_inds[1]], + [str(edge[1]), "CZ", "Open"], + ), + ] + ) + self.state.add_tensor_network(cz_tn) + case "opt": + pass @override def measure(self, node: int, measurement: Measurement, rng: Generator | None = None) -> Outcome: diff --git a/graphix/simulator.py b/graphix/simulator.py index 380b12b5..de83ab9c 100644 --- a/graphix/simulator.py +++ b/graphix/simulator.py @@ -316,31 +316,31 @@ def run(self, input_state: Data = BasicStates.PLUS, rng: Generator | None = None self.pattern.check_runnability() for cmd in pattern: - if cmd.kind == CommandKind.N: - self.__prepare_method.prepare(self.backend, cmd, rng=rng) - elif cmd.kind == CommandKind.E: - self.backend.entangle_nodes(edge=cmd.nodes) - elif cmd.kind == CommandKind.M: - self.__measure_method.measure(self.backend, cmd, noise_model=self.noise_model, rng=rng) - # Use of `==` here for mypy - elif cmd.kind == CommandKind.X or cmd.kind == CommandKind.Z: # noqa: PLR1714 - if self.__measure_method.check_domain(cmd.domain): - self.backend.correct_byproduct(cmd) - elif cmd.kind == CommandKind.C: - self.backend.apply_clifford(cmd.node, cmd.clifford) - elif cmd.kind == CommandKind.T: - # The T command is a flag for one clock cycle in a simulated - # experiment, added via a hardware-agnostic - # pattern modifier. Noise models can perform special - # handling of ticks during noise transpilation. - pass - elif cmd.kind == CommandKind.ApplyNoise: - if cmd.domain is None or self.__measure_method.check_domain(cmd.domain): - self.backend.apply_noise(cmd) - elif cmd.kind == CommandKind.S: - raise ValueError("S commands unexpected in simulated patterns.") - else: - assert_never(cmd.kind) + match cmd.kind: + case CommandKind.N: + self.__prepare_method.prepare(self.backend, cmd, rng=rng) + case CommandKind.E: + self.backend.entangle_nodes(edge=cmd.nodes) + case CommandKind.M: + self.__measure_method.measure(self.backend, cmd, noise_model=self.noise_model, rng=rng) + case CommandKind.X | CommandKind.Z: + if self.__measure_method.check_domain(cmd.domain): + self.backend.correct_byproduct(cmd) + case CommandKind.C: + self.backend.apply_clifford(cmd.node, cmd.clifford) + case CommandKind.T: + # The T command is a flag for one clock cycle in a simulated + # experiment, added via a hardware-agnostic + # pattern modifier. Noise models can perform special + # handling of ticks during noise transpilation. + pass + case CommandKind.ApplyNoise: + if cmd.domain is None or self.__measure_method.check_domain(cmd.domain): + self.backend.apply_noise(cmd) + case CommandKind.S: + raise ValueError("S commands unexpected in simulated patterns.") + case _: + assert_never(cmd.kind) self.backend.finalize(output_nodes=self.pattern.output_nodes) diff --git a/graphix/transpiler.py b/graphix/transpiler.py index 8656b913..fd82c4d9 100644 --- a/graphix/transpiler.py +++ b/graphix/transpiler.py @@ -101,38 +101,39 @@ def __init__(self, width: int, instr: Iterable[Instruction] | None = None) -> No def add(self, instr: Instruction) -> None: """Add an instruction to the circuit.""" - if instr.kind == InstructionKind.CCX: - self.ccx(instr.controls[0], instr.controls[1], instr.target) - elif instr.kind == InstructionKind.RZZ: - self.rzz(instr.control, instr.target, instr.angle) - elif instr.kind == InstructionKind.CNOT: - self.cnot(instr.control, instr.target) - elif instr.kind == InstructionKind.SWAP: - self.swap(instr.targets[0], instr.targets[1]) - elif instr.kind == InstructionKind.CZ: - self.cz(instr.targets[0], instr.targets[1]) - elif instr.kind == InstructionKind.H: - self.h(instr.target) - elif instr.kind == InstructionKind.S: - self.s(instr.target) - elif instr.kind == InstructionKind.X: - self.x(instr.target) - elif instr.kind == InstructionKind.Y: - self.y(instr.target) - elif instr.kind == InstructionKind.Z: - self.z(instr.target) - elif instr.kind == InstructionKind.I: - self.i(instr.target) - elif instr.kind == InstructionKind.M: - self.m(instr.target, instr.axis) - elif instr.kind == InstructionKind.RX: - self.rx(instr.target, instr.angle) - elif instr.kind == InstructionKind.RY: - self.ry(instr.target, instr.angle) - elif instr.kind == InstructionKind.RZ: - self.rz(instr.target, instr.angle) - else: - assert_never(instr.kind) + match instr.kind: + case InstructionKind.CCX: + self.ccx(instr.controls[0], instr.controls[1], instr.target) + case InstructionKind.RZZ: + self.rzz(instr.control, instr.target, instr.angle) + case InstructionKind.CNOT: + self.cnot(instr.control, instr.target) + case InstructionKind.SWAP: + self.swap(instr.targets[0], instr.targets[1]) + case InstructionKind.CZ: + self.cz(instr.targets[0], instr.targets[1]) + case InstructionKind.H: + self.h(instr.target) + case InstructionKind.S: + self.s(instr.target) + case InstructionKind.X: + self.x(instr.target) + case InstructionKind.Y: + self.y(instr.target) + case InstructionKind.Z: + self.z(instr.target) + case InstructionKind.I: + self.i(instr.target) + case InstructionKind.M: + self.m(instr.target, instr.axis) + case InstructionKind.RX: + self.rx(instr.target, instr.angle) + case InstructionKind.RY: + self.ry(instr.target, instr.angle) + case InstructionKind.RZ: + self.rz(instr.target, instr.angle) + case _: + assert_never(instr.kind) def extend(self, instrs: Iterable[Instruction]) -> None: """Add instructions to the circuit.""" @@ -365,101 +366,102 @@ def transpile(self) -> TranspileResult: pattern = Pattern(input_nodes=list(range(self.width))) classical_outputs = [] for instr in _transpile_rzz(self.instruction): - if instr.kind == instruction.InstructionKind.CZ: - target0 = _check_target(out, instr.targets[0]) - target1 = _check_target(out, instr.targets[1]) - seq = self._cz_command(target0, target1) - pattern.extend(seq) - elif instr.kind == instruction.InstructionKind.CNOT: - ancilla = [n_node, n_node + 1] - control = _check_target(out, instr.control) - target = _check_target(out, instr.target) - out[instr.control], out[instr.target], seq = self._cnot_command(control, target, ancilla) - pattern.extend(seq) - n_node += 2 - elif instr.kind == instruction.InstructionKind.SWAP: - target0 = _check_target(out, instr.targets[0]) - target1 = _check_target(out, instr.targets[1]) - out[instr.targets[0]], out[instr.targets[1]] = ( - target1, - target0, - ) - elif instr.kind == instruction.InstructionKind.I: - pass - elif instr.kind == instruction.InstructionKind.H: - single_ancilla = n_node - target = _check_target(out, instr.target) - out[instr.target], seq = self._h_command(target, single_ancilla) - pattern.extend(seq) - n_node += 1 - elif instr.kind == instruction.InstructionKind.S: - ancilla = [n_node, n_node + 1] - target = _check_target(out, instr.target) - out[instr.target], seq = self._s_command(target, ancilla) - pattern.extend(seq) - n_node += 2 - elif instr.kind == instruction.InstructionKind.X: - ancilla = [n_node, n_node + 1] - target = _check_target(out, instr.target) - out[instr.target], seq = self._x_command(target, ancilla) - pattern.extend(seq) - n_node += 2 - elif instr.kind == instruction.InstructionKind.Y: - ancilla = [n_node, n_node + 1, n_node + 2, n_node + 3] - target = _check_target(out, instr.target) - out[instr.target], seq = self._y_command(target, ancilla) - pattern.extend(seq) - n_node += 4 - elif instr.kind == instruction.InstructionKind.Z: - ancilla = [n_node, n_node + 1] - target = _check_target(out, instr.target) - out[instr.target], seq = self._z_command(target, ancilla) - pattern.extend(seq) - n_node += 2 - elif instr.kind == instruction.InstructionKind.RX: - ancilla = [n_node, n_node + 1] - target = _check_target(out, instr.target) - out[instr.target], seq = self._rx_command(target, ancilla, instr.angle) - pattern.extend(seq) - n_node += 2 - elif instr.kind == instruction.InstructionKind.RY: - ancilla = [n_node, n_node + 1, n_node + 2, n_node + 3] - target = _check_target(out, instr.target) - out[instr.target], seq = self._ry_command(target, ancilla, instr.angle) - pattern.extend(seq) - n_node += 4 - elif instr.kind == instruction.InstructionKind.RZ: - ancilla = [n_node, n_node + 1] - target = _check_target(out, instr.target) - out[instr.target], seq = self._rz_command(target, ancilla, instr.angle) - pattern.extend(seq) - n_node += 2 - elif instr.kind == instruction.InstructionKind.CCX: - ancilla = [n_node + i for i in range(18)] - control0 = _check_target(out, instr.controls[0]) - control1 = _check_target(out, instr.controls[1]) - target = _check_target(out, instr.target) - ( - out[instr.controls[0]], - out[instr.controls[1]], - out[instr.target], - seq, - ) = self._ccx_command( - control0, - control1, - target, - ancilla, - ) - pattern.extend(seq) - n_node += 18 - elif instr.kind == instruction.InstructionKind.M: - target = _check_target(out, instr.target) - seq = self._m_command(target, instr.axis) - pattern.extend(seq) - classical_outputs.append(target) - out[instr.target] = None - else: - raise ValueError("Unknown instruction, commands not added") + match instr.kind: + case instruction.InstructionKind.CZ: + target0 = _check_target(out, instr.targets[0]) + target1 = _check_target(out, instr.targets[1]) + seq = self._cz_command(target0, target1) + pattern.extend(seq) + case instruction.InstructionKind.CNOT: + ancilla = [n_node, n_node + 1] + control = _check_target(out, instr.control) + target = _check_target(out, instr.target) + out[instr.control], out[instr.target], seq = self._cnot_command(control, target, ancilla) + pattern.extend(seq) + n_node += 2 + case instruction.InstructionKind.SWAP: + target0 = _check_target(out, instr.targets[0]) + target1 = _check_target(out, instr.targets[1]) + out[instr.targets[0]], out[instr.targets[1]] = ( + target1, + target0, + ) + case instruction.InstructionKind.I: + pass + case instruction.InstructionKind.H: + single_ancilla = n_node + target = _check_target(out, instr.target) + out[instr.target], seq = self._h_command(target, single_ancilla) + pattern.extend(seq) + n_node += 1 + case instruction.InstructionKind.S: + ancilla = [n_node, n_node + 1] + target = _check_target(out, instr.target) + out[instr.target], seq = self._s_command(target, ancilla) + pattern.extend(seq) + n_node += 2 + case instruction.InstructionKind.X: + ancilla = [n_node, n_node + 1] + target = _check_target(out, instr.target) + out[instr.target], seq = self._x_command(target, ancilla) + pattern.extend(seq) + n_node += 2 + case instruction.InstructionKind.Y: + ancilla = [n_node, n_node + 1, n_node + 2, n_node + 3] + target = _check_target(out, instr.target) + out[instr.target], seq = self._y_command(target, ancilla) + pattern.extend(seq) + n_node += 4 + case instruction.InstructionKind.Z: + ancilla = [n_node, n_node + 1] + target = _check_target(out, instr.target) + out[instr.target], seq = self._z_command(target, ancilla) + pattern.extend(seq) + n_node += 2 + case instruction.InstructionKind.RX: + ancilla = [n_node, n_node + 1] + target = _check_target(out, instr.target) + out[instr.target], seq = self._rx_command(target, ancilla, instr.angle) + pattern.extend(seq) + n_node += 2 + case instruction.InstructionKind.RY: + ancilla = [n_node, n_node + 1, n_node + 2, n_node + 3] + target = _check_target(out, instr.target) + out[instr.target], seq = self._ry_command(target, ancilla, instr.angle) + pattern.extend(seq) + n_node += 4 + case instruction.InstructionKind.RZ: + ancilla = [n_node, n_node + 1] + target = _check_target(out, instr.target) + out[instr.target], seq = self._rz_command(target, ancilla, instr.angle) + pattern.extend(seq) + n_node += 2 + case instruction.InstructionKind.CCX: + ancilla = [n_node + i for i in range(18)] + control0 = _check_target(out, instr.controls[0]) + control1 = _check_target(out, instr.controls[1]) + target = _check_target(out, instr.target) + ( + out[instr.controls[0]], + out[instr.controls[1]], + out[instr.target], + seq, + ) = self._ccx_command( + control0, + control1, + target, + ancilla, + ) + pattern.extend(seq) + n_node += 18 + case instruction.InstructionKind.M: + target = _check_target(out, instr.target) + seq = self._m_command(target, instr.axis) + pattern.extend(seq) + classical_outputs.append(target) + out[instr.target] = None + case _: + raise ValueError("Unknown instruction, commands not added") output_nodes = [node for node in out if node is not None] pattern.reorder_output_nodes(output_nodes) return TranspileResult(pattern, tuple(classical_outputs)) @@ -950,45 +952,48 @@ def evolve_single(op: Matrix, target: int) -> None: def evolve(op: Matrix, qargs: Iterable[int]) -> None: backend.state.evolve(op, [backend.node_index.index(qarg) for qarg in qargs]) - if instr.kind == instruction.InstructionKind.CNOT: - backend.state.cnot((backend.node_index.index(instr.control), backend.node_index.index(instr.target))) - elif instr.kind == instruction.InstructionKind.SWAP: - u, v = instr.targets - backend.state.swap((backend.node_index.index(u), backend.node_index.index(v))) - elif instr.kind == instruction.InstructionKind.CZ: - u, v = instr.targets - backend.state.entangle((backend.node_index.index(u), backend.node_index.index(v))) - elif instr.kind == instruction.InstructionKind.I: - pass - elif instr.kind == instruction.InstructionKind.S: - evolve_single(Ops.S, instr.target) - elif instr.kind == instruction.InstructionKind.H: - evolve_single(Ops.H, instr.target) - elif instr.kind == instruction.InstructionKind.X: - evolve_single(Ops.X, instr.target) - elif instr.kind == instruction.InstructionKind.Y: - evolve_single(Ops.Y, instr.target) - elif instr.kind == instruction.InstructionKind.Z: - evolve_single(Ops.Z, instr.target) - elif instr.kind == instruction.InstructionKind.RX: - evolve_single(Ops.rx(instr.angle), instr.target) - elif instr.kind == instruction.InstructionKind.RY: - evolve_single(Ops.ry(instr.angle), instr.target) - elif instr.kind == instruction.InstructionKind.RZ: - evolve_single(Ops.rz(instr.angle), instr.target) - elif instr.kind == instruction.InstructionKind.RZZ: - evolve(Ops.rzz(instr.angle), [instr.control, instr.target]) - elif instr.kind == instruction.InstructionKind.CCX: - evolve(Ops.CCX, [instr.controls[0], instr.controls[1], instr.target]) - elif instr.kind == instruction.InstructionKind.M: - result = backend.measure( - instr.target, - _measurement_of_axis(instr.axis), - rng=rng, - ) - classical_measures.append(result) - else: - raise ValueError(f"Unknown instruction: {instr}") + match instr.kind: + case instruction.InstructionKind.CNOT: + backend.state.cnot( + (backend.node_index.index(instr.control), backend.node_index.index(instr.target)) + ) + case instruction.InstructionKind.SWAP: + u, v = instr.targets + backend.state.swap((backend.node_index.index(u), backend.node_index.index(v))) + case instruction.InstructionKind.CZ: + u, v = instr.targets + backend.state.entangle((backend.node_index.index(u), backend.node_index.index(v))) + case instruction.InstructionKind.I: + pass + case instruction.InstructionKind.S: + evolve_single(Ops.S, instr.target) + case instruction.InstructionKind.H: + evolve_single(Ops.H, instr.target) + case instruction.InstructionKind.X: + evolve_single(Ops.X, instr.target) + case instruction.InstructionKind.Y: + evolve_single(Ops.Y, instr.target) + case instruction.InstructionKind.Z: + evolve_single(Ops.Z, instr.target) + case instruction.InstructionKind.RX: + evolve_single(Ops.rx(instr.angle), instr.target) + case instruction.InstructionKind.RY: + evolve_single(Ops.ry(instr.angle), instr.target) + case instruction.InstructionKind.RZ: + evolve_single(Ops.rz(instr.angle), instr.target) + case instruction.InstructionKind.RZZ: + evolve(Ops.rzz(instr.angle), [instr.control, instr.target]) + case instruction.InstructionKind.CCX: + evolve(Ops.CCX, [instr.controls[0], instr.controls[1], instr.target]) + case instruction.InstructionKind.M: + result = backend.measure( + instr.target, + _measurement_of_axis(instr.axis), + rng=rng, + ) + classical_measures.append(result) + case _: + raise ValueError(f"Unknown instruction: {instr}") return SimulateResult(backend.state, tuple(classical_measures)) def map_angle(self, f: Callable[[ParameterizedAngle], ParameterizedAngle]) -> Circuit: From 93dc8307a39bd9220d630bfb3b275338e1b77fc8 Mon Sep 17 00:00:00 2001 From: Thierry Martinez Date: Sun, 15 Feb 2026 10:55:28 +0100 Subject: [PATCH 2/2] Rewrite other occurrences --- graphix/flow/exceptions.py | 117 ++++++++++++++------------- graphix/noise_models/depolarising.py | 42 +++++----- graphix/pattern.py | 20 +++-- graphix/pretty_print.py | 13 ++- graphix/qasm3_exporter.py | 85 +++++++++---------- graphix/transpiler.py | 31 +++---- 6 files changed, 154 insertions(+), 154 deletions(-) diff --git a/graphix/flow/exceptions.py b/graphix/flow/exceptions.py index ddad7a4a..a3bc9e2d 100644 --- a/graphix/flow/exceptions.py +++ b/graphix/flow/exceptions.py @@ -160,29 +160,30 @@ def __str__(self) -> str: """Explain the error.""" error_help = f"Error found at c({self.node}) = {self.correction_set}." - if self.reason == FlowPropositionErrorReason.C0: - return f"Correction set c({self.node}) = {self.correction_set} has more than one element." + match self.reason: + case FlowPropositionErrorReason.C0: + return f"Correction set c({self.node}) = {self.correction_set} has more than one element." - if self.reason == FlowPropositionErrorReason.C1: - return f"{self.reason.name}: a node and its corrector must be neighbors. {error_help}" + case FlowPropositionErrorReason.C1: + return f"{self.reason.name}: a node and its corrector must be neighbors. {error_help}" - if self.reason == FlowPropositionErrorReason.G3 or self.reason == FlowPropositionErrorReason.P4: # noqa: PLR1714 - return f"{self.reason.name}: nodes measured on plane XY cannot be in their own correcting set and must belong to the odd neighbourhood of their own correcting set.\n{error_help}" + case FlowPropositionErrorReason.G3 | FlowPropositionErrorReason.P4: + return f"{self.reason.name}: nodes measured on plane XY cannot be in their own correcting set and must belong to the odd neighbourhood of their own correcting set.\n{error_help}" - if self.reason == FlowPropositionErrorReason.G4 or self.reason == FlowPropositionErrorReason.P5: # noqa: PLR1714 - return f"{self.reason.name}: nodes measured on plane XZ must belong to their own correcting set and its odd neighbourhood.\n{error_help}" + case FlowPropositionErrorReason.G4 | FlowPropositionErrorReason.P5: + return f"{self.reason.name}: nodes measured on plane XZ must belong to their own correcting set and its odd neighbourhood.\n{error_help}" - if self.reason == FlowPropositionErrorReason.G5 or self.reason == FlowPropositionErrorReason.P6: # noqa: PLR1714 - return f"{self.reason.name}: nodes measured on plane YZ must belong to their own correcting set and cannot be in the odd neighbourhood of their own correcting set.\n{error_help}" + case FlowPropositionErrorReason.G5 | FlowPropositionErrorReason.P6: + return f"{self.reason.name}: nodes measured on plane YZ must belong to their own correcting set and cannot be in the odd neighbourhood of their own correcting set.\n{error_help}" - if self.reason == FlowPropositionErrorReason.P7: - return f"{self.reason.name}: nodes measured along axis X must belong to the odd neighbourhood of their own correcting set.\n{error_help}" + case FlowPropositionErrorReason.P7: + return f"{self.reason.name}: nodes measured along axis X must belong to the odd neighbourhood of their own correcting set.\n{error_help}" - if self.reason == FlowPropositionErrorReason.P8: - return f"{self.reason.name}: nodes measured along axis Z must belong to their own correcting set.\n{error_help}" + case FlowPropositionErrorReason.P8: + return f"{self.reason.name}: nodes measured along axis Z must belong to their own correcting set.\n{error_help}" - if self.reason == FlowPropositionErrorReason.P9: - return f"{self.reason.name}: nodes measured along axis Y must belong to the closed odd neighbourhood of their own correcting set.\n{error_help}" + case FlowPropositionErrorReason.P9: + return f"{self.reason.name}: nodes measured along axis Y must belong to the closed odd neighbourhood of their own correcting set.\n{error_help}" assert_never(self.reason) @@ -200,23 +201,24 @@ def __str__(self) -> str: """Explain the error.""" error_help = f"The flow's partial order implies that {self.past_and_present_nodes - {self.node}} ≼ {self.node}. This is incompatible with the correction set c({self.node}) = {self.correction_set}." - if self.reason == FlowPropositionOrderErrorReason.C2 or self.reason == FlowPropositionOrderErrorReason.G1: # noqa: PLR1714 - return f"{self.reason.name}: nodes must be in the past of their correction set.\n{error_help}" + match self.reason: + case FlowPropositionOrderErrorReason.C2 | FlowPropositionOrderErrorReason.G1: + return f"{self.reason.name}: nodes must be in the past of their correction set.\n{error_help}" - if self.reason == FlowPropositionOrderErrorReason.C3: - return f"{self.reason.name}: neighbors of the correcting nodes (except the corrected node) must be in the future of the corrected node.\n{error_help}" + case FlowPropositionOrderErrorReason.C3: + return f"{self.reason.name}: neighbors of the correcting nodes (except the corrected node) must be in the future of the corrected node.\n{error_help}" - if self.reason == FlowPropositionOrderErrorReason.G2: - return f"{self.reason.name}: the odd neighbourhood (except the corrected node) of the correcting nodes must be in the future of the corrected node.\n{error_help}" + case FlowPropositionOrderErrorReason.G2: + return f"{self.reason.name}: the odd neighbourhood (except the corrected node) of the correcting nodes must be in the future of the corrected node.\n{error_help}" - if self.reason == FlowPropositionOrderErrorReason.P1: - return f"{self.reason.name}: nodes must be in the past of their correcting nodes unless these are measured along the X or the Y axes.\n{error_help}" + case FlowPropositionOrderErrorReason.P1: + return f"{self.reason.name}: nodes must be in the past of their correcting nodes unless these are measured along the X or the Y axes.\n{error_help}" - if self.reason == FlowPropositionOrderErrorReason.P2: - return f"{self.reason.name}: the odd neighbourhood (except the corrected node and nodes measured along axes Y or Z) of the correcting nodes must be in the future of the corrected node.\n{error_help}" + case FlowPropositionOrderErrorReason.P2: + return f"{self.reason.name}: the odd neighbourhood (except the corrected node and nodes measured along axes Y or Z) of the correcting nodes must be in the future of the corrected node.\n{error_help}" - if self.reason == FlowPropositionOrderErrorReason.P3: - return f"{self.reason.name}: nodes that are measured along axis Y and that are not in the future of the corrected node (except the corrected node itself) cannot be in the closed odd neighbourhood of the correcting set.\n{error_help}" + case FlowPropositionOrderErrorReason.P3: + return f"{self.reason.name}: nodes that are measured along axis Y and that are not in the future of the corrected node (except the corrected node itself) cannot be in the closed odd neighbourhood of the correcting set.\n{error_help}" assert_never(self.reason) @@ -229,14 +231,15 @@ class FlowGenericError(FlowError): def __str__(self) -> str: """Explain the error.""" - if self.reason == FlowGenericErrorReason.IncorrectCorrectionFunctionDomain: - return "The domain of the correction function must be the set of non-output nodes (measured qubits) of the open graph." + match self.reason: + case FlowGenericErrorReason.IncorrectCorrectionFunctionDomain: + return "The domain of the correction function must be the set of non-output nodes (measured qubits) of the open graph." - if self.reason == FlowGenericErrorReason.IncorrectCorrectionFunctionImage: - return "The image of the correction function must be a subset of non-input nodes (prepared qubits) of the open graph." + case FlowGenericErrorReason.IncorrectCorrectionFunctionImage: + return "The image of the correction function must be a subset of non-input nodes (prepared qubits) of the open graph." - if self.reason == FlowGenericErrorReason.XYPlane: - return "Causal flow is only defined on open graphs with XY measurements." + case FlowGenericErrorReason.XYPlane: + return "Causal flow is only defined on open graphs with XY measurements." assert_never(self.reason) @@ -249,11 +252,12 @@ class PartialOrderError(FlowError, XZCorrectionsError): def __str__(self) -> str: """Explain the error.""" - if self.reason == PartialOrderErrorReason.Empty: - return "The partial order cannot be empty." + match self.reason: + case PartialOrderErrorReason.Empty: + return "The partial order cannot be empty." - if self.reason == PartialOrderErrorReason.IncorrectNodes: - return "The partial order does not contain all the nodes of the open graph or contains nodes that are not in the open graph." + case PartialOrderErrorReason.IncorrectNodes: + return "The partial order does not contain all the nodes of the open graph or contains nodes that are not in the open graph." assert_never(self.reason) @@ -267,13 +271,14 @@ class PartialOrderLayerError(FlowError, XZCorrectionsError): def __str__(self) -> str: """Explain the error.""" - if self.reason == PartialOrderLayerErrorReason.FirstLayer: - return f"The first layer of the partial order must contain all the output nodes of the open graph and cannot be empty. First layer: {self.layer}" + match self.reason: + case PartialOrderLayerErrorReason.FirstLayer: + return f"The first layer of the partial order must contain all the output nodes of the open graph and cannot be empty. First layer: {self.layer}" - # Note: A flow defined on an open graph without outputs will trigger this error. This is not the case for an XZ-corrections object. + # Note: A flow defined on an open graph without outputs will trigger this error. This is not the case for an XZ-corrections object. - if self.reason == PartialOrderLayerErrorReason.NthLayer: - return f"Partial order layer {self.layer_index} = {self.layer} contains non-measured nodes of the open graph, is empty or contains nodes in previous layers." + case PartialOrderLayerErrorReason.NthLayer: + return f"Partial order layer {self.layer_index} = {self.layer} contains non-measured nodes of the open graph, is empty or contains nodes in previous layers." assert_never(self.reason) @@ -288,11 +293,12 @@ class XZCorrectionsOrderError(XZCorrectionsError): def __str__(self) -> str: """Explain the error.""" - if self.reason == XZCorrectionsOrderErrorReason.X: - return "The X-correction set {self.node} -> {self.correction_set} is incompatible with the partial order: {self.past_and_present_nodes - {self.node}} ≼ {self.node}." + match self.reason: + case XZCorrectionsOrderErrorReason.X: + return "The X-correction set {self.node} -> {self.correction_set} is incompatible with the partial order: {self.past_and_present_nodes - {self.node}} ≼ {self.node}." - if self.reason == XZCorrectionsOrderErrorReason.Z: - return "The Z-correction set {self.node} -> {self.correction_set} is incompatible with the partial order: {self.past_and_present_nodes - {self.node}} ≼ {self.node}." + case XZCorrectionsOrderErrorReason.Z: + return "The Z-correction set {self.node} -> {self.correction_set} is incompatible with the partial order: {self.past_and_present_nodes - {self.node}} ≼ {self.node}." assert_never(self.reason) @@ -305,13 +311,14 @@ class XZCorrectionsGenericError(XZCorrectionsError): def __str__(self) -> str: """Explain the error.""" - if self.reason == XZCorrectionsGenericErrorReason.IncorrectKeys: - return "Keys of correction dictionaries must be a subset of the measured nodes." - if self.reason == XZCorrectionsGenericErrorReason.IncorrectValues: - return "Values of correction dictionaries must contain labels which are nodes of the open graph." - if self.reason == XZCorrectionsGenericErrorReason.ClosedLoop: - return "XZ-corrections are not runnable since the induced directed graph contains closed loops." - if self.reason == XZCorrectionsGenericErrorReason.IncompatibleOrder: - return "The input total measurement order is not compatible with the partial order induced by the XZ-corrections." + match self.reason: + case XZCorrectionsGenericErrorReason.IncorrectKeys: + return "Keys of correction dictionaries must be a subset of the measured nodes." + case XZCorrectionsGenericErrorReason.IncorrectValues: + return "Values of correction dictionaries must contain labels which are nodes of the open graph." + case XZCorrectionsGenericErrorReason.ClosedLoop: + return "XZ-corrections are not runnable since the induced directed graph contains closed loops." + case XZCorrectionsGenericErrorReason.IncompatibleOrder: + return "The input total measurement order is not compatible with the partial order induced by the XZ-corrections." assert_never(self.reason) diff --git a/graphix/noise_models/depolarising.py b/graphix/noise_models/depolarising.py index ff72a416..4748e34a 100644 --- a/graphix/noise_models/depolarising.py +++ b/graphix/noise_models/depolarising.py @@ -108,24 +108,30 @@ def input_nodes(self, nodes: Iterable[int], rng: Generator | None = None) -> lis @typing_extensions.override def command(self, cmd: CommandOrNoise, rng: Generator | None = None) -> list[CommandOrNoise]: """Return the noise to apply to the command ``cmd``.""" - if cmd.kind == CommandKind.N: - return [cmd, ApplyNoise(noise=DepolarisingNoise(self.prepare_error_prob), nodes=[cmd.node])] - if cmd.kind == CommandKind.E: - return [ - cmd, - ApplyNoise(noise=TwoQubitDepolarisingNoise(self.entanglement_error_prob), nodes=list(cmd.nodes)), - ] - if cmd.kind == CommandKind.M: - return [ApplyNoise(noise=DepolarisingNoise(self.measure_channel_prob), nodes=[cmd.node]), cmd] - if cmd.kind == CommandKind.X: - return [cmd, ApplyNoise(noise=DepolarisingNoise(self.x_error_prob), nodes=[cmd.node], domain=cmd.domain)] - if cmd.kind == CommandKind.Z: - return [cmd, ApplyNoise(noise=DepolarisingNoise(self.z_error_prob), nodes=[cmd.node], domain=cmd.domain)] - # Use of `==` here for mypy - if cmd.kind == CommandKind.C or cmd.kind == CommandKind.T or cmd.kind == CommandKind.ApplyNoise: # noqa: PLR1714 - return [cmd] - if cmd.kind == CommandKind.S: - raise ValueError("Unexpected signal!") + match cmd.kind: + case CommandKind.N: + return [cmd, ApplyNoise(noise=DepolarisingNoise(self.prepare_error_prob), nodes=[cmd.node])] + case CommandKind.E: + return [ + cmd, + ApplyNoise(noise=TwoQubitDepolarisingNoise(self.entanglement_error_prob), nodes=list(cmd.nodes)), + ] + case CommandKind.M: + return [ApplyNoise(noise=DepolarisingNoise(self.measure_channel_prob), nodes=[cmd.node]), cmd] + case CommandKind.X: + return [ + cmd, + ApplyNoise(noise=DepolarisingNoise(self.x_error_prob), nodes=[cmd.node], domain=cmd.domain), + ] + case CommandKind.Z: + return [ + cmd, + ApplyNoise(noise=DepolarisingNoise(self.z_error_prob), nodes=[cmd.node], domain=cmd.domain), + ] + case CommandKind.C | CommandKind.T | CommandKind.ApplyNoise: + return [cmd] + case CommandKind.S: + raise ValueError("Unexpected signal!") typing_extensions.assert_never(cmd.kind) @typing_extensions.override diff --git a/graphix/pattern.py b/graphix/pattern.py index 51c8b7ed..f151a089 100644 --- a/graphix/pattern.py +++ b/graphix/pattern.py @@ -277,12 +277,12 @@ def update_command(cmd: Command) -> Command: cmd_new.nodes = (mapping_complete[i], mapping_complete[j]) elif cmd_new.kind is not CommandKind.T: cmd_new.node = mapping_complete[cmd_new.node] - if cmd_new.kind is CommandKind.M: - cmd_new.s_domain = {mapping_complete[i] for i in cmd_new.s_domain} - cmd_new.t_domain = {mapping_complete[i] for i in cmd_new.t_domain} - # Use of `==` here for mypy - elif cmd_new.kind == CommandKind.X or cmd_new.kind == CommandKind.Z or cmd_new.kind == CommandKind.S: # noqa: PLR1714 - cmd_new.domain = {mapping_complete[i] for i in cmd_new.domain} + match cmd_new.kind: + case CommandKind.M: + cmd_new.s_domain = {mapping_complete[i] for i in cmd_new.s_domain} + cmd_new.t_domain = {mapping_complete[i] for i in cmd_new.t_domain} + case CommandKind.X | CommandKind.Z | CommandKind.S: + cmd_new.domain = {mapping_complete[i] for i in cmd_new.domain} return cmd_new @@ -1225,8 +1225,12 @@ def connected_nodes(self, node: int, prepared: set[int] | None = None) -> list[i def correction_commands(self) -> list[command.X | command.Z]: """Return the list of byproduct correction commands.""" assert self.is_standard() - # Use of `==` here for mypy - return [seqi for seqi in self.__seq if seqi.kind == CommandKind.X or seqi.kind == CommandKind.Z] # noqa: PLR1714 + cmds = [] + for cmd in self: + match cmd.kind: + case CommandKind.X | CommandKind.Z: + cmds.append(cmd) + return cmds def parallelize_pattern(self) -> None: """Optimize the pattern to reduce the depth of the computation by gathering measurement commands that can be performed simultaneously. diff --git a/graphix/pretty_print.py b/graphix/pretty_print.py index b3cd7734..2f53c99b 100644 --- a/graphix/pretty_print.py +++ b/graphix/pretty_print.py @@ -159,14 +159,11 @@ def command_to_str(cmd: command.Command, output: OutputFormat) -> str: arguments.append(str(cmd.angle * math.pi)) case command.CommandKind.C: arguments.append(str(cmd.clifford)) - # Use of `==` here for mypy - command_domain = ( - cmd.domain - if cmd.kind == command.CommandKind.X # noqa: PLR1714 - or cmd.kind == command.CommandKind.Z - or cmd.kind == command.CommandKind.S - else None - ) + match cmd.kind: + case command.CommandKind.X | command.CommandKind.Z | command.CommandKind.S: + command_domain: set[int] | None = cmd.domain + case _: + command_domain = None match output: case OutputFormat.LaTeX: out.append(f"_{{{cmd.node}}}") diff --git a/graphix/qasm3_exporter.py b/graphix/qasm3_exporter.py index b23303c9..bfabe2be 100644 --- a/graphix/qasm3_exporter.py +++ b/graphix/qasm3_exporter.py @@ -72,53 +72,46 @@ def angle_to_qasm3(angle: ParameterizedAngle) -> str: def instruction_to_qasm3(instruction: Instruction) -> str: """Get the OpenQASM3 representation of a single circuit instruction.""" - if instruction.kind == InstructionKind.M: - if instruction.axis != Axis.Z: - raise ValueError( - "OpenQASM3 only supports measurements on Z axis. Use `Circuit.transpile_measurements_to_z_axis` to rewrite measurements on X and Y axes." + match instruction.kind: + case InstructionKind.M: + if instruction.axis != Axis.Z: + raise ValueError( + "OpenQASM3 only supports measurements on Z axis. Use `Circuit.transpile_measurements_to_z_axis` to rewrite measurements on X and Y axes." + ) + return f"b[{instruction.target}] = measure q[{instruction.target}]" + case InstructionKind.RX | InstructionKind.RY | InstructionKind.RZ: + angle = angle_to_qasm3(instruction.angle) + return qasm3_gate_call( + instruction.kind.name.lower(), args=[angle], operands=[qasm3_qubit(instruction.target)] ) - return f"b[{instruction.target}] = measure q[{instruction.target}]" - # Use of `==` here for mypy - if ( - instruction.kind == InstructionKind.RX # noqa: PLR1714 - or instruction.kind == InstructionKind.RY - or instruction.kind == InstructionKind.RZ - ): - angle = angle_to_qasm3(instruction.angle) - return qasm3_gate_call(instruction.kind.name.lower(), args=[angle], operands=[qasm3_qubit(instruction.target)]) - - # Use of `==` here for mypy - if ( - instruction.kind == InstructionKind.H # noqa: PLR1714 - or instruction.kind == InstructionKind.S - or instruction.kind == InstructionKind.X - or instruction.kind == InstructionKind.Y - or instruction.kind == InstructionKind.Z - ): - return qasm3_gate_call(instruction.kind.name.lower(), [qasm3_qubit(instruction.target)]) - if instruction.kind == InstructionKind.I: - return qasm3_gate_call("id", [qasm3_qubit(instruction.target)]) - if instruction.kind == InstructionKind.CNOT: - return qasm3_gate_call("cx", [qasm3_qubit(instruction.control), qasm3_qubit(instruction.target)]) - if instruction.kind == InstructionKind.SWAP: - return qasm3_gate_call("swap", [qasm3_qubit(instruction.targets[i]) for i in (0, 1)]) - if instruction.kind == InstructionKind.CZ: - return qasm3_gate_call("cz", [qasm3_qubit(instruction.targets[i]) for i in (0, 1)]) - if instruction.kind == InstructionKind.RZZ: - angle = angle_to_qasm3(instruction.angle) - return qasm3_gate_call( - "crz", args=[angle], operands=[qasm3_qubit(instruction.control), qasm3_qubit(instruction.target)] - ) - if instruction.kind == InstructionKind.CCX: - return qasm3_gate_call( - "ccx", - [ - qasm3_qubit(instruction.controls[0]), - qasm3_qubit(instruction.controls[1]), - qasm3_qubit(instruction.target), - ], - ) - assert_never(instruction.kind) + + # Use of `==` here for mypy + case InstructionKind.H | InstructionKind.S | InstructionKind.X | InstructionKind.Y | InstructionKind.Z: + return qasm3_gate_call(instruction.kind.name.lower(), [qasm3_qubit(instruction.target)]) + case InstructionKind.I: + return qasm3_gate_call("id", [qasm3_qubit(instruction.target)]) + case InstructionKind.CNOT: + return qasm3_gate_call("cx", [qasm3_qubit(instruction.control), qasm3_qubit(instruction.target)]) + case InstructionKind.SWAP: + return qasm3_gate_call("swap", [qasm3_qubit(instruction.targets[i]) for i in (0, 1)]) + case InstructionKind.CZ: + return qasm3_gate_call("cz", [qasm3_qubit(instruction.targets[i]) for i in (0, 1)]) + case InstructionKind.RZZ: + angle = angle_to_qasm3(instruction.angle) + return qasm3_gate_call( + "crz", args=[angle], operands=[qasm3_qubit(instruction.control), qasm3_qubit(instruction.target)] + ) + case InstructionKind.CCX: + return qasm3_gate_call( + "ccx", + [ + qasm3_qubit(instruction.controls[0]), + qasm3_qubit(instruction.controls[1]), + qasm3_qubit(instruction.target), + ], + ) + case _: + assert_never(instruction.kind) def pattern_to_qasm3(pattern: Pattern, input_state: dict[int, State] | State = BasicStates.PLUS) -> str: diff --git a/graphix/transpiler.py b/graphix/transpiler.py index fd82c4d9..dae2f9ab 100644 --- a/graphix/transpiler.py +++ b/graphix/transpiler.py @@ -1000,17 +1000,12 @@ def map_angle(self, f: Callable[[ParameterizedAngle], ParameterizedAngle]) -> Ci """Apply `f` to all angles that occur in the circuit.""" result = Circuit(self.width) for instr in self.instruction: - # Use == for mypy - if ( - instr.kind == InstructionKind.RZZ # noqa: PLR1714 - or instr.kind == InstructionKind.RX - or instr.kind == InstructionKind.RY - or instr.kind == InstructionKind.RZ - ): - new_instr = dataclasses.replace(instr, angle=f(instr.angle)) - result.instruction.append(new_instr) - else: - result.instruction.append(instr) + match instr.kind: + case InstructionKind.RZZ | InstructionKind.RX | InstructionKind.RY | InstructionKind.RZ: + new_instr = dataclasses.replace(instr, angle=f(instr.angle)) + result.instruction.append(new_instr) + case _: + result.instruction.append(instr) return result def is_parameterized(self) -> bool: @@ -1024,14 +1019,12 @@ def is_parameterized(self) -> bool: """ # Use of `==` here for mypy - return any( - not isinstance(instr.angle, SupportsFloat) - for instr in self.instruction - if instr.kind == InstructionKind.RZZ # noqa: PLR1714 - or instr.kind == InstructionKind.RX - or instr.kind == InstructionKind.RY - or instr.kind == InstructionKind.RZ - ) + for instr in self.instruction: + match instr.kind: + case InstructionKind.RZZ | InstructionKind.RX | InstructionKind.RY | InstructionKind.RZ: + if not isinstance(instr.angle, SupportsFloat): + return True + return False def subs(self, variable: Parameter, substitute: ExpressionOrFloat) -> Circuit: """Return a copy of the circuit where all occurrences of the given variable in measurement angles are substituted by the given value."""