Skip to content

Commit

Permalink
Add parameter resolution and other utility methods to _InverseComposi…
Browse files Browse the repository at this point in the history
…teGate (#4656)
  • Loading branch information
tanujkhattar authored Nov 10, 2021
1 parent c5477cb commit 5fdf764
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
13 changes: 13 additions & 0 deletions cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,19 @@ def _has_unitary_(self):
for op in protocols.decompose_once_with_qubits(self._original, qubits)
)

def _is_parameterized_(self) -> bool:
return protocols.is_parameterized(self._original)

def _parameter_names_(self) -> AbstractSet[str]:
return protocols.parameter_names(self._original)

def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> '_InverseCompositeGate':
return _InverseCompositeGate(
protocols.resolve_parameters(self._original, resolver, recursive)
)

def _value_equality_values_(self):
return self._original

Expand Down
27 changes: 21 additions & 6 deletions cirq-core/cirq/ops/raw_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,9 @@ def test_tagged_operation_resolves_parameterized_tags(resolve_fn):
def test_inverse_composite_standards():
@cirq.value_equality
class Gate(cirq.Gate):
def __init__(self, param: 'cirq.TParamVal'):
self._param = param

def _decompose_(self, qubits):
return cirq.S.on(qubits[0])

Expand All @@ -708,14 +711,26 @@ def _has_unitary_(self):
return True

def _value_equality_values_(self):
return ()
return (self._param,)

def __repr__(self):
return 'C()'
def _parameter_names_(self) -> AbstractSet[str]:
return cirq.parameter_names(self._param)

cirq.testing.assert_implements_consistent_protocols(
cirq.inverse(Gate()), global_vals={'C': Gate}
)
def _is_parameterized_(self) -> bool:
return cirq.is_parameterized(self._param)

def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> 'Gate':
return Gate(cirq.resolve_parameters(self._param, resolver, recursive))

def __repr__(self):
return f'C({self._param})'

a = sympy.Symbol("a")
g = cirq.inverse(Gate(a))
assert cirq.is_parameterized(g)
assert cirq.parameter_names(g) == {'a'}
assert cirq.resolve_parameters(g, {a: 0}) == Gate(0) ** -1
cirq.testing.assert_implements_consistent_protocols(g, global_vals={'C': Gate, 'a': a})


def test_tagged_act_on():
Expand Down

0 comments on commit 5fdf764

Please sign in to comment.