Skip to content

Commit

Permalink
Sanitize type annotations in cirq.sim (quantumlib#4773)
Browse files Browse the repository at this point in the history
* Sanitize type annotations in cirq.sim

* keys
  • Loading branch information
daxfohl authored and MichaelBroughton committed Jan 22, 2022
1 parent fed271d commit 82e6fed
Show file tree
Hide file tree
Showing 30 changed files with 196 additions and 190 deletions.
8 changes: 4 additions & 4 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
import numpy as np

import cirq._version
from cirq import devices, ops, protocols, value, qis
from cirq import devices, ops, protocols, qis
from cirq.circuits._bucket_priority_queue import BucketPriorityQueue
from cirq.circuits.circuit_operation import CircuitOperation
from cirq.circuits.insert_strategy import InsertStrategy
Expand Down Expand Up @@ -886,10 +886,10 @@ def qid_shape(
qids = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits())
return protocols.qid_shape(qids)

def all_measurement_key_objs(self) -> AbstractSet[value.MeasurementKey]:
def all_measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']:
return {key for op in self.all_operations() for key in protocols.measurement_key_objs(op)}

def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]:
def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']:
return self.all_measurement_key_objs()

def all_measurement_key_names(self) -> AbstractSet[str]:
Expand Down Expand Up @@ -1537,7 +1537,7 @@ def factorize(self: CIRCUIT_TYPE) -> Iterable[CIRCUIT_TYPE]:
self._with_sliced_moments([m[qubits] for m in self.moments]) for qubits in qubit_factors
)

def _control_keys_(self) -> FrozenSet[value.MeasurementKey]:
def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
controls = frozenset(k for op in self.all_operations() for k in protocols.control_keys(op))
return controls - protocols.measurement_key_objs(self)

Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class CircuitOperation(ops.Operation):
"""

_hash: Optional[int] = dataclasses.field(default=None, init=False)
_cached_measurement_key_objs: Optional[AbstractSet[value.MeasurementKey]] = dataclasses.field(
_cached_measurement_key_objs: Optional[AbstractSet['cirq.MeasurementKey']] = dataclasses.field(
default=None, init=False
)

Expand Down Expand Up @@ -184,7 +184,7 @@ def _qid_shape_(self) -> Tuple[int, ...]:
def _is_measurement_(self) -> bool:
return self.circuit._is_measurement_()

def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]:
def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']:
if self._cached_measurement_key_objs is None:
circuit_keys = protocols.measurement_key_objs(self.circuit)
if self.repetition_ids is not None:
Expand All @@ -207,7 +207,7 @@ def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]:
def _measurement_key_names_(self) -> AbstractSet[str]:
return {str(key) for key in self._measurement_key_objs_()}

def _control_keys_(self) -> AbstractSet[value.MeasurementKey]:
def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']:
if not protocols.control_keys(self.circuit):
return frozenset()
return protocols.control_keys(self.mapped_circuit())
Expand Down
12 changes: 6 additions & 6 deletions cirq-core/cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import numpy as np

from cirq import devices, ops, protocols, value
from cirq import devices, ops, protocols
from cirq.circuits import AbstractCircuit, Alignment, Circuit
from cirq.circuits.insert_strategy import InsertStrategy
from cirq.type_workarounds import NotImplementedType
Expand Down Expand Up @@ -74,9 +74,9 @@ def __init__(
self._all_qubits: Optional[FrozenSet['cirq.Qid']] = None
self._all_operations: Optional[Tuple[ops.Operation, ...]] = None
self._has_measurements: Optional[bool] = None
self._all_measurement_key_objs: Optional[AbstractSet[value.MeasurementKey]] = None
self._all_measurement_key_objs: Optional[AbstractSet['cirq.MeasurementKey']] = None
self._are_all_measurements_terminal: Optional[bool] = None
self._control_keys: Optional[FrozenSet[value.MeasurementKey]] = None
self._control_keys: Optional[FrozenSet['cirq.MeasurementKey']] = None

@property
def moments(self) -> Sequence['cirq.Moment']:
Expand Down Expand Up @@ -126,15 +126,15 @@ def has_measurements(self) -> bool:
self._has_measurements = super().has_measurements()
return self._has_measurements

def all_measurement_key_objs(self) -> AbstractSet[value.MeasurementKey]:
def all_measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']:
if self._all_measurement_key_objs is None:
self._all_measurement_key_objs = super().all_measurement_key_objs()
return self._all_measurement_key_objs

def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]:
def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']:
return self.all_measurement_key_objs()

def _control_keys_(self) -> FrozenSet[value.MeasurementKey]:
def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
if self._control_keys is None:
self._control_keys = super()._control_keys_()
return self._control_keys
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/contrib/qcircuit/qcircuit_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _render(diagram: circuits.TextDiagramDrawer) -> str:


def circuit_to_latex_using_qcircuit(
circuit: circuits.Circuit, qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT
circuit: 'cirq.Circuit', qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT
) -> str:
"""Returns a QCircuit-based latex diagram of the given circuit.
Expand Down
12 changes: 6 additions & 6 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import numpy as np
import quimb.tensor as qtn

from cirq import devices, study, ops, protocols, value
from cirq import devices, ops, protocols, value
from cirq.sim import simulator_base
from cirq.sim.act_on_args import ActOnArgs

Expand Down Expand Up @@ -126,7 +126,7 @@ def _create_step_result(

def _create_simulator_trial_result(
self,
params: study.ParamResolver,
params: 'cirq.ParamResolver',
measurements: Dict[str, np.ndarray],
final_step_result: 'MPSSimulatorStepResult',
) -> 'MPSTrialResult':
Expand All @@ -151,7 +151,7 @@ class MPSTrialResult(simulator_base.SimulationTrialResultBase['MPSState', 'MPSSt

def __init__(
self,
params: study.ParamResolver,
params: 'cirq.ParamResolver',
measurements: Dict[str, np.ndarray],
final_step_result: 'MPSSimulatorStepResult',
) -> None:
Expand Down Expand Up @@ -321,7 +321,7 @@ def state_vector(self) -> np.ndarray:
sorted_ind = tuple(sorted(state_vector.inds))
return state_vector.fuse({'i': sorted_ind}).data

def partial_trace(self, keep_qubits: Set[ops.Qid]) -> np.ndarray:
def partial_trace(self, keep_qubits: Set['cirq.Qid']) -> np.ndarray:
"""Traces out all qubits except keep_qubits.
Args:
Expand Down Expand Up @@ -475,7 +475,7 @@ def estimation_stats(self):
}

def perform_measurement(
self, qubits: Sequence[ops.Qid], prng: np.random.RandomState, collapse_state_vector=True
self, qubits: Sequence['cirq.Qid'], prng: np.random.RandomState, collapse_state_vector=True
) -> List[int]:
"""Performs a measurement over one or more qubits.
Expand Down Expand Up @@ -533,7 +533,7 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:

def sample(
self,
qubits: Sequence[ops.Qid],
qubits: Sequence['cirq.Qid'],
repetitions: int = 1,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
) -> np.ndarray:
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/ops/classically_controlled_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _with_rescoped_keys_(
path: Tuple[str, ...],
bindable_keys: FrozenSet['cirq.MeasurementKey'],
) -> 'ClassicallyControlledOperation':
def map_key(key: value.MeasurementKey) -> value.MeasurementKey:
def map_key(key: 'cirq.MeasurementKey') -> 'cirq.MeasurementKey':
for i in range(len(path) + 1):
back_path = path[: len(path) - i]
new_key = key.with_key_path_prefix(*back_path)
Expand All @@ -195,7 +195,7 @@ def map_key(key: value.MeasurementKey) -> value.MeasurementKey:
sub_operation = protocols.with_rescoped_keys(self._sub_operation, path, bindable_keys)
return sub_operation.with_classical_controls(*[map_key(k) for k in self._control_keys])

def _control_keys_(self) -> FrozenSet[value.MeasurementKey]:
def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
return frozenset(self._control_keys).union(protocols.control_keys(self._sub_operation))

def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]:
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,13 @@ def _measurement_key_names_(self) -> Optional[AbstractSet[str]]:
return getter()
return NotImplemented

def _measurement_key_obj_(self) -> Optional[value.MeasurementKey]:
def _measurement_key_obj_(self) -> Optional['cirq.MeasurementKey']:
getter = getattr(self.gate, '_measurement_key_obj_', None)
if getter is not None:
return getter()
return NotImplemented

def _measurement_key_objs_(self) -> Optional[AbstractSet[value.MeasurementKey]]:
def _measurement_key_objs_(self) -> Optional[AbstractSet['cirq.MeasurementKey']]:
getter = getattr(self.gate, '_measurement_key_objs_', None)
if getter is not None:
return getter()
Expand Down
13 changes: 8 additions & 5 deletions cirq-core/cirq/ops/kraus_channel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# pylint: disable=wrong-or-nonexistent-copyright-notice
from typing import Any, Dict, FrozenSet, Iterable, Tuple, Union
from typing import Any, Dict, FrozenSet, Iterable, Tuple, TYPE_CHECKING, Union
import numpy as np

from cirq import linalg, protocols, value
from cirq._compat import proper_repr
from cirq.ops import raw_types

if TYPE_CHECKING:
import cirq


# TODO(#3241): support qudits and non-square operators.
class KrausChannel(raw_types.Gate):
Expand All @@ -25,7 +28,7 @@ class KrausChannel(raw_types.Gate):
def __init__(
self,
kraus_ops: Iterable[np.ndarray],
key: Union[str, value.MeasurementKey, None] = None,
key: Union[str, 'cirq.MeasurementKey', None] = None,
validate: bool = False,
):
kraus_ops = list(kraus_ops)
Expand All @@ -52,7 +55,7 @@ def __init__(
self._key = key

@staticmethod
def from_channel(channel: 'KrausChannel', key: Union[str, value.MeasurementKey, None] = None):
def from_channel(channel: 'KrausChannel', key: Union[str, 'cirq.MeasurementKey', None] = None):
"""Creates a copy of a channel with the given measurement key."""
return KrausChannel(kraus_ops=list(protocols.kraus(channel)), key=key)

Expand All @@ -76,7 +79,7 @@ def _measurement_key_name_(self) -> str:
return NotImplemented
return str(self._key)

def _measurement_key_obj_(self) -> value.MeasurementKey:
def _measurement_key_obj_(self) -> 'cirq.MeasurementKey':
if self._key is None:
return NotImplemented
return self._key
Expand All @@ -99,7 +102,7 @@ def _with_key_path_prefix_(self, prefix: Tuple[str, ...]):
def _with_rescoped_keys_(
self,
path: Tuple[str, ...],
bindable_keys: FrozenSet[value.MeasurementKey],
bindable_keys: FrozenSet['cirq.MeasurementKey'],
):
return KrausChannel(
kraus_ops=self._kraus_ops,
Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/ops/measure_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import numpy as np

from cirq import protocols, value
from cirq import protocols
from cirq.ops import raw_types, pauli_string
from cirq.ops.measurement_gate import MeasurementGate
from cirq.ops.pauli_measurement_gate import PauliMeasurementGate
Expand All @@ -31,7 +31,7 @@ def _default_measurement_key(qubits: Iterable[raw_types.Qid]) -> str:

def measure_single_paulistring(
pauli_observable: pauli_string.PauliString,
key: Optional[Union[str, value.MeasurementKey]] = None,
key: Optional[Union[str, 'cirq.MeasurementKey']] = None,
) -> raw_types.Operation:
"""Returns a single PauliMeasurementGate which measures the pauli observable
Expand Down Expand Up @@ -83,7 +83,7 @@ def measure_paulistring_terms(

def measure(
*target: 'cirq.Qid',
key: Optional[Union[str, value.MeasurementKey]] = None,
key: Optional[Union[str, 'cirq.MeasurementKey']] = None,
invert_mask: Tuple[bool, ...] = (),
) -> raw_types.Operation:
"""Returns a single MeasurementGate applied to all the given qubits.
Expand Down
8 changes: 4 additions & 4 deletions cirq-core/cirq/ops/measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class MeasurementGate(raw_types.Gate):
def __init__(
self,
num_qubits: Optional[int] = None,
key: Union[str, value.MeasurementKey] = '',
key: Union[str, 'cirq.MeasurementKey'] = '',
invert_mask: Tuple[bool, ...] = (),
qid_shape: Tuple[int, ...] = None,
) -> None:
Expand Down Expand Up @@ -75,7 +75,7 @@ def key(self) -> str:
return str(self.mkey)

@key.setter
def key(self, key: Union[str, value.MeasurementKey]):
def key(self, key: Union[str, 'cirq.MeasurementKey']):
if isinstance(key, value.MeasurementKey):
self.mkey = key
else:
Expand All @@ -84,7 +84,7 @@ def key(self, key: Union[str, value.MeasurementKey]):
def _qid_shape_(self) -> Tuple[int, ...]:
return self._qid_shape

def with_key(self, key: Union[str, value.MeasurementKey]) -> 'MeasurementGate':
def with_key(self, key: Union[str, 'cirq.MeasurementKey']) -> 'MeasurementGate':
"""Creates a measurement gate with a new key but otherwise identical."""
if key == self.key:
return self
Expand Down Expand Up @@ -139,7 +139,7 @@ def _is_measurement_(self) -> bool:
def _measurement_key_name_(self) -> str:
return self.key

def _measurement_key_obj_(self) -> value.MeasurementKey:
def _measurement_key_obj_(self) -> 'cirq.MeasurementKey':
return self.mkey

def _kraus_(self):
Expand Down
13 changes: 8 additions & 5 deletions cirq-core/cirq/ops/mixed_unitary_channel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# pylint: disable=wrong-or-nonexistent-copyright-notice
from typing import Any, Dict, FrozenSet, Iterable, Tuple, Union
from typing import Any, Dict, FrozenSet, Iterable, Tuple, TYPE_CHECKING, Union
import numpy as np

from cirq import linalg, protocols, value
from cirq._compat import proper_repr
from cirq.ops import raw_types

if TYPE_CHECKING:
import cirq


class MixedUnitaryChannel(raw_types.Gate):
"""A generic mixture that can record the index of its selected operator.
Expand All @@ -24,7 +27,7 @@ class MixedUnitaryChannel(raw_types.Gate):
def __init__(
self,
mixture: Iterable[Tuple[float, np.ndarray]],
key: Union[str, value.MeasurementKey, None] = None,
key: Union[str, 'cirq.MeasurementKey', None] = None,
validate: bool = False,
):
mixture = list(mixture)
Expand Down Expand Up @@ -54,7 +57,7 @@ def __init__(

@staticmethod
def from_mixture(
mixture: 'protocols.SupportsMixture', key: Union[str, value.MeasurementKey, None] = None
mixture: 'protocols.SupportsMixture', key: Union[str, 'cirq.MeasurementKey', None] = None
):
"""Creates a copy of a mixture with the given measurement key."""
return MixedUnitaryChannel(mixture=list(protocols.mixture(mixture)), key=key)
Expand Down Expand Up @@ -85,7 +88,7 @@ def _measurement_key_name_(self) -> str:
return NotImplemented
return str(self._key)

def _measurement_key_obj_(self) -> value.MeasurementKey:
def _measurement_key_obj_(self) -> 'cirq.MeasurementKey':
if self._key is None:
return NotImplemented
return self._key
Expand All @@ -110,7 +113,7 @@ def _with_key_path_prefix_(self, prefix: Tuple[str, ...]):
def _with_rescoped_keys_(
self,
path: Tuple[str, ...],
bindable_keys: FrozenSet[value.MeasurementKey],
bindable_keys: FrozenSet['cirq.MeasurementKey'],
):
return MixedUnitaryChannel(
mixture=self._mixture,
Expand Down
Loading

0 comments on commit 82e6fed

Please sign in to comment.