Skip to content

Commit

Permalink
Use a Protocol for TRANSFORMER to ensure common arg names (#4871)
Browse files Browse the repository at this point in the history
* Use a Protocol for TRANSFORMER to ensure common arg names

Also cleans up some of the internals of the transformer decorator and
simplifies the types.

Follow-up to #4797

* Fix Protocol import for 3.7

* Fixes from review

* Add type annotations in transformer implementation

Co-authored-by: Tanuj Khattar <tanujkhattar@google.com>
  • Loading branch information
maffoo and tanujkhattar authored Jan 24, 2022
1 parent 90679d2 commit 693ceed
Showing 1 changed file with 63 additions and 64 deletions.
127 changes: 63 additions & 64 deletions cirq-core/cirq/transformers/transformer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@

"""Defines the API for circuit transformers in Cirq."""

import textwrap
import dataclasses
import enum
import functools
import textwrap
from typing import (
Any,
Callable,
Tuple,
Hashable,
List,
Type,
overload,
Type,
TYPE_CHECKING,
TypeVar,
)
import dataclasses
import enum
from cirq.circuits.circuit import CIRCUIT_TYPE
from typing_extensions import Protocol

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -218,96 +218,95 @@ class TransformerContext:
ignore_tags: Tuple[Hashable, ...] = ()


TRANSFORMER = Callable[['cirq.AbstractCircuit', TransformerContext], 'cirq.AbstractCircuit']
_TRANSFORMER_TYPE = Callable[['cirq.AbstractCircuit', TransformerContext], CIRCUIT_TYPE]


def _transform_and_log(
func: _TRANSFORMER_TYPE[CIRCUIT_TYPE],
transformer_name: str,
circuit: 'cirq.AbstractCircuit',
context: TransformerContext,
) -> CIRCUIT_TYPE:
"""Helper to log initial and final circuits before and after calling the transformer."""

context.logger.register_initial(circuit, transformer_name)
transformed_circuit = func(circuit, context)
context.logger.register_final(transformed_circuit, transformer_name)
return transformed_circuit

class TRANSFORMER(Protocol):
def __call__(
self, circuit: 'cirq.AbstractCircuit', context: TransformerContext
) -> 'cirq.AbstractCircuit':
...

def _transformer_class(
cls: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]],
) -> Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]]:
old_func = cls.__call__

def transformer_with_logging_cls(
self: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]],
circuit: 'cirq.AbstractCircuit',
context: TransformerContext,
) -> CIRCUIT_TYPE:
def call_old_func(c: 'cirq.AbstractCircuit', ct: TransformerContext) -> CIRCUIT_TYPE:
return old_func(self, c, ct)

return _transform_and_log(call_old_func, cls.__name__, circuit, context)

setattr(cls, '__call__', transformer_with_logging_cls)
return cls


def _transformer_func(func: _TRANSFORMER_TYPE[CIRCUIT_TYPE]) -> _TRANSFORMER_TYPE[CIRCUIT_TYPE]:
@functools.wraps(func)
def transformer_with_logging_func(
circuit: 'cirq.AbstractCircuit',
context: TransformerContext,
) -> CIRCUIT_TYPE:
return _transform_and_log(func, func.__name__, circuit, context)

return transformer_with_logging_func
_TRANSFORMER_T = TypeVar('_TRANSFORMER_T', bound=TRANSFORMER)
_TRANSFORMER_CLS_T = TypeVar('_TRANSFORMER_CLS_T', bound=Type[TRANSFORMER])


@overload
def transformer(cls_or_func: _TRANSFORMER_TYPE[CIRCUIT_TYPE]) -> _TRANSFORMER_TYPE[CIRCUIT_TYPE]:
def transformer(cls_or_func: _TRANSFORMER_T) -> _TRANSFORMER_T:
pass


@overload
def transformer(
cls_or_func: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]],
) -> Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]]:
def transformer(cls_or_func: _TRANSFORMER_CLS_T) -> _TRANSFORMER_CLS_T:
pass


def transformer(cls_or_func: Any) -> Any:
"""Decorator to verify API and append logging functionality to transformer functions & classes.
The decorated function or class must satisfy
`Callable[[cirq.Circuit, cirq.TransformerContext], cirq.Circuit]` API. For Example:
A transformer is a callable that takes as inputs a `cirq.AbstractCircuit` and
`cirq.TransformerContext`, and returns another `cirq.AbstractCircuit` without
modifying the input circuit. A transformer could be a function, for example:
>>> @cirq.transformer
>>> def convert_to_cz(circuit: cirq.Circuit, context: cirq.TransformerContext) -> cirq.Circuit:
>>> def convert_to_cz(
>>> circuit: cirq.AbstractCircuit, context: cirq.TransformerContext
>>> ) -> cirq.Circuit:
>>> ...
The decorated class must implement the `__call__` method to satisfy the above API.
Or it could be a class that implements `__call__` with the same API, for example:
>>> @cirq.transformer
>>> class ConvertToSqrtISwaps:
>>> def __init__(self):
>>> ...
>>> def __call__(
>>> self, circuit: cirq.Circuit, context: cirq.TransformerContext
>>> self, circuit: cirq.AbstractCircuit, context: cirq.TransformerContext
>>> ) -> cirq.Circuit:
>>> ...
Args:
cls_or_func: The callable class or method to be decorated.
cls_or_func: The callable class or function to be decorated.
Returns:
Decorated class / method which includes additional logging boilerplate. The decorated
callable always receives a copy of the input circuit so that the input is never mutated.
Decorated class / function which includes additional logging boilerplate.
"""
if isinstance(cls_or_func, type):
return _transformer_class(cls_or_func)
cls = cls_or_func
method = cls.__call__

@functools.wraps(method)
def method_with_logging(
self, circuit: 'cirq.AbstractCircuit', context: TransformerContext
) -> 'cirq.AbstractCircuit':
return _transform_and_log(
lambda circuit, context: method(self, circuit, context),
cls.__name__,
circuit,
context,
)

setattr(cls, '__call__', method_with_logging)
return cls
else:
assert callable(cls_or_func)
return _transformer_func(cls_or_func)
func = cls_or_func

@functools.wraps(func)
def func_with_logging(
circuit: 'cirq.AbstractCircuit', context: TransformerContext
) -> 'cirq.AbstractCircuit':
return _transform_and_log(func, func.__name__, circuit, context)

return func_with_logging


def _transform_and_log(
func: TRANSFORMER,
transformer_name: str,
circuit: 'cirq.AbstractCircuit',
context: TransformerContext,
) -> 'cirq.AbstractCircuit':
"""Helper to log initial and final circuits before and after calling the transformer."""
context.logger.register_initial(circuit, transformer_name)
transformed_circuit = func(circuit, context)
context.logger.register_final(transformed_circuit, transformer_name)
return transformed_circuit

0 comments on commit 693ceed

Please sign in to comment.