Skip to content

Commit

Permalink
Avoid copying unnecessary buffers between simulation iterations (#4789)
Browse files Browse the repository at this point in the history
This PR implements the task proposed in #4779. In each iteration of a simulation, the entire `ActOnArgs` is copied, including buffers. It is not necessary and adds additional cost, especially for `DensityMatrixSimulator`. Therefore, a parameter `with_buffer` is added to the `copy` method to indicate whether buffers are also needed to be copied. For third-party simulators that have not added the parameter, a deprecation warning is raised.

This PR also modifies the `__init__` method of `DensityMatrixSimulator` and `ActOnStateVectorArgs` to create the buffer and qid_shape parameters when they are not provided.

close #4779
  • Loading branch information
yjt98765 authored Jan 14, 2022
1 parent 972b6d4 commit 20b577c
Show file tree
Hide file tree
Showing 15 changed files with 246 additions and 36 deletions.
2 changes: 1 addition & 1 deletion cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def __str__(self) -> str:
def _value_equality_values_(self) -> Any:
return self.qubit_map, self.M, self.simulation_options, self.grouping

def _on_copy(self, target: 'MPSState'):
def _on_copy(self, target: 'MPSState', deep_copy_buffers: bool = True):
target.simulation_options = self.simulation_options
target.grouping = self.grouping
target.M = [x.copy() for x in self.M]
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/protocols/act_on_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, fallback_result: Any = NotImplemented, measurements=None):
def _perform_measurement(self, qubits):
return self.measurements # coverage: ignore

def copy(self):
def copy(self, deep_copy_buffers: bool = True):
return DummyActOnArgs(self.fallback_result, self.measurements.copy()) # coverage: ignore

def _act_on_fallback_(
Expand Down
29 changes: 25 additions & 4 deletions cirq-core/cirq/sim/act_on_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Objects and methods for acting efficiently on a state tensor."""
import abc
import copy
import inspect
from typing import (
Any,
Dict,
Expand All @@ -26,6 +27,7 @@
Optional,
Iterator,
)
import warnings

import numpy as np

Expand Down Expand Up @@ -113,14 +115,33 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
"""Child classes that perform measurements should implement this with
the implementation."""

def copy(self: TSelf) -> TSelf:
"""Creates a copy of the object."""
def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
"""Creates a copy of the object.
Args:
deep_copy_buffers: If True, buffers will also be deep-copied.
Otherwise the copy will share a reference to the original object's
buffers.
Returns:
A copied instance.
"""
args = copy.copy(self)
self._on_copy(args)
if 'deep_copy_buffers' in inspect.signature(self._on_copy).parameters:
self._on_copy(args, deep_copy_buffers)
else:
warnings.warn(
(
'A new parameter deep_copy_buffers has been added to ActOnArgs._on_copy(). '
'The classes that inherit from ActOnArgs should support it before Cirq 0.15.'
),
DeprecationWarning,
)
self._on_copy(args)
args._log_of_measurement_results = self.log_of_measurement_results.copy()
return args

def _on_copy(self: TSelf, args: TSelf):
def _on_copy(self: TSelf, args: TSelf, deep_copy_buffers: bool = True):
"""Subclasses should implement this with any additional state copy
functionality."""

Expand Down
18 changes: 16 additions & 2 deletions cirq-core/cirq/sim/act_on_args_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from collections import abc
import inspect
from typing import (
Dict,
TYPE_CHECKING,
Expand All @@ -25,6 +26,7 @@
List,
Union,
)
import warnings

import numpy as np

Expand Down Expand Up @@ -131,9 +133,21 @@ def _act_on_fallback_(
self.args[q] = op_args
return True

def copy(self) -> 'cirq.ActOnArgsContainer[TActOnArgs]':
def copy(self, deep_copy_buffers: bool = True) -> 'cirq.ActOnArgsContainer[TActOnArgs]':
logs = self.log_of_measurement_results.copy()
copies = {a: a.copy() for a in set(self.args.values())}
copies = {}
for act_on_args in set(self.args.values()):
if 'deep_copy_buffers' in inspect.signature(act_on_args.copy).parameters:
copies[act_on_args] = act_on_args.copy(deep_copy_buffers)
else:
warnings.warn(
(
'A new parameter deep_copy_buffers has been added to ActOnArgs.copy(). The '
'classes that inherit from ActOnArgs should support it before Cirq 0.15.'
),
DeprecationWarning,
)
copies[act_on_args] = act_on_args.copy()
for copy in copies.values():
copy._log_of_measurement_results = logs
args = {q: copies[a] for q, a in self.args.items()}
Expand Down
9 changes: 8 additions & 1 deletion cirq-core/cirq/sim/act_on_args_container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def __init__(self, qubits, logs):
def _perform_measurement(self, qubits: Sequence[cirq.Qid]) -> List[int]:
return [0] * len(qubits)

def copy(self) -> 'EmptyActOnArgs':
def copy(self) -> 'EmptyActOnArgs': # type: ignore
"""The deep_copy_buffers parameter is omitted to trigger a deprecation warning test."""
return EmptyActOnArgs(
qubits=self.qubits,
logs=self.log_of_measurement_results.copy(),
Expand Down Expand Up @@ -226,6 +227,12 @@ def test_copy_succeeds():
assert copied.qubits == (q0, q1)


def test_copy_deprecation_warning():
args = create_container(qs2, False)
with cirq.testing.assert_deprecated('deep_copy_buffers', deadline='0.15'):
args.copy(False)


def test_merge_succeeds():
args = create_container(qs2, False)
merged = args.create_merged_state()
Expand Down
9 changes: 9 additions & 0 deletions cirq-core/cirq/sim/act_on_args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def _act_on_fallback_(
) -> bool:
return True

def _on_copy(self, args):
return super()._on_copy(args)


def test_measurements():
args = DummyArgs()
Expand Down Expand Up @@ -89,3 +92,9 @@ def test_transpose_qubits():
args.transpose_to_qubit_order((q0, q2))
with pytest.raises(ValueError, match='Qubits do not match'):
args.transpose_to_qubit_order((q0, q1, q1))


def test_on_copy_has_no_param():
args = DummyArgs()
with cirq.testing.assert_deprecated('deep_copy_buffers', deadline='0.15'):
args.copy(False)
39 changes: 29 additions & 10 deletions cirq-core/cirq/sim/act_on_density_matrix_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Objects and methods for acting efficiently on a density matrix."""

from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Sequence, Union
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Sequence, Union

import numpy as np

Expand All @@ -36,11 +36,11 @@ class ActOnDensityMatrixArgs(ActOnArgs):
def __init__(
self,
target_tensor: np.ndarray,
available_buffer: List[np.ndarray],
qid_shape: Tuple[int, ...],
prng: np.random.RandomState = None,
log_of_measurement_results: Dict[str, Any] = None,
qubits: Sequence['cirq.Qid'] = None,
available_buffer: Optional[List[np.ndarray]] = None,
qid_shape: Optional[Tuple[int, ...]] = None,
prng: Optional[np.random.RandomState] = None,
log_of_measurement_results: Optional[Dict[str, Any]] = None,
qubits: Optional[Sequence['cirq.Qid']] = None,
ignore_measurement_results: bool = False,
):
"""Inits ActOnDensityMatrixArgs.
Expand All @@ -65,11 +65,27 @@ def __init__(
will treat measurement as dephasing instead of collapsing
process. This is only applicable to simulators that can
model dephasing.
Raises:
ValueError: The dimension of `target_tensor` is not divisible by 2
and `qid_shape` is not provided.
"""
super().__init__(prng, qubits, log_of_measurement_results, ignore_measurement_results)
self.target_tensor = target_tensor
self.available_buffer = available_buffer
self.qid_shape = qid_shape
if available_buffer is None:
self.available_buffer = [np.empty_like(target_tensor) for _ in range(3)]
else:
self.available_buffer = available_buffer
if qid_shape is None:
target_shape = target_tensor.shape
if len(target_shape) % 2 != 0:
raise ValueError(
'The dimension of target_tensor is not divisible by 2.'
' Require explicit qid_shape.'
)
self.qid_shape = target_shape[: len(target_shape) // 2]
else:
self.qid_shape = qid_shape

def _act_on_fallback_(
self,
Expand Down Expand Up @@ -108,9 +124,12 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
)
return bits

def _on_copy(self, target: 'cirq.ActOnDensityMatrixArgs'):
def _on_copy(self, target: 'cirq.ActOnDensityMatrixArgs', deep_copy_buffers: bool = True):
target.target_tensor = self.target_tensor.copy()
target.available_buffer = [b.copy() for b in self.available_buffer]
if deep_copy_buffers:
target.available_buffer = [b.copy() for b in self.available_buffer]
else:
target.available_buffer = self.available_buffer

def _on_kronecker_product(
self, other: 'cirq.ActOnDensityMatrixArgs', target: 'cirq.ActOnDensityMatrixArgs'
Expand Down
29 changes: 29 additions & 0 deletions cirq-core/cirq/sim/act_on_density_matrix_args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,35 @@
import cirq


def test_default_parameter():
qid_shape = (2,)
tensor = cirq.to_valid_density_matrix(
0, len(qid_shape), qid_shape=qid_shape, dtype=np.complex64
)
args = cirq.ActOnDensityMatrixArgs(target_tensor=tensor)
assert len(args.available_buffer) == 3
for buffer in args.available_buffer:
assert buffer.shape == tensor.shape
assert buffer.dtype == tensor.dtype
assert args.qid_shape == qid_shape


def test_shallow_copy_buffers():
qid_shape = (2,)
tensor = cirq.to_valid_density_matrix(
0, len(qid_shape), qid_shape=qid_shape, dtype=np.complex64
)
args = cirq.ActOnDensityMatrixArgs(target_tensor=tensor)
copy = args.copy(deep_copy_buffers=False)
assert copy.available_buffer is args.available_buffer


def test_default_parameter_error():
tensor = np.ndarray(shape=(2,))
with pytest.raises(ValueError, match='The dimension of target_tensor is not divisible by 2'):
cirq.ActOnDensityMatrixArgs(target_tensor=tensor)


def test_decomposed_fallback():
class Composite(cirq.Gate):
def num_qubits(self) -> int:
Expand Down
22 changes: 14 additions & 8 deletions cirq-core/cirq/sim/act_on_state_vector_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Objects and methods for acting efficiently on a state vector."""

from typing import Any, Tuple, TYPE_CHECKING, Union, Dict, List, Sequence
from typing import Any, Optional, Tuple, TYPE_CHECKING, Union, Dict, List, Sequence

import numpy as np

Expand All @@ -40,10 +40,10 @@ class ActOnStateVectorArgs(ActOnArgs):
def __init__(
self,
target_tensor: np.ndarray,
available_buffer: np.ndarray,
prng: np.random.RandomState = None,
log_of_measurement_results: Dict[str, Any] = None,
qubits: Sequence['cirq.Qid'] = None,
available_buffer: Optional[np.ndarray] = None,
prng: Optional[np.random.RandomState] = None,
log_of_measurement_results: Optional[Dict[str, Any]] = None,
qubits: Optional[Sequence['cirq.Qid']] = None,
):
"""Inits ActOnStateVectorArgs.
Expand All @@ -66,7 +66,10 @@ def __init__(
"""
super().__init__(prng, qubits, log_of_measurement_results)
self.target_tensor = target_tensor
self.available_buffer = available_buffer
if available_buffer is None:
self.available_buffer = np.empty_like(target_tensor)
else:
self.available_buffer = available_buffer

def swap_target_tensor_for(self, new_target_tensor: np.ndarray):
"""Gives a new state vector for the system.
Expand Down Expand Up @@ -174,9 +177,12 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
)
return bits

def _on_copy(self, target: 'cirq.ActOnStateVectorArgs'):
def _on_copy(self, target: 'cirq.ActOnStateVectorArgs', deep_copy_buffers: bool = True):
target.target_tensor = self.target_tensor.copy()
target.available_buffer = self.available_buffer.copy()
if deep_copy_buffers:
target.available_buffer = self.available_buffer.copy()
else:
target.available_buffer = self.available_buffer

def _on_kronecker_product(
self, other: 'cirq.ActOnStateVectorArgs', target: 'cirq.ActOnStateVectorArgs'
Expand Down
14 changes: 14 additions & 0 deletions cirq-core/cirq/sim/act_on_state_vector_args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@
import cirq


def test_default_parameter():
target_tensor = cirq.one_hot(shape=(2, 2, 2), dtype=np.complex64)
args = cirq.ActOnStateVectorArgs(target_tensor)
assert args.available_buffer.shape == target_tensor.shape
assert args.available_buffer.dtype == target_tensor.dtype


def test_shallow_copy_buffers():
target_tensor = cirq.one_hot(shape=(2, 2, 2), dtype=np.complex64)
args = cirq.ActOnStateVectorArgs(target_tensor)
copy = args.copy(deep_copy_buffers=False)
assert copy.available_buffer is args.available_buffer


def test_decomposed_fallback():
class Composite(cirq.Gate):
def num_qubits(self) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
"""Returns the measurement from the tableau."""
return [self.tableau._measure(self.qubit_map[q], self.prng) for q in qubits]

def _on_copy(self, target: 'ActOnCliffordTableauArgs'):
def _on_copy(self, target: 'ActOnCliffordTableauArgs', deep_copy_buffers: bool = True):
target.tableau = self.tableau.copy()

def sample(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
"""Returns the measurement from the stabilizer state form."""
return [self.state._measure(self.qubit_map[q], self.prng) for q in qubits]

def _on_copy(self, target: 'ActOnStabilizerCHFormArgs'):
def _on_copy(self, target: 'ActOnStabilizerCHFormArgs', deep_copy_buffers: bool = True):
target.state = self.state.copy()

def sample(
Expand Down
13 changes: 11 additions & 2 deletions cirq-core/cirq/sim/operation_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,17 @@ def apply_operation(self, op: 'cirq.Operation'):
protocols.act_on(op, self)

@abc.abstractmethod
def copy(self: TSelfTarget) -> TSelfTarget:
"""Copies the object."""
def copy(self: TSelfTarget, deep_copy_buffers: bool = True) -> TSelfTarget:
"""Creates a copy of the object.
Args:
deep_copy_buffers: If True, buffers will also be deep-copied.
Otherwise the copy will share a reference to the original object's
buffers.
Returns:
A copied instance.
"""

@property
@abc.abstractmethod
Expand Down
Loading

0 comments on commit 20b577c

Please sign in to comment.