Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Classical sympy conditions with qudits #4778

Closed
wants to merge 47 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
b76950e
Allow sympy expressions as classical controls
daxfohl Dec 9, 2021
5fdff50
Format
daxfohl Dec 9, 2021
ef081d7
move Condition to value
daxfohl Dec 9, 2021
0dd430e
Condition subclasses
daxfohl Dec 9, 2021
3fb7f75
Fix sympy resolver
daxfohl Dec 9, 2021
9568ae0
lint
daxfohl Dec 9, 2021
04fcff7
fix CCO serialization
daxfohl Dec 9, 2021
b3c344e
fix CCO serialization
daxfohl Dec 9, 2021
b8ff20a
add json reprs for conditions
daxfohl Dec 9, 2021
aa99805
add support for qudits in conditions
daxfohl Dec 9, 2021
cbb029b
add test
daxfohl Dec 9, 2021
efce2f9
tests
daxfohl Dec 9, 2021
b34994b
tests
daxfohl Dec 9, 2021
5537397
test
daxfohl Dec 9, 2021
73ed74b
format
daxfohl Dec 9, 2021
a415235
docstrings
daxfohl Dec 9, 2021
3a0a56d
subop
daxfohl Dec 9, 2021
f930f6a
regex
daxfohl Dec 10, 2021
39a7a95
docs
daxfohl Dec 10, 2021
42bac3e
Make test_sympy more intuitive.
daxfohl Dec 10, 2021
f4ea9d8
Sympy str roundtrip
daxfohl Dec 14, 2021
71f61f5
Resolve some code review comments
daxfohl Dec 16, 2021
2261355
Add escape key to parse_sympy_condition
daxfohl Dec 16, 2021
6b36357
repr
daxfohl Dec 16, 2021
afbf3c9
coverage
daxfohl Dec 16, 2021
58fb2dc
coverage
daxfohl Dec 16, 2021
bd80c0b
parser
daxfohl Dec 17, 2021
c39a572
Improve sympy repr
daxfohl Dec 17, 2021
12d38ca
lint
daxfohl Dec 20, 2021
724febb
sympy.basic
daxfohl Dec 20, 2021
b598697
Add sympy json resolvers for comparators
daxfohl Dec 20, 2021
d167de7
_from_json_dict_
daxfohl Dec 20, 2021
72d82eb
lint
daxfohl Dec 20, 2021
b55188e
reduce fixed_tokens
daxfohl Dec 20, 2021
fd1fefb
Merge branch 'master' into sympy3
daxfohl Dec 20, 2021
f6a6645
Merge branch 'master' into sympymerge
daxfohl Dec 20, 2021
de3f887
Merge branch 'sympy3' of https://github.com/daxfohl/Cirq into sympy3
daxfohl Dec 20, 2021
ca56bd8
more tests
daxfohl Dec 20, 2021
6f8e344
format
daxfohl Dec 20, 2021
689719f
Key
daxfohl Dec 21, 2021
96ba4e9
combined test
daxfohl Dec 21, 2021
793c138
Merge remote-tracking branch 'origin/sympy3' into sympy3
daxfohl Dec 21, 2021
0b5526f
Merge branch 'sympy3' into qudits2
daxfohl Dec 21, 2021
8c17a3f
Merge branch 'master' into qudits2
daxfohl Dec 23, 2021
c56d6bb
lint
daxfohl Dec 23, 2021
681a008
Docstrings
daxfohl Dec 23, 2021
1bddb1c
Merge branch 'master' into qudits2
daxfohl Dec 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import dataclasses
import math
from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Union
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union

import numpy as np
import quimb.tensor as qtn
Expand Down Expand Up @@ -92,6 +92,7 @@ def _create_partial_act_on_args(
initial_state: Union[int, 'MPSState'],
qubits: Sequence['cirq.Qid'],
logs: Dict[str, Any],
measured_qubits: Dict[str, Tuple['cirq.Qid', ...]],
) -> 'MPSState':
"""Creates MPSState args for simulating the Circuit.

Expand All @@ -102,6 +103,8 @@ def _create_partial_act_on_args(
is often used in specifying the initial state, i.e. the
ordering of the computational basis states.
logs: A mutable object that measurements are recorded into.
measured_qubits: A dictionary that contains the qubits that were
measured in each measurement.

Returns:
MPSState args for simulating the Circuit.
Expand All @@ -116,6 +119,7 @@ def _create_partial_act_on_args(
grouping=self.grouping,
initial_state=initial_state,
log_of_measurement_results=logs,
measured_qubits=measured_qubits,
)

def _create_step_result(
Expand Down Expand Up @@ -229,6 +233,7 @@ def __init__(
grouping: Optional[Dict['cirq.Qid', int]] = None,
initial_state: int = 0,
log_of_measurement_results: Dict[str, Any] = None,
measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None,
):
"""Creates and MPSState

Expand All @@ -242,11 +247,13 @@ def __init__(
initial_state: An integer representing the initial state.
log_of_measurement_results: A mutable object that measurements are
being recorded into.
measured_qubits: A dictionary that contains the qubits that were
measured in each measurement.

Raises:
ValueError: If the grouping does not cover the qubits.
"""
super().__init__(prng, qubits, log_of_measurement_results)
super().__init__(prng, qubits, log_of_measurement_results, measured_qubits)
qubit_map = self.qubit_map
self.grouping = qubit_map if grouping is None else grouping
if self.grouping.keys() != self.qubit_map.keys():
Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/ops/classically_controlled_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def _circuit_diagram_info_(
sub_info = protocols.circuit_diagram_info(self._sub_operation, sub_args, None)
if sub_info is None:
return NotImplemented # coverage: ignore

control_count = len({k for c in self._conditions for k in c.keys})
wire_symbols = sub_info.wire_symbols + ('^',) * control_count
if any(not isinstance(c, value.KeyCondition) for c in self._conditions):
Expand Down Expand Up @@ -176,8 +175,9 @@ def _json_dict_(self) -> Dict[str, Any]:
'sub_operation': self._sub_operation,
}

def _act_on_(self, args: 'cirq.ActOnArgs') -> bool:
if all(c.resolve(args.log_of_measurement_results) for c in self._conditions):
def _act_on_(self, args: 'cirq.OperationTarget') -> bool:
measurements, qubits = args.log_of_measurement_results, args.measured_qubits
if all(c.resolve(measurements, qubits) for c in self._conditions):
protocols.act_on(self._sub_operation, args)
return True

Expand Down
36 changes: 36 additions & 0 deletions cirq-core/cirq/ops/classically_controlled_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest
import sympy
from sympy.parsing import sympy_parser
Expand Down Expand Up @@ -702,6 +704,40 @@ def test_sympy():
assert result.measurements['m_result'][0][0] == (j > i)


def test_sympy_qudits():
q0 = cirq.LineQid(0, 3)
q1 = cirq.LineQid(1, 3)
q_result = cirq.LineQubit(2)

class PlusGate(cirq.Gate):
def __init__(self, dimension, increment=1):
self.dimension = dimension
self.increment = increment % dimension

def _qid_shape_(self):
return (self.dimension,)

def _unitary_(self):
inc = (self.increment - 1) % self.dimension + 1
u = np.empty((self.dimension, self.dimension))
u[inc:] = np.eye(self.dimension)[:-inc]
u[:inc] = np.eye(self.dimension)[-inc:]
return u

for i in range(9):
digits = cirq.big_endian_int_to_digits(i, digit_count=2, base=3)
circuit = cirq.Circuit(
PlusGate(3, digits[0]).on(q0),
PlusGate(3, digits[1]).on(q1),
cirq.measure(q0, q1, key='m'),
cirq.X(q_result).with_classical_controls(sympy_parser.parse_expr('m > 4')),
cirq.measure(q_result, key='m_result'),
)

result = cirq.Simulator().run(circuit)
assert result.measurements['m_result'][0][0] == (i > 4)


def test_sympy_path_prefix():
q = cirq.LineQubit(0)
op = cirq.X(q).with_classical_controls(sympy.Symbol('b'))
Expand Down
3 changes: 0 additions & 3 deletions cirq-core/cirq/protocols/measurement_key_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
from cirq import value
from cirq._doc import doc_private

if TYPE_CHECKING:
import cirq

if TYPE_CHECKING:
import cirq

Expand Down
12 changes: 12 additions & 0 deletions cirq-core/cirq/sim/act_on_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
prng: np.random.RandomState = None,
qubits: Sequence['cirq.Qid'] = None,
log_of_measurement_results: Dict[str, List[int]] = None,
measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None,
Copy link
Contributor Author

@daxfohl daxfohl Dec 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@95-martin-orion I did this the least-invasive way, of adding this measured_qubits member and propagating it everywhere it needed to be. This is a little smelly because it and log_of_measurement_results go hand in hand and are initialized, copied, etc. in pairs. I think the better option though is to introduce a new classical_state: ClassicalState that contains both of these dictionaries, plus whatever we add later, and then deprecate log_of_measurement_results.. I know it breaks the "rule of three" to do so, but I'm pretty sure we'll have a third thing to put into it soon.

It may be a big deprecation, but it may be worth it to do now rather than having to clean up the mess after things like repeated measurements, keeping track of qubits, etc have all become dependent on it.

Additionally I noticed a minor bug here is that I don't keep track of the qubits in keyed channels. If we have a ClassicalState class then we can make the interface such that writing to the dictionaries requires both a value and a qubit list, so that such an omission is impossible. Of course, that will make for a much bigger review. WDYT?

ignore_measurement_results: bool = False,
):
"""Inits ActOnArgs.
Expand All @@ -59,6 +60,8 @@ def __init__(
ordering of the computational basis states.
log_of_measurement_results: A mutable object that measurements are
being recorded into.
measured_qubits: A dictionary that contains the qubits that were
measured in each measurement.
ignore_measurement_results: If True, then the simulation
will treat measurement as dephasing instead of collapsing
process, and not log the result. This is only applicable to
Expand All @@ -70,9 +73,12 @@ def __init__(
qubits = ()
if log_of_measurement_results is None:
log_of_measurement_results = {}
if measured_qubits is None:
measured_qubits = {}
self._set_qubits(qubits)
self.prng = prng
self._log_of_measurement_results = log_of_measurement_results
self._measured_qubits = measured_qubits
self._ignore_measurement_results = ignore_measurement_results

def _set_qubits(self, qubits: Sequence['cirq.Qid']):
Expand Down Expand Up @@ -104,6 +110,7 @@ def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[
if key in self._log_of_measurement_results:
raise ValueError(f"Measurement already logged to key {key!r}")
self._log_of_measurement_results[key] = corrected
self._measured_qubits[key] = tuple(qubits)

def get_axes(self, qubits: Sequence['cirq.Qid']) -> List[int]:
return [self.qubit_map[q] for q in qubits]
Expand All @@ -118,6 +125,7 @@ def copy(self: TSelf) -> TSelf:
args = copy.copy(self)
self._on_copy(args)
args._log_of_measurement_results = self.log_of_measurement_results.copy()
args._measured_qubits = self.measured_qubits.copy()
return args

def _on_copy(self: TSelf, args: TSelf):
Expand Down Expand Up @@ -197,6 +205,10 @@ def _on_transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid'], targ
def log_of_measurement_results(self) -> Dict[str, List[int]]:
return self._log_of_measurement_results

@property
def measured_qubits(self) -> Dict[str, Tuple['cirq.Qid', ...]]:
return self._measured_qubits

@property
def ignore_measurement_results(self) -> bool:
return self._ignore_measurement_results
Expand Down
21 changes: 16 additions & 5 deletions cirq-core/cirq/sim/act_on_args_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@

from collections import abc
from typing import (
Any,
Dict,
TYPE_CHECKING,
Generic,
Sequence,
Optional,
Iterator,
Any,
Tuple,
List,
Mapping,
Optional,
Sequence,
Tuple,
TYPE_CHECKING,
Union,
)

Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(
qubits: Sequence['cirq.Qid'],
split_untangled_states: bool,
log_of_measurement_results: Dict[str, Any],
measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None,
):
"""Initializes the class.

Expand All @@ -63,11 +65,14 @@ def __init__(
at the end.
log_of_measurement_results: A mutable object that measurements are
being recorded into.
measured_qubits: A dictionary that contains the qubits that were
measured in each measurement.
"""
self.args = args
self._qubits = tuple(qubits)
self.split_untangled_states = split_untangled_states
self._log_of_measurement_results = log_of_measurement_results
self._measured_qubits = measured_qubits if measured_qubits is not None else {}

def create_merged_state(self) -> TActOnArgs:
if not self.split_untangled_states:
Expand Down Expand Up @@ -132,9 +137,11 @@ def _act_on_fallback_(

def copy(self) -> 'cirq.ActOnArgsContainer[TActOnArgs]':
logs = self.log_of_measurement_results.copy()
measured_qubits = self._measured_qubits.copy()
copies = {a: a.copy() for a in set(self.args.values())}
for copy in copies.values():
copy._log_of_measurement_results = logs
copy._measured_qubits = measured_qubits
args = {q: copies[a] for q, a in self.args.items()}
return ActOnArgsContainer(args, self.qubits, self.split_untangled_states, logs)

Expand All @@ -146,6 +153,10 @@ def qubits(self) -> Tuple['cirq.Qid', ...]:
def log_of_measurement_results(self) -> Dict[str, Any]:
return self._log_of_measurement_results

@property
def measured_qubits(self) -> Mapping[str, Tuple['cirq.Qid', ...]]:
return self._measured_qubits

def sample(
self,
qubits: List['cirq.Qid'],
Expand Down
11 changes: 10 additions & 1 deletion cirq-core/cirq/sim/act_on_density_matrix_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
prng: np.random.RandomState = None,
log_of_measurement_results: Dict[str, Any] = None,
qubits: Sequence['cirq.Qid'] = None,
measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None,
ignore_measurement_results: bool = False,
):
"""Inits ActOnDensityMatrixArgs.
Expand All @@ -61,12 +62,20 @@ def __init__(
effects.
log_of_measurement_results: A mutable object that measurements are
being recorded into.
measured_qubits: A dictionary that contains the qubits that were
measured in each measurement.
ignore_measurement_results: If True, then the simulation
will treat measurement as dephasing instead of collapsing
process. This is only applicable to simulators that can
model dephasing.
"""
super().__init__(prng, qubits, log_of_measurement_results, ignore_measurement_results)
super().__init__(
prng=prng,
qubits=qubits,
log_of_measurement_results=log_of_measurement_results,
measured_qubits=measured_qubits,
ignore_measurement_results=ignore_measurement_results,
)
self.target_tensor = target_tensor
self.available_buffer = available_buffer
self.qid_shape = qid_shape
Expand Down
5 changes: 4 additions & 1 deletion cirq-core/cirq/sim/act_on_state_vector_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
prng: np.random.RandomState = None,
log_of_measurement_results: Dict[str, Any] = None,
qubits: Sequence['cirq.Qid'] = None,
measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None,
):
"""Inits ActOnStateVectorArgs.

Expand All @@ -63,8 +64,10 @@ def __init__(
effects.
log_of_measurement_results: A mutable object that measurements are
being recorded into.
measured_qubits: A dictionary that contains the qubits that were
measured in each measurement.
"""
super().__init__(prng, qubits, log_of_measurement_results)
super().__init__(prng, qubits, log_of_measurement_results, measured_qubits)
self.target_tensor = target_tensor
self.available_buffer = available_buffer

Expand Down
7 changes: 5 additions & 2 deletions cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""A protocol for implementing high performance clifford tableau evolutions
for Clifford Simulator."""

from typing import Any, Dict, TYPE_CHECKING, List, Sequence, Union
from typing import Any, Dict, TYPE_CHECKING, List, Sequence, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -42,6 +42,7 @@ def __init__(
prng: np.random.RandomState,
log_of_measurement_results: Dict[str, Any],
qubits: Sequence['cirq.Qid'] = None,
measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None,
):
"""Inits ActOnCliffordTableauArgs.

Expand All @@ -55,8 +56,10 @@ def __init__(
effects.
log_of_measurement_results: A mutable object that measurements are
being recorded into.
measured_qubits: A dictionary that contains the qubits that were
measured in each measurement.
"""
super().__init__(prng, qubits, log_of_measurement_results)
super().__init__(prng, qubits, log_of_measurement_results, measured_qubits)
self.tableau = tableau

def _act_on_fallback_(
Expand Down
7 changes: 5 additions & 2 deletions cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, TYPE_CHECKING, List, Sequence, Union
from typing import Any, Dict, TYPE_CHECKING, List, Sequence, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -41,6 +41,7 @@ def __init__(
prng: np.random.RandomState,
log_of_measurement_results: Dict[str, Any],
qubits: Sequence['cirq.Qid'] = None,
measured_qubits: Dict[str, Tuple['cirq.Qid', ...]] = None,
):
"""Initializes with the given state and the axes for the operation.
Args:
Expand All @@ -53,8 +54,10 @@ def __init__(
effects.
log_of_measurement_results: A mutable object that measurements are
being recorded into.
measured_qubits: A dictionary that contains the qubits that were
measured in each measurement.
"""
super().__init__(prng, qubits, log_of_measurement_results)
super().__init__(prng, qubits, log_of_measurement_results, measured_qubits)
self.state = state

def _act_on_fallback_(
Expand Down
Loading