From 5fdf76431c026b9adbe7d23c9968de1d00061a18 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Wed, 10 Nov 2021 15:48:43 -0800 Subject: [PATCH] Add parameter resolution and other utility methods to _InverseCompositeGate (#4656) --- cirq-core/cirq/ops/raw_types.py | 13 +++++++++++++ cirq-core/cirq/ops/raw_types_test.py | 27 +++++++++++++++++++++------ 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index bc35187b9cb..452d38974e3 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -819,6 +819,19 @@ def _has_unitary_(self): for op in protocols.decompose_once_with_qubits(self._original, qubits) ) + def _is_parameterized_(self) -> bool: + return protocols.is_parameterized(self._original) + + def _parameter_names_(self) -> AbstractSet[str]: + return protocols.parameter_names(self._original) + + def _resolve_parameters_( + self, resolver: 'cirq.ParamResolver', recursive: bool + ) -> '_InverseCompositeGate': + return _InverseCompositeGate( + protocols.resolve_parameters(self._original, resolver, recursive) + ) + def _value_equality_values_(self): return self._original diff --git a/cirq-core/cirq/ops/raw_types_test.py b/cirq-core/cirq/ops/raw_types_test.py index 6be9e4d5a16..afe77fdc639 100644 --- a/cirq-core/cirq/ops/raw_types_test.py +++ b/cirq-core/cirq/ops/raw_types_test.py @@ -698,6 +698,9 @@ def test_tagged_operation_resolves_parameterized_tags(resolve_fn): def test_inverse_composite_standards(): @cirq.value_equality class Gate(cirq.Gate): + def __init__(self, param: 'cirq.TParamVal'): + self._param = param + def _decompose_(self, qubits): return cirq.S.on(qubits[0]) @@ -708,14 +711,26 @@ def _has_unitary_(self): return True def _value_equality_values_(self): - return () + return (self._param,) - def __repr__(self): - return 'C()' + def _parameter_names_(self) -> AbstractSet[str]: + return cirq.parameter_names(self._param) - cirq.testing.assert_implements_consistent_protocols( - cirq.inverse(Gate()), global_vals={'C': Gate} - ) + def _is_parameterized_(self) -> bool: + return cirq.is_parameterized(self._param) + + def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> 'Gate': + return Gate(cirq.resolve_parameters(self._param, resolver, recursive)) + + def __repr__(self): + return f'C({self._param})' + + a = sympy.Symbol("a") + g = cirq.inverse(Gate(a)) + assert cirq.is_parameterized(g) + assert cirq.parameter_names(g) == {'a'} + assert cirq.resolve_parameters(g, {a: 0}) == Gate(0) ** -1 + cirq.testing.assert_implements_consistent_protocols(g, global_vals={'C': Gate, 'a': a}) def test_tagged_act_on():