Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow sympy expressions as classical controls #4740

Merged
merged 50 commits into from
Dec 23, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
b76950e
Allow sympy expressions as classical controls
daxfohl Dec 9, 2021
5fdff50
Format
daxfohl Dec 9, 2021
ef081d7
move Condition to value
daxfohl Dec 9, 2021
0dd430e
Condition subclasses
daxfohl Dec 9, 2021
3fb7f75
Fix sympy resolver
daxfohl Dec 9, 2021
9568ae0
lint
daxfohl Dec 9, 2021
04fcff7
fix CCO serialization
daxfohl Dec 9, 2021
b3c344e
fix CCO serialization
daxfohl Dec 9, 2021
b8ff20a
add json reprs for conditions
daxfohl Dec 9, 2021
cbb029b
add test
daxfohl Dec 9, 2021
efce2f9
tests
daxfohl Dec 9, 2021
b34994b
tests
daxfohl Dec 9, 2021
5537397
test
daxfohl Dec 9, 2021
73ed74b
format
daxfohl Dec 9, 2021
a415235
docstrings
daxfohl Dec 9, 2021
3a0a56d
subop
daxfohl Dec 9, 2021
f930f6a
regex
daxfohl Dec 10, 2021
39a7a95
docs
daxfohl Dec 10, 2021
42bac3e
Make test_sympy more intuitive.
daxfohl Dec 10, 2021
f4ea9d8
Sympy str roundtrip
daxfohl Dec 14, 2021
71f61f5
Resolve some code review comments
daxfohl Dec 16, 2021
2261355
Add escape key to parse_sympy_condition
daxfohl Dec 16, 2021
6b36357
repr
daxfohl Dec 16, 2021
afbf3c9
coverage
daxfohl Dec 16, 2021
58fb2dc
coverage
daxfohl Dec 16, 2021
bd80c0b
parser
daxfohl Dec 17, 2021
c39a572
Improve sympy repr
daxfohl Dec 17, 2021
12d38ca
lint
daxfohl Dec 20, 2021
724febb
sympy.basic
daxfohl Dec 20, 2021
b598697
Add sympy json resolvers for comparators
daxfohl Dec 20, 2021
d167de7
_from_json_dict_
daxfohl Dec 20, 2021
72d82eb
lint
daxfohl Dec 20, 2021
b55188e
reduce fixed_tokens
daxfohl Dec 20, 2021
fd1fefb
Merge branch 'master' into sympy3
daxfohl Dec 20, 2021
f6a6645
Merge branch 'master' into sympymerge
daxfohl Dec 20, 2021
de3f887
Merge branch 'sympy3' of https://github.com/daxfohl/Cirq into sympy3
daxfohl Dec 20, 2021
ca56bd8
more tests
daxfohl Dec 20, 2021
6f8e344
format
daxfohl Dec 20, 2021
689719f
Key
daxfohl Dec 21, 2021
96ba4e9
combined test
daxfohl Dec 21, 2021
793c138
Merge remote-tracking branch 'origin/sympy3' into sympy3
daxfohl Dec 21, 2021
024faf3
fix diagram for multiple control keys
daxfohl Dec 21, 2021
ca29317
Merge branch 'master' into sympy3
daxfohl Dec 21, 2021
4bb1ed3
add diagrams
daxfohl Dec 22, 2021
d4195b8
Merge remote-tracking branch 'origin/sympy3' into sympy3
daxfohl Dec 22, 2021
8c04286
add a test for sympy scope simulation result
daxfohl Dec 22, 2021
18d2f62
better labels
daxfohl Dec 22, 2021
7a632ae
better labels
daxfohl Dec 22, 2021
780e43f
format
daxfohl Dec 22, 2021
0798973
Merge branch 'master' into sympy3
CirqBot Dec 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 85 additions & 32 deletions cirq-core/cirq/ops/classically_controlled_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
Any,
Dict,
FrozenSet,
List,
Optional,
Sequence,
TYPE_CHECKING,
Tuple,
Union,
)

import sympy

from cirq import protocols, value
from cirq.ops import raw_types

Expand All @@ -46,7 +48,7 @@ class ClassicallyControlledOperation(raw_types.Operation):
def __init__(
self,
sub_operation: 'cirq.Operation',
conditions: Sequence[Union[str, 'cirq.MeasurementKey']],
conditions: Tuple[Union[str, 'cirq.MeasurementKey', raw_types.Condition], ...],
):
"""Initializes a `ClassicallyControlledOperation`.

Expand All @@ -68,13 +70,26 @@ def __init__(
raise ValueError(
f'Cannot conditionally run operations with measurements: {sub_operation}'
)
keys = tuple(value.MeasurementKey(c) if isinstance(c, str) else c for c in conditions)
if isinstance(sub_operation, ClassicallyControlledOperation):
keys += sub_operation._control_keys
conditions += sub_operation._conditions
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
sub_operation = sub_operation._sub_operation
self._control_keys: Tuple['cirq.MeasurementKey', ...] = keys
conds: List[raw_types.Condition] = []
for c in conditions:
if isinstance(c, str):
c1 = parse_condition(c) or value.MeasurementKey.parse_serialized(c)
if c1 is None:
raise ValueError(f"'{c}' is not a valid condition")
c = c1
if isinstance(c, value.MeasurementKey):
c = raw_types.Condition(sympy.sympify('x0'), (c,))
conds.append(c)
self._conditions = tuple(conds)
self._sub_operation: 'cirq.Operation' = sub_operation

@property
def classical_controls(self) -> FrozenSet[raw_types.Condition]:
return frozenset(self._conditions).union(self._sub_operation.classical_controls)

def without_classical_controls(self) -> 'cirq.Operation':
return self._sub_operation.without_classical_controls()

Expand All @@ -84,27 +99,27 @@ def qubits(self):

def with_qubits(self, *new_qubits):
return self._sub_operation.with_qubits(*new_qubits).with_classical_controls(
*self._control_keys
*self._conditions
)

def _decompose_(self):
result = protocols.decompose_once(self._sub_operation, NotImplemented)
if result is NotImplemented:
return NotImplemented

return [ClassicallyControlledOperation(op, self._control_keys) for op in result]
return [ClassicallyControlledOperation(op, self._conditions) for op in result]

def _value_equality_values_(self):
return (frozenset(self._control_keys), self._sub_operation)
return (frozenset(self._conditions), self._sub_operation)

def __str__(self) -> str:
keys = ', '.join(map(str, self._control_keys))
keys = ', '.join(map(str, self._conditions))
return f'{self._sub_operation}.with_classical_controls({keys})'

def __repr__(self):
return (
f'cirq.ClassicallyControlledOperation('
f'{self._sub_operation!r}, {list(self._control_keys)!r})'
f'{self._sub_operation!r}, {list(self._conditions)!r})'
)

def _is_parameterized_(self) -> bool:
Expand All @@ -117,7 +132,7 @@ def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'ClassicallyControlledOperation':
new_sub_op = protocols.resolve_parameters(self._sub_operation, resolver, recursive)
return new_sub_op.with_classical_controls(*self._control_keys)
return new_sub_op.with_classical_controls(*self._conditions)

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
Expand All @@ -133,12 +148,12 @@ def _circuit_diagram_info_(
if sub_info is None:
return NotImplemented # coverage: ignore

wire_symbols = sub_info.wire_symbols + ('^',) * len(self._control_keys)
wire_symbols = sub_info.wire_symbols + ('^',) * len(self._conditions)
exponent_qubit_index = None
if sub_info.exponent_qubit_index is not None:
exponent_qubit_index = sub_info.exponent_qubit_index + len(self._control_keys)
exponent_qubit_index = sub_info.exponent_qubit_index + len(self._conditions)
elif sub_info.exponent is not None:
exponent_qubit_index = len(self._control_keys)
exponent_qubit_index = len(self._conditions)
return protocols.CircuitDiagramInfo(
wire_symbols=wire_symbols,
exponent=sub_info.exponent,
Expand All @@ -148,39 +163,77 @@ def _circuit_diagram_info_(
def _json_dict_(self) -> Dict[str, Any]:
return {
'cirq_type': self.__class__.__name__,
'conditions': self._control_keys,
'conditions': self._conditions,
'sub_operation': self._sub_operation,
}

def _act_on_(self, args: 'cirq.ActOnArgs') -> bool:
def not_zero(measurement):
return any(i != 0 for i in measurement)

measurements = [
args.log_of_measurement_results.get(str(key), str(key)) for key in self._control_keys
]
missing = [m for m in measurements if isinstance(m, str)]
if missing:
raise ValueError(f'Measurement keys {missing} missing when performing {self}')
if all(not_zero(measurement) for measurement in measurements):
protocols.act_on(self._sub_operation, args)
for condition in self._conditions:
keys, expr = condition.keys, condition.expr
missing = [str(k) for k in keys if str(k) not in args.log_of_measurement_results]
if missing:
raise ValueError(f'Measurement keys {missing} missing when performing {self}')
replacements = {
f'x{i}': args.log_of_measurement_results[str(k)][0] for i, k in enumerate(keys)
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
}
result = expr.subs(replacements)
if not result:
return True
protocols.act_on(self._sub_operation, args)
return True

def _with_measurement_key_mapping_(
self, key_map: Dict[str, str]
) -> 'ClassicallyControlledOperation':
keys = [protocols.with_measurement_key_mapping(k, key_map) for k in self._control_keys]
return self._sub_operation.with_classical_controls(*keys)
def map_condition(condition: raw_types.Condition) -> raw_types.Condition:
keys = [protocols.with_measurement_key_mapping(k, key_map) for k in condition.keys]
return condition.with_keys(tuple(keys))
conditions = [map_condition(c) for c in self._conditions]
return self._sub_operation.with_classical_controls(*conditions)
daxfohl marked this conversation as resolved.
Show resolved Hide resolved

def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> 'ClassicallyControlledOperation':
keys = [protocols.with_key_path_prefix(k, path) for k in self._control_keys]
return self._sub_operation.with_classical_controls(*keys)
def map_condition(condition: raw_types.Condition) -> raw_types.Condition:
keys = tuple(protocols.with_key_path_prefix(k, path) for k in condition.keys)
return condition.with_keys(keys)
conditions = [map_condition(c) for c in self._conditions]
return self._sub_operation.with_classical_controls(*conditions)
daxfohl marked this conversation as resolved.
Show resolved Hide resolved

def _control_keys_(self) -> FrozenSet[value.MeasurementKey]:
return frozenset(self._control_keys).union(protocols.control_keys(self._sub_operation))
local_keys = frozenset()
for condition in self._conditions:
local_keys = local_keys.union(condition.keys)
return local_keys.union(protocols.control_keys(self._sub_operation))

def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]:
args.validate_version('2.0')
keys = [f'm_{key}!=0' for key in self._control_keys]
keys = [f'm_{key}!=0' for key in self._conditions]
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
all_keys = " && ".join(keys)
return args.format('if ({0}) {1}', all_keys, protocols.qasm(self._sub_operation, args=args))


def parse_condition(s: str) -> Optional[raw_types.Condition]:
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
in_key = False
key_count = 0
s_out = ''
key_name = ''
keys = []
for c in s:
if not in_key:
if c == '{':
in_key = True
else:
s_out += c
else:
if c == '}':
symbol_name = f'x{key_count}'
s_out += symbol_name
keys.append(value.MeasurementKey.parse_serialized(key_name))
key_name = ''
key_count += 1
in_key = False
else:
key_name += c
expr = sympy.sympify(s_out)
if len(expr.free_symbols) != len(keys):
return None
return raw_types.Condition(expr, tuple(keys))
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def test_key_set_in_subcircuit_outer_scope():
def test_condition_flattening():
q0 = cirq.LineQubit(0)
op = cirq.X(q0).with_classical_controls('a').with_classical_controls('b')
assert set(map(str, op._control_keys)) == {'a', 'b'}
assert set(map(str, op.classical_controls)) == {'a', 'b'}
assert isinstance(op._sub_operation, cirq.GateOperation)


Expand Down
41 changes: 40 additions & 1 deletion cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Callable,
Collection,
Dict,
FrozenSet,
Hashable,
Iterable,
List,
Expand All @@ -34,6 +35,7 @@
)

import numpy as np
import sympy

from cirq import protocols, value
from cirq._import import LazyLoader
Expand Down Expand Up @@ -421,6 +423,35 @@ def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, attribute_names=[])


class Condition:
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, expr: sympy.Expr, keys: Tuple[value.MeasurementKey, ...]):
self._expr = expr
self._keys = keys

@property
def keys(self):
return self._keys

@property
def expr(self):
return self._expr

def with_keys(self, keys: Tuple[value.MeasurementKey, ...]):
assert len(keys) == len(self._keys)
return Condition(self._expr, keys)

def __eq__(self, other):
return isinstance(other, Condition) and self._keys == other._keys and self._expr == other._expr

def __hash__(self):
return hash(self._keys) ^ hash(self._expr)

def __str__(self):
if self._expr == sympy.symbols('x0') and len(self._keys) == 1:
return str(self._keys[0])
return f'({self._expr}, {self._keys})'


TSelf = TypeVar('TSelf', bound='Operation')


Expand Down Expand Up @@ -590,8 +621,12 @@ def _commutes_(

return np.allclose(m12, m21, atol=atol)

@property
def classical_controls(self) -> FrozenSet[Condition]:
return frozenset()

def with_classical_controls(
self, *conditions: Union[str, 'cirq.MeasurementKey']
self, *conditions: Union[str, 'cirq.MeasurementKey', Condition]
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
) -> 'cirq.ClassicallyControlledOperation':
"""Returns a classically controlled version of this operation.

Expand Down Expand Up @@ -821,6 +856,10 @@ def _equal_up_to_global_phase_(
) -> Union[NotImplementedType, bool]:
return protocols.equal_up_to_global_phase(self.sub_operation, other, atol=atol)

@property
def classical_controls(self) -> FrozenSet[Condition]:
return self.sub_operation.classical_controls

def without_classical_controls(self) -> 'cirq.Operation':
new_sub_operation = self.sub_operation.without_classical_controls()
return self if new_sub_operation is self.sub_operation else new_sub_operation
Expand Down