Skip to content

Commit

Permalink
Add Transformer API Interface and @cirq.transformer decorator (quan…
Browse files Browse the repository at this point in the history
…tumlib#4797)

Defines `TRANSFORMER_TYPE` and Implements the `@transformer` decorator, as proposed in https://tinyurl.com/cirq-circuit-transformers-api

All existing transformers will be rewritten to follow the new API once this is checked-in. 
Implementation of `TransformerStatsLoggerBase` will follow in a separate PR. 

Part of quantumlib#4483

cc @maffoo PTAL at all the mypy magic.
  • Loading branch information
tanujkhattar authored Jan 21, 2022
1 parent 3b7ab6a commit 9d267da
Show file tree
Hide file tree
Showing 5 changed files with 564 additions and 0 deletions.
4 changes: 4 additions & 0 deletions cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,11 @@
single_qubit_matrix_to_phased_x_z,
single_qubit_matrix_to_phxz,
single_qubit_op_to_framed_phase_form,
TRANSFORMER,
TransformerContext,
TransformerLogger,
three_qubit_matrix_to_operations,
transformer,
two_qubit_matrix_to_diagonal_and_operations,
two_qubit_matrix_to_operations,
two_qubit_matrix_to_sqrt_iswap_operations,
Expand Down
4 changes: 4 additions & 0 deletions cirq/protocols/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@
'MergeSingleQubitGates',
'PointOptimizer',
'SynchronizeTerminalMeasurements',
# Transformers
'TransformerLogger',
'TransformerContext',
# global objects
'CONTROL_TAG',
'PAULI_BASIS',
Expand Down Expand Up @@ -172,6 +175,7 @@
'Sweepable',
'TParamKey',
'TParamVal',
'TRANSFORMER',
'ParamDictType',
# utility:
'CliffordSimulator',
Expand Down
8 changes: 8 additions & 0 deletions cirq/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@
two_qubit_gate_product_tabulation,
)

from cirq.transformers.transformer_api import (
LogLevel,
TRANSFORMER,
TransformerContext,
TransformerLogger,
transformer,
)

from cirq.transformers.transformer_primitives import (
map_moments,
map_operations,
Expand Down
313 changes: 313 additions & 0 deletions cirq/transformers/transformer_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
# Copyright 2022 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines the API for circuit transformers in Cirq."""

import textwrap
import functools
from typing import (
Any,
Callable,
Tuple,
Hashable,
List,
Type,
overload,
TYPE_CHECKING,
)
import dataclasses
import enum
from cirq.circuits.circuit import CIRCUIT_TYPE

if TYPE_CHECKING:
import cirq


class LogLevel(enum.Enum):
"""Different logging resolution options for `cirq.TransformerLogger`.
The enum values of the logging levels are used to filter the stored logs when printing.
In general, a logging level `X` includes all logs stored at a level >= 'X'.
Args:
ALL: All levels. Used to filter logs when printing.
DEBUG: Designates fine-grained informational events that are most useful to debug /
understand in-depth any unexpected behavior of the transformer.
INFO: Designates informational messages that highlight the actions of a transformer.
WARNING: Designates unwanted or potentially harmful situations.
NONE: No levels. Used to filter logs when printing.
"""

ALL = 0
DEBUG = 10
INFO = 20
WARNING = 30
NONE = 40


@dataclasses.dataclass
class _LoggerNode:
"""Stores logging data of a single transformer stage in `cirq.TransformerLogger`.
The class is used to define a logging graph to store logs of sequential or nested transformers.
Each node corresponds to logs of a single transformer stage.
Args:
transformer_id: Integer specifying a unique id for corresponding transformer stage.
transformer_name: Name of the corresponding transformer stage.
initial_circuit: Initial circuit before the transformer stage began.
final_circuit: Final circuit after the transformer stage ended.
logs: Messages logged by the transformer stage.
nested_loggers: `transformer_id`s of nested transformer stages which were called by
the current stage.
"""

transformer_id: int
transformer_name: str
initial_circuit: 'cirq.AbstractCircuit'
final_circuit: 'cirq.AbstractCircuit'
logs: List[Tuple[LogLevel, Tuple[str, ...]]] = dataclasses.field(default_factory=list)
nested_loggers: List[int] = dataclasses.field(default_factory=list)


class TransformerLogger:
"""Base Class for transformer logging infrastructure. Defaults to text-based logging.
The logger implementation should be stateful, s.t.:
- Each call to `register_initial` registers a new transformer stage and initial circuit.
- Each subsequent call to `log` should store additional logs corresponding to the stage.
- Each call to `register_final` should register the end of the currently active stage.
The logger assumes that
- Transformers are run sequentially.
- Nested transformers are allowed, in which case the behavior would be similar to a
doing a depth first search on the graph of transformers -- i.e. the top level transformer
would end (i.e. receive a `register_final` call) once all nested transformers (i.e. all
`register_initial` calls received while the top level transformer was active) have
finished (i.e. corresponding `register_final` calls have also been received).
- This behavior can be simulated by maintaining a stack of currently active stages and
adding data from `log` calls to the stage at the top of the stack.
The `LogLevel`s can be used to control the input processing and output resolution of the logs.
"""

def __init__(self):
"""Initializes TransformerLogger."""
self._curr_id: int = 0
self._logs: List[_LoggerNode] = []
self._stack: List[int] = []

def register_initial(self, circuit: 'cirq.AbstractCircuit', transformer_name: str) -> None:
"""Register the beginning of a new transformer stage.
Args:
circuit: Input circuit to the new transformer stage.
transformer_name: Name of the new transformer stage.
"""
if self._stack:
self._logs[self._stack[-1]].nested_loggers.append(self._curr_id)
self._logs.append(_LoggerNode(self._curr_id, transformer_name, circuit, circuit))
self._stack.append(self._curr_id)
self._curr_id += 1

def log(self, *args: str, level: LogLevel = LogLevel.INFO) -> None:
"""Log additional metadata corresponding to the currently active transformer stage.
Args:
*args: The additional metadata to log.
level: Logging level to control the amount of metadata that gets put into the context.
Raises:
ValueError: If there's no active transformer on the stack.
"""
if len(self._stack) == 0:
raise ValueError('No active transformer found.')
self._logs[self._stack[-1]].logs.append((level, args))

def register_final(self, circuit: 'cirq.AbstractCircuit', transformer_name: str) -> None:
"""Register the end of the currently active transformer stage.
Args:
circuit: Final transformed output circuit from the transformer stage.
transformer_name: Name of the (currently active) transformer stage which ends.
Raises:
ValueError: If `transformer_name` is different from currently active transformer name.
"""
tid = self._stack.pop()
if self._logs[tid].transformer_name != transformer_name:
raise ValueError(
f"Expected `register_final` call for currently active transformer "
f"{self._logs[tid].transformer_name}."
)
self._logs[tid].final_circuit = circuit

def show(self, level: LogLevel = LogLevel.INFO) -> None:
"""Show the stored logs >= level in the desired format.
Args:
level: The logging level to filter the logs with. The method shows all logs with a
`LogLevel` >= `level`.
"""

def print_log(log: _LoggerNode, pad=''):
print(pad, f"Transformer-{1+log.transformer_id}: {log.transformer_name}", sep='')
print(pad, "Initial Circuit:", sep='')
print(textwrap.indent(str(log.initial_circuit), pad), "\n", sep='')
for log_level, log_text in log.logs:
if log_level.value >= level.value:
print(pad, log_level, *log_text)
print("\n", pad, "Final Circuit:", sep='')
print(textwrap.indent(str(log.final_circuit), pad))
print("----------------------------------------")

done = [0] * self._curr_id
for i in range(self._curr_id):
# Iterative DFS.
stack = [(i, '')] if not done[i] else []
while len(stack) > 0:
log_id, pad = stack.pop()
print_log(self._logs[log_id], pad)
done[log_id] = True
for child_id in self._logs[log_id].nested_loggers[::-1]:
stack.append((child_id, pad + ' ' * 4))


class NoOpTransformerLogger(TransformerLogger):
"""All calls to this logger are a no-op"""

def register_initial(self, circuit: 'cirq.AbstractCircuit', transformer_name: str) -> None:
pass

def log(self, *args: str, level: LogLevel = LogLevel.INFO) -> None:
pass

def register_final(self, circuit: 'cirq.AbstractCircuit', transformer_name: str) -> None:
pass

def show(self, level: LogLevel = LogLevel.INFO) -> None:
pass


@dataclasses.dataclass()
class TransformerContext:
"""Stores common configurable options for transformers.
Args:
logger: `cirq.TransformerLogger` instance, which is a stateful logger used for logging
the actions of individual transformer stages. The same logger instance should be
shared across different transformer calls.
ignore_tags: Tuple of tags which should be ignored while applying transformations on a
circuit. Transformers should not transform any operation marked with a tag that
belongs to this tuple. Note that any instance of a Hashable type (like `str`,
`cirq.VirtualTag` etc.) is a valid tag.
"""

logger: TransformerLogger = NoOpTransformerLogger()
ignore_tags: Tuple[Hashable, ...] = ()


TRANSFORMER = Callable[['cirq.AbstractCircuit', TransformerContext], 'cirq.AbstractCircuit']
_TRANSFORMER_TYPE = Callable[['cirq.AbstractCircuit', TransformerContext], CIRCUIT_TYPE]


def _transform_and_log(
func: _TRANSFORMER_TYPE[CIRCUIT_TYPE],
transformer_name: str,
circuit: 'cirq.AbstractCircuit',
context: TransformerContext,
) -> CIRCUIT_TYPE:
"""Helper to log initial and final circuits before and after calling the transformer."""

context.logger.register_initial(circuit, transformer_name)
transformed_circuit = func(circuit, context)
context.logger.register_final(transformed_circuit, transformer_name)
return transformed_circuit


def _transformer_class(
cls: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]],
) -> Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]]:
old_func = cls.__call__

def transformer_with_logging_cls(
self: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]],
circuit: 'cirq.AbstractCircuit',
context: TransformerContext,
) -> CIRCUIT_TYPE:
def call_old_func(c: 'cirq.AbstractCircuit', ct: TransformerContext) -> CIRCUIT_TYPE:
return old_func(self, c, ct)

return _transform_and_log(call_old_func, cls.__name__, circuit, context)

setattr(cls, '__call__', transformer_with_logging_cls)
return cls


def _transformer_func(func: _TRANSFORMER_TYPE[CIRCUIT_TYPE]) -> _TRANSFORMER_TYPE[CIRCUIT_TYPE]:
@functools.wraps(func)
def transformer_with_logging_func(
circuit: 'cirq.AbstractCircuit',
context: TransformerContext,
) -> CIRCUIT_TYPE:
return _transform_and_log(func, func.__name__, circuit, context)

return transformer_with_logging_func


@overload
def transformer(cls_or_func: _TRANSFORMER_TYPE[CIRCUIT_TYPE]) -> _TRANSFORMER_TYPE[CIRCUIT_TYPE]:
pass


@overload
def transformer(
cls_or_func: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]],
) -> Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]]:
pass


def transformer(cls_or_func: Any) -> Any:
"""Decorator to verify API and append logging functionality to transformer functions & classes.
The decorated function or class must satisfy
`Callable[[cirq.Circuit, cirq.TransformerContext], cirq.Circuit]` API. For Example:
>>> @cirq.transformer
>>> def convert_to_cz(circuit: cirq.Circuit, context: cirq.TransformerContext) -> cirq.Circuit:
>>> ...
The decorated class must implement the `__call__` method to satisfy the above API.
>>> @cirq.transformer
>>> class ConvertToSqrtISwaps:
>>> def __init__(self):
>>> ...
>>> def __call__(
>>> self, circuit: cirq.Circuit, context: cirq.TransformerContext
>>> ) -> cirq.Circuit:
>>> ...
Args:
cls_or_func: The callable class or method to be decorated.
Returns:
Decorated class / method which includes additional logging boilerplate. The decorated
callable always receives a copy of the input circuit so that the input is never mutated.
"""
if isinstance(cls_or_func, type):
return _transformer_class(cls_or_func)
else:
assert callable(cls_or_func)
return _transformer_func(cls_or_func)
Loading

0 comments on commit 9d267da

Please sign in to comment.