Skip to content

Commit

Permalink
Document CIRCUIT_TYPE and hide other typevars/aliases in circuits.py (q…
Browse files Browse the repository at this point in the history
…uantumlib#5229)

Fixes quantumlib#5150 (assuming this renders nicely on the docsite; how can I check that locally?)

This adds an underscore prefix to hide some type aliases and type vars that are not part of the public interface of the module. Also adds a docstring to the `CIRCUIT_TYPE` variable, which is used in a few other places.
  • Loading branch information
maffoo authored and rht committed May 1, 2023
1 parent d7f7e42 commit b0e595c
Showing 1 changed file with 38 additions and 17 deletions.
55 changes: 38 additions & 17 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

import cirq._version
from cirq import _compat, devices, ops, protocols, qis
from cirq._doc import document
from cirq.circuits._bucket_priority_queue import BucketPriorityQueue
from cirq.circuits.circuit_operation import CircuitOperation
from cirq.circuits.insert_strategy import InsertStrategy
Expand All @@ -65,9 +66,30 @@
if TYPE_CHECKING:
import cirq

T_DESIRED_GATE_TYPE = TypeVar('T_DESIRED_GATE_TYPE', bound='ops.Gate')

_TGate = TypeVar('_TGate', bound='cirq.Gate')

CIRCUIT_TYPE = TypeVar('CIRCUIT_TYPE', bound='AbstractCircuit')
INT_TYPE = Union[int, np.integer]
document(
CIRCUIT_TYPE,
"""Type variable for an AbstractCircuit.
This can be used when writing generic functions that operate on circuits.
For example, suppose we define the following function:
def foo(circuit: CIRCUIT_TYPE) -> CIRCUIT_TYPE:
...
This lets the typechecker know that this function takes any kind of circuit
and returns the same type, that is, if passed a `cirq.Circuit` it will return
`cirq.Circuit`, and similarly if passed `cirq.FrozenCircuit` it will return
`cirq.FrozenCircuit`. This is particularly useful for things like the
transformer API, since it can preserve more type information than if we typed
the function as taking and returning `cirq.AbstractCircuit`.
""",
)

_INT_TYPE = Union[int, np.integer]

_DEVICE_DEP_MESSAGE = 'Attaching devices to circuits will no longer be supported.'

Expand Down Expand Up @@ -752,8 +774,8 @@ def findall_operations(
yield index, op

def findall_operations_with_gate_type(
self, gate_type: Type[T_DESIRED_GATE_TYPE]
) -> Iterable[Tuple[int, 'cirq.GateOperation', T_DESIRED_GATE_TYPE]]:
self, gate_type: Type[_TGate]
) -> Iterable[Tuple[int, 'cirq.GateOperation', _TGate]]:
"""Find the locations of all gate operations of a given type.
Args:
Expand All @@ -767,7 +789,7 @@ def findall_operations_with_gate_type(
result = self.findall_operations(lambda operation: isinstance(operation.gate, gate_type))
for index, op in result:
gate_op = cast(ops.GateOperation, op)
yield index, gate_op, cast(T_DESIRED_GATE_TYPE, gate_op.gate)
yield index, gate_op, cast(_TGate, gate_op.gate)

def has_measurements(self):
return protocols.is_measurement(self)
Expand Down Expand Up @@ -1818,20 +1840,20 @@ def __radd__(self, other):
# Needed for numpy to handle multiplication by np.int64 correctly.
__array_priority__ = 10000

def __imul__(self, repetitions: INT_TYPE):
def __imul__(self, repetitions: _INT_TYPE):
if not isinstance(repetitions, (int, np.integer)):
return NotImplemented
self._moments *= int(repetitions)
return self

def __mul__(self, repetitions: INT_TYPE):
def __mul__(self, repetitions: _INT_TYPE):
if not isinstance(repetitions, (int, np.integer)):
return NotImplemented
if self._device == cirq.UNCONSTRAINED_DEVICE:
return Circuit(self._moments * int(repetitions))
return Circuit(self._moments * int(repetitions), device=self._device)

def __rmul__(self, repetitions: INT_TYPE):
def __rmul__(self, repetitions: _INT_TYPE):
if not isinstance(repetitions, (int, np.integer)):
return NotImplemented
return self * int(repetitions)
Expand Down Expand Up @@ -2750,27 +2772,26 @@ def _list_repr_with_indented_item_lines(items: Sequence[Any]) -> str:
return f'[\n{indented}\n]'


TIn = TypeVar('TIn')
TOut = TypeVar('TOut')
TKey = TypeVar('TKey')
_TIn = TypeVar('_TIn')
_TOut = TypeVar('_TOut')
_TKey = TypeVar('_TKey')


@overload
def _group_until_different(
items: Iterable[TIn],
key: Callable[[TIn], TKey],
) -> Iterable[Tuple[TKey, List[TIn]]]:
items: Iterable[_TIn], key: Callable[[_TIn], _TKey]
) -> Iterable[Tuple[_TKey, List[_TIn]]]:
pass


@overload
def _group_until_different(
items: Iterable[TIn], key: Callable[[TIn], TKey], val: Callable[[TIn], TOut]
) -> Iterable[Tuple[TKey, List[TOut]]]:
items: Iterable[_TIn], key: Callable[[_TIn], _TKey], val: Callable[[_TIn], _TOut]
) -> Iterable[Tuple[_TKey, List[_TOut]]]:
pass


def _group_until_different(items: Iterable[TIn], key: Callable[[TIn], TKey], val=lambda e: e):
def _group_until_different(items: Iterable[_TIn], key: Callable[[_TIn], _TKey], val=lambda e: e):
"""Groups runs of items that are identical according to a keying function.
Args:
Expand Down

0 comments on commit b0e595c

Please sign in to comment.