diff --git a/cirq-core/cirq/_compat.py b/cirq-core/cirq/_compat.py index 39e0891cc77..5993cbad395 100644 --- a/cirq-core/cirq/_compat.py +++ b/cirq-core/cirq/_compat.py @@ -13,6 +13,7 @@ # limitations under the License. """Workarounds for compatibility issues between versions and libraries.""" +import contextlib import dataclasses import functools import importlib @@ -649,3 +650,23 @@ def __getattr__(self, name): return getattr(parent_module, name) sys.modules[old_parent] = Wrapped(parent_module.__name__, parent_module.__doc__) + + +@contextlib.contextmanager +def block_overlapping_deprecation(match_regex: str): + """Context to block deprecation warnings raised within it. + + Useful if a function call might raise more than one warning, + where only one warning is desired. + + Args: + match_regex: DeprecationWarnings with message fields matching + match_regex will be blocked. + """ + with warnings.catch_warnings(): + warnings.filterwarnings( + action='ignore', + category=DeprecationWarning, + message=f'(.|\n)*{match_regex}(.|\n)*', + ) + yield diff --git a/cirq-core/cirq/_compat_test.py b/cirq-core/cirq/_compat_test.py index c2314c3e9f5..cbe509bd71c 100644 --- a/cirq-core/cirq/_compat_test.py +++ b/cirq-core/cirq/_compat_test.py @@ -34,6 +34,7 @@ import cirq.testing from cirq._compat import ( + block_overlapping_deprecation, proper_repr, dataclass_repr, deprecated, @@ -887,3 +888,17 @@ def _dir_is_still_valid_inner(): for m in ['fake_a', 'info', 'module_a', 'sys']: assert m in dir(mod) + + +def test_block_overlapping_deprecation(): + @deprecated(fix="Don't use g.", deadline="v1000.0") + def g(y): + return y - 4 + + @deprecated(fix="Don't use f.", deadline="v1000.0") + def f(x): + with block_overlapping_deprecation('g'): + return [g(i + 1) for i in range(x)] + + with cirq.testing.assert_deprecated('f', deadline='v1000.0', count=1): + f(5) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index ba37f9d24c4..16167cb4ec2 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -20,13 +20,11 @@ """ import abc -import contextlib import enum import html import itertools import math import re -import warnings from collections import defaultdict from typing import ( AbstractSet, @@ -73,17 +71,6 @@ _DEVICE_DEP_MESSAGE = 'Attaching devices to circuits will no longer be supported.' -@contextlib.contextmanager -def _block_overlapping_deprecation(): - with warnings.catch_warnings(): - warnings.filterwarnings( - action='ignore', - category=DeprecationWarning, - message=f'(.|\n)*{re.escape(_DEVICE_DEP_MESSAGE)}(.|\n)*', - ) - yield - - class Alignment(enum.Enum): # Stop when left ends are lined up. LEFT = 1 @@ -1890,7 +1877,7 @@ def with_device( Returns: The translated circuit. """ - with _block_overlapping_deprecation(): + with _compat.block_overlapping_deprecation(re.escape(_DEVICE_DEP_MESSAGE)): return Circuit( [ ops.Moment( @@ -1958,7 +1945,7 @@ def transform_qubits( if new_device is None and self._device == devices.UNCONSTRAINED_DEVICE: return Circuit(op_list) - with _block_overlapping_deprecation(): + with _compat.block_overlapping_deprecation(re.escape(_DEVICE_DEP_MESSAGE)): return Circuit(op_list, device=self._device if new_device is None else new_device) def _prev_moment_available(self, op: 'cirq.Operation', end_moment_index: int) -> Optional[int]: @@ -2377,7 +2364,7 @@ def _resolve_parameters_( resolved_moments.append(new_moment) if self._device == devices.UNCONSTRAINED_DEVICE: return Circuit(resolved_moments) - with _block_overlapping_deprecation(): + with _compat.block_overlapping_deprecation(re.escape(_DEVICE_DEP_MESSAGE)): return Circuit(resolved_moments, device=self._device) @property diff --git a/cirq-core/cirq/circuits/circuit_dag.py b/cirq-core/cirq/circuits/circuit_dag.py index e5c10904e09..4854d532539 100644 --- a/cirq-core/cirq/circuits/circuit_dag.py +++ b/cirq-core/cirq/circuits/circuit_dag.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import re from typing import Any, Callable, Dict, Generic, Iterator, TypeVar, cast, TYPE_CHECKING import functools @@ -138,7 +138,7 @@ def from_ops( if device == devices.UNCONSTRAINED_DEVICE: dag = CircuitDag(can_reorder=can_reorder) else: - with circuit._block_overlapping_deprecation(): + with _compat.block_overlapping_deprecation(re.escape(circuit._DEVICE_DEP_MESSAGE)): dag = CircuitDag(can_reorder=can_reorder, device=device) for op in ops.flatten_op_tree(operations): diff --git a/cirq-core/cirq/circuits/frozen_circuit.py b/cirq-core/cirq/circuits/frozen_circuit.py index be1d6da52d3..91a46c5fb6e 100644 --- a/cirq-core/cirq/circuits/frozen_circuit.py +++ b/cirq-core/cirq/circuits/frozen_circuit.py @@ -24,8 +24,6 @@ Tuple, Union, ) -import contextlib -import warnings import re from cirq.circuits import AbstractCircuit, Alignment, Circuit @@ -44,17 +42,6 @@ _DEVICE_DEP_MESSAGE = 'Attaching devices to circuits will no longer be supported.' -@contextlib.contextmanager -def _block_overlapping_deprecation(): - with warnings.catch_warnings(): - warnings.filterwarnings( - action='ignore', - category=DeprecationWarning, - message=f'(.|\n)*{re.escape(_DEVICE_DEP_MESSAGE)}(.|\n)*', - ) - yield - - class FrozenCircuit(AbstractCircuit, protocols.SerializableByKey): """An immutable version of the Circuit data structure. @@ -91,7 +78,7 @@ def __init__( if device == devices.UNCONSTRAINED_DEVICE: base = Circuit(contents, strategy=strategy) else: - with _block_overlapping_deprecation(): + with _compat.block_overlapping_deprecation(re.escape(_DEVICE_DEP_MESSAGE)): base = Circuit(contents, strategy=strategy, device=device) self._moments = tuple(base.moments) @@ -225,7 +212,7 @@ def with_device( new_device: 'cirq.Device', qubit_mapping: Callable[['cirq.Qid'], 'cirq.Qid'] = lambda e: e, ) -> 'FrozenCircuit': - with _block_overlapping_deprecation(): + with _compat.block_overlapping_deprecation(re.escape(_DEVICE_DEP_MESSAGE)): return self.unfreeze().with_device(new_device, qubit_mapping).freeze() def _resolve_parameters_(