Skip to content

Commit

Permalink
Ensure subcircuit.mapped_circuit applies parent path if no repetitions (
Browse files Browse the repository at this point in the history
#4619)

Fixes #4618

Adds protocol `with_key_path_prefix` and uses this to correctly implement `mapped_circuit`, as outlined in the bug:

> Currently the logic in mapped_circuit applies both the parent path and repetition key in a single step, and only if there's repetitions. Thus it misses the parent path if there's no repetitions. What we need is for mapped_circuit to apply the repetition prefix first (if there's repetitions), and then later apply the parent_path prefix independently, regardless of whether there are repetitions.
> 
> Pretty straightforward, but the complicating issue is that with_key_path replaces the path rather than prefixing the existing one, so that in the above solution, the second step would overwrite the first. So likely the fix for this will require creation and use of a new with_key_path_prefix protocol. This should also be fairly straightforward, but will require this protocol implementation in a number of places. (Open question of whether replacing the path is ever required. All use cases that come to mind are exclusively prefixing a path).

Note this takes the simplification provided in #4611 as a baseline. We can either commit that first and then rebase this on top, or just kill that PR and replace it outright with this one.
  • Loading branch information
daxfohl authored Nov 10, 2021
1 parent 2f123fb commit 83c440f
Show file tree
Hide file tree
Showing 14 changed files with 96 additions and 31 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@
unitary,
validate_mixture,
with_key_path,
with_key_path_prefix,
with_measurement_key_mapping,
)

Expand Down
5 changes: 5 additions & 0 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,11 @@ def _with_key_path_(self, path: Tuple[str, ...]):
[protocols.with_key_path(moment, path) for moment in self.moments]
)

def _with_key_path_prefix_(self, prefix: Tuple[str, ...]):
return self._with_sliced_moments(
[protocols.with_key_path_prefix(moment, prefix) for moment in self.moments]
)

def _qid_shape_(self) -> Tuple[int, ...]:
return self.qid_shape()

Expand Down
45 changes: 15 additions & 30 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,37 +227,19 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit':
circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map)
circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False)
if deep:

def map_deep(op: 'cirq.Operation') -> 'cirq.OP_TREE':
return op.mapped_circuit(deep=True) if isinstance(op, CircuitOperation) else op

if self.repetition_ids is None:
return circuit.map_operations(map_deep)
if not has_measurements:
return circuit.map_operations(map_deep) * abs(self.repetitions)

# Path must be constructed from the top down.
rekeyed_circuit = circuits.Circuit(
protocols.with_key_path(circuit, self.parent_path + (rep,))
for rep in self.repetition_ids
circuit = circuit.map_operations(
lambda op: op.mapped_circuit(deep=True) if isinstance(op, CircuitOperation) else op
)
return rekeyed_circuit.map_operations(map_deep)

if self.repetition_ids is None:
return circuit
if not has_measurements:
return circuit * abs(self.repetitions)

def rekey_op(op: 'cirq.Operation', rep: str):
"""Update measurement keys in `op` to include repetition ID `rep`."""
rekeyed_op = protocols.with_key_path(op, self.parent_path + (rep,))
if rekeyed_op is NotImplemented:
return op
return rekeyed_op

return circuits.Circuit(
circuit.map_operations(lambda op: rekey_op(op, rep)) for rep in self.repetition_ids
)
if self.repetition_ids:
if not has_measurements:
circuit = circuit * abs(self.repetitions)
else:
circuit = circuits.Circuit(
protocols.with_key_path_prefix(circuit, (rep,)) for rep in self.repetition_ids
)
if self.parent_path:
circuit = protocols.with_key_path_prefix(circuit, self.parent_path)
return circuit

def mapped_op(self, deep: bool = False) -> 'cirq.CircuitOperation':
"""As `mapped_circuit`, but wraps the result in a CircuitOperation."""
Expand Down Expand Up @@ -444,6 +426,9 @@ def __pow__(self, power: int) -> 'CircuitOperation':
def _with_key_path_(self, path: Tuple[str, ...]):
return dataclasses.replace(self, parent_path=path)

def _with_key_path_prefix_(self, prefix: Tuple[str, ...]):
return dataclasses.replace(self, parent_path=prefix + self.parent_path)

def with_key_path(self, path: Tuple[str, ...]):
return self._with_key_path_(path)

Expand Down
14 changes: 14 additions & 0 deletions cirq-core/cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,4 +802,18 @@ def test_tag_propagation():
assert test_tag not in op.tags


def test_keys_under_parent_path():
q = cirq.LineQubit(0)
op1 = cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.measure(q, key='A'),
cirq.measure_single_paulistring(cirq.X(q), key='B'),
cirq.MixedUnitaryChannel.from_mixture(cirq.bit_flip(0.5), key='C').on(q),
cirq.KrausChannel.from_channel(cirq.phase_damp(0.5), key='D').on(q),
)
)
op2 = op1.with_key_path(('X',))
assert cirq.measurement_key_names(op2.mapped_circuit()) == {'X:A', 'X:B', 'X:C', 'X:D'}


# TODO: Operation has a "gate" property. What is this for a CircuitOperation?
9 changes: 9 additions & 0 deletions cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ def _with_key_path_(self, path: Tuple[str, ...]):
return self
return new_gate.on(*self.qubits)

def _with_key_path_prefix_(self, prefix: Tuple[str, ...]):
new_gate = protocols.with_key_path_prefix(self.gate, prefix)
if new_gate is NotImplemented:
return NotImplemented
if new_gate is self.gate:
# As GateOperation is immutable, this can return the original.
return self
return new_gate.on(*self.qubits)

def __repr__(self):
if hasattr(self.gate, '_op_repr_'):
result = self.gate._op_repr_(self.qubits)
Expand Down
10 changes: 10 additions & 0 deletions cirq-core/cirq/ops/gate_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,16 @@ def test_with_key_path():
assert cirq.with_key_path(cirq.X(a), ('a', 'b')) is NotImplemented


def test_with_key_path_prefix():
a = cirq.LineQubit(0)
op = cirq.measure(a, key='m')
remap_op = cirq.with_key_path_prefix(op, ('a', 'b'))
assert cirq.measurement_key_names(remap_op) == {'a:b:m'}
assert cirq.with_key_path_prefix(remap_op, tuple()) is remap_op
assert cirq.with_key_path_prefix(op, tuple()) is op
assert cirq.with_key_path_prefix(cirq.X(a), ('a', 'b')) is NotImplemented


def test_cannot_remap_non_measurement_gate():
a = cirq.LineQubit(0)
op = cirq.X(a)
Expand Down
5 changes: 5 additions & 0 deletions cirq-core/cirq/ops/kraus_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
def _with_key_path_(self, path: Tuple[str, ...]):
return KrausChannel(kraus_ops=self._kraus_ops, key=protocols.with_key_path(self._key, path))

def _with_key_path_prefix_(self, prefix: Tuple[str, ...]):
return KrausChannel(
kraus_ops=self._kraus_ops, key=protocols.with_key_path_prefix(self._key, prefix)
)

def __str__(self):
if self._key is not None:
return f'KrausChannel({self._kraus_ops}, key={self._key})'
Expand Down
3 changes: 3 additions & 0 deletions cirq-core/cirq/ops/measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def with_key(self, key: Union[str, value.MeasurementKey]) -> 'MeasurementGate':
def _with_key_path_(self, path: Tuple[str, ...]):
return self.with_key(self.mkey._with_key_path_(path))

def _with_key_path_prefix_(self, prefix: Tuple[str, ...]):
return self.with_key(self.mkey._with_key_path_prefix_(prefix))

def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
return self.with_key(protocols.with_measurement_key_mapping(self.mkey, key_map))

Expand Down
5 changes: 5 additions & 0 deletions cirq-core/cirq/ops/mixed_unitary_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ def _with_key_path_(self, path: Tuple[str, ...]):
mixture=self._mixture, key=protocols.with_key_path(self._key, path)
)

def _with_key_path_prefix_(self, prefix: Tuple[str, ...]):
return MixedUnitaryChannel(
mixture=self._mixture, key=protocols.with_key_path_prefix(self._key, prefix)
)

def __str__(self):
if self._key is not None:
return f'MixedUnitaryChannel({self._mixture}, key={self._key})'
Expand Down
6 changes: 6 additions & 0 deletions cirq-core/cirq/ops/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,12 @@ def _with_key_path_(self, path: Tuple[str, ...]):
for op in self.operations
)

def _with_key_path_prefix_(self, prefix: Tuple[str, ...]):
return Moment(
protocols.with_key_path_prefix(op, prefix) if protocols.is_measurement(op) else op
for op in self.operations
)

def __copy__(self):
return type(self)(self.operations)

Expand Down
3 changes: 3 additions & 0 deletions cirq-core/cirq/ops/pauli_measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def with_key(self, key: Union[str, value.MeasurementKey]) -> 'PauliMeasurementGa
def _with_key_path_(self, path: Tuple[str, ...]) -> 'PauliMeasurementGate':
return self.with_key(self.mkey._with_key_path_(path))

def _with_key_path_prefix_(self, prefix: Tuple[str, ...]) -> 'PauliMeasurementGate':
return self.with_key(self.mkey._with_key_path_prefix_(prefix))

def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'PauliMeasurementGate':
return self.with_key(protocols.with_measurement_key_mapping(self.mkey, key_map))

Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/protocols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
measurement_key_names,
measurement_key_objs,
with_key_path,
with_key_path_prefix,
with_measurement_key_mapping,
SupportsMeasurementKey,
)
Expand Down
15 changes: 15 additions & 0 deletions cirq-core/cirq/protocols/measurement_key_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,18 @@ def with_key_path(val: Any, path: Tuple[str, ...]):
"""
getter = getattr(val, '_with_key_path_', None)
return NotImplemented if getter is None else getter(path)


def with_key_path_prefix(val: Any, prefix: Tuple[str, ...]):
"""Prefixes the path to the target's measurement keys.
The path usually refers to an identifier or a list of identifiers from a subcircuit that
used to contain the target. Since a subcircuit can be repeated and reused, these paths help
differentiate the actual measurement keys.
Args:
val: The value whose path to prefix.
prefix: The prefix to apply to the value's path.
"""
getter = getattr(val, '_with_key_path_prefix_', None)
return NotImplemented if getter is None else getter(prefix)
5 changes: 4 additions & 1 deletion cirq-core/cirq/value/measurement_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,16 @@ def parse_serialized(cls, key_str: str):
def _with_key_path_(self, path: Tuple[str, ...]):
return self.replace(path=path)

def _with_key_path_prefix_(self, prefix: Tuple[str, ...]):
return self._with_key_path_(path=prefix + self.path)

def with_key_path_prefix(self, path_component: str):
"""Adds the input path component to the start of the path.
Useful when constructing the path from inside to out (in case of nested subcircuits),
recursively.
"""
return self._with_key_path_((path_component,) + self.path)
return self._with_key_path_prefix_((path_component,))

def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
if self.name not in key_map:
Expand Down

0 comments on commit 83c440f

Please sign in to comment.