From c6bb1140c38f970a5b4e1a647a6e3d89cb54ae74 Mon Sep 17 00:00:00 2001 From: Kartik Singhal Date: Fri, 22 Mar 2024 10:50:31 -0500 Subject: [PATCH] style: eliminate all type casts --- pytket/phir/sharding/sharder.py | 11 +++-------- tests/test_sharder.py | 10 ++++------ 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/pytket/phir/sharding/sharder.py b/pytket/phir/sharding/sharder.py index 3bb99cb..09d8762 100644 --- a/pytket/phir/sharding/sharder.py +++ b/pytket/phir/sharding/sharder.py @@ -7,7 +7,6 @@ ############################################################################## import logging -from typing import cast from pytket.circuit import Circuit, Command, Conditional, Op, OpType from pytket.unit_id import Bit, Qubit, UnitID @@ -273,11 +272,8 @@ def should_op_create_shard(op: Op) -> bool: `True` if the operation is one that should result in shard creation """ return ( - op.type in (SHARD_TRIGGER_OP_TYPES) - or ( - op.type == OpType.Conditional - and cast(Conditional, op).op.type in (SHARD_TRIGGER_OP_TYPES) - ) + op.type in SHARD_TRIGGER_OP_TYPES + or (isinstance(op, Conditional) and op.op.type in SHARD_TRIGGER_OP_TYPES) or (op.is_gate() and op.n_qubits > 1) ) @@ -289,6 +285,5 @@ def _is_command_global_phase(command: Command) -> bool: command: Command to evaluate """ return command.op.type == OpType.Phase or ( - command.op.type == OpType.Conditional - and cast(Conditional, command.op).op.type == OpType.Phase + isinstance(command.op, Conditional) and command.op.op.type == OpType.Phase ) diff --git a/tests/test_sharder.py b/tests/test_sharder.py index c10f493..4b1a3ff 100644 --- a/tests/test_sharder.py +++ b/tests/test_sharder.py @@ -6,8 +6,6 @@ # ############################################################################## -from typing import cast - from pytket.circuit import Conditional, Op, OpType from pytket.phir.sharding.sharder import Sharder @@ -126,8 +124,8 @@ def test_simple_conditional(self) -> None: assert not shards[1].bits_read # shard 2: if (c==1) z=1; - assert shards[2].primary_command.op.type == OpType.Conditional - assert cast(Conditional, shards[2].primary_command.op).op.type == OpType.SetBits + assert isinstance(shards[2].primary_command.op, Conditional) + assert shards[2].primary_command.op.op.type == OpType.SetBits assert not shards[2].sub_commands assert not shards[2].qubits_used assert shards[2].bits_written == {circuit.bits[1]} @@ -143,8 +141,8 @@ def test_simple_conditional(self) -> None: assert len(shards[3].sub_commands.items()) == 1 s2_qubit, s2_sub_cmds = next(iter(shards[3].sub_commands.items())) assert s2_qubit == circuit.qubits[0] - assert s2_sub_cmds[0].op.type == OpType.Conditional - assert cast(Conditional, s2_sub_cmds[0].op).op.type == OpType.H + assert isinstance(s2_sub_cmds[0].op, Conditional) + assert s2_sub_cmds[0].op.op.type == OpType.H assert s2_sub_cmds[0].qubits == [circuit.qubits[0]] def test_complex_barriers(self) -> None: # noqa: PLR0915