Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix str/repr explosion in separated states #4518

Merged
merged 16 commits into from
Dec 21, 2021
9 changes: 9 additions & 0 deletions cirq-core/cirq/sim/act_on_density_matrix_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,15 @@ def sample(
seed=seed,
)

def __repr__(self) -> str:
return (
'cirq.ActOnDensityMatrixArgs('
f'target_tensor={self.target_tensor!r},'
f' qid_shape={self.qid_shape!r},'
f' qubits={self.qubits!r},'
f' log_of_measurement_results={self.log_of_measurement_results!r}'
)


def _strat_apply_channel_to_state(
action: Any, args: ActOnDensityMatrixArgs, qubits: Sequence['cirq.Qid']
Expand Down
8 changes: 8 additions & 0 deletions cirq-core/cirq/sim/act_on_state_vector_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,14 @@ def sample(
seed=seed,
)

def __repr__(self) -> str:
return (
'cirq.ActOnStateVectorArgs('
f'target_tensor={self.target_tensor!r},'
f' qubits={self.qubits!r},'
f' log_of_measurement_results={self.log_of_measurement_results!r}'
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
)


def _strat_act_on_state_vector_from_apply_unitary(
unitary_value: Any,
Expand Down
28 changes: 24 additions & 4 deletions cirq-core/cirq/sim/density_matrix_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simulator for density matrices that simulates noisy quantum circuits."""
from typing import Any, Dict, TYPE_CHECKING, Tuple, Union, Sequence, Optional, List
from typing import Any, Dict, TYPE_CHECKING, Tuple, Union, Sequence, Optional, List, cast

import numpy as np

Expand Down Expand Up @@ -362,6 +362,9 @@ def density_matrix(self, copy=True):
self._density_matrix = np.reshape(matrix, (size, size))
return self._density_matrix.copy() if copy else self._density_matrix

def __repr__(self) -> str:
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
return f'cirq.DensityMatrixStepResult(sim_state={self._sim_state!r}, dtype={self._dtype!r}'


@value.value_equality(unhashable=True)
class DensityMatrixSimulatorState:
Expand All @@ -382,7 +385,7 @@ def _qid_shape_(self) -> Tuple[int, ...]:
return self._qid_shape

def _value_equality_values_(self) -> Any:
return (self.density_matrix.tolist(), self.qubit_map)
return self.density_matrix.tolist(), self.qubit_map

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -453,13 +456,30 @@ def final_density_matrix(self):

def _value_equality_values_(self) -> Any:
measurements = {k: v.tolist() for k, v in sorted(self.measurements.items())}
return (self.params, measurements, self._final_simulator_state)
return self.params, measurements, self._final_simulator_state

def __str__(self) -> str:
samples = super().__str__()
return f'measurements: {samples}\nfinal density matrix:\n{self.final_density_matrix}'
substates = self._substates
if substates is None:
return f'measurements: {samples}\nfinal density matrix:\n{self.final_density_matrix}'
ret = f'measurements: {samples}'
for substate in substates:
substate = cast(act_on_density_matrix_args.ActOnDensityMatrixArgs, substate)
tensor = substate.target_tensor
size = np.prod([tensor.shape[i] for i in range(tensor.ndim // 2)], dtype=np.int64)
dm = tensor.reshape((size, size))
label = f'qubits: {substate.qubits}' if substate.qubits else 'phase:'
ret += f'\n\n{label}\nfinal density matrix:\n{dm}'
return ret

def __repr__(self) -> str:
if self._final_step_result:
return (
'cirq.DensityMatrixTrialResult('
f'params={self.params!r}, measurements={self.measurements!r}, '
f'final_step_result={self._final_step_result!r})'
)
return (
'cirq.DensityMatrixTrialResult('
f'params={self.params!r}, measurements={self.measurements!r}, '
Expand Down
79 changes: 69 additions & 10 deletions cirq-core/cirq/sim/density_matrix_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,10 +1066,15 @@ def test_density_matrix_trial_result_qid_shape():

def test_density_matrix_trial_result_repr():
q0 = cirq.LineQubit(0)
final_step_result = mock.Mock(cirq.StepResult)
final_step_result._simulator_state.return_value = cirq.DensityMatrixSimulatorState(
density_matrix=np.ones((2, 2)) * 0.5, qubit_map={q0: 0}
args = cirq.ActOnDensityMatrixArgs(
target_tensor=np.ones((2, 2)) * 0.5,
available_buffer=[],
qid_shape=(2,),
prng=np.random.RandomState(0),
log_of_measurement_results={},
qubits=[q0],
)
final_step_result = cirq.DensityMatrixStepResult(args, cirq.DensityMatrixSimulator())
assert (
repr(
cirq.DensityMatrixTrialResult(
Expand All @@ -1081,10 +1086,13 @@ def test_density_matrix_trial_result_repr():
== "cirq.DensityMatrixTrialResult("
"params=cirq.ParamResolver({'s': 1}), "
"measurements={'m': array([[1]])}, "
"final_simulator_state=cirq.DensityMatrixSimulatorState("
"density_matrix=np.array([[0.5, 0.5], [0.5, 0.5]]), "
"qubit_map={cirq.LineQubit(0): 0}))"
""
"final_step_result=cirq.DensityMatrixStepResult("
"sim_state=cirq.ActOnDensityMatrixArgs("
"target_tensor=array([[0.5, 0.5],\n [0.5, 0.5]]), "
"qid_shape=(2,), "
"qubits=(cirq.LineQubit(0),), "
"log_of_measurement_results={}, "
"dtype=<class 'numpy.complex64'>)"
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
)


Expand Down Expand Up @@ -1531,7 +1539,7 @@ def test_density_matrices_same_with_or_without_split_untangled_states():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(cirq.H(q0), cirq.CX.on(q0, q1), cirq.reset(q1))
result1 = sim.simulate(circuit).final_density_matrix
sim = cirq.DensityMatrixSimulator(split_untangled_states=True)
sim = cirq.DensityMatrixSimulator()
result2 = sim.simulate(circuit).final_density_matrix
assert np.allclose(result1, result2)

Expand All @@ -1548,12 +1556,63 @@ def test_large_untangled_okay():
_ = cirq.DensityMatrixSimulator(split_untangled_states=False).simulate(circuit)

# Validate a simulation run
result = cirq.DensityMatrixSimulator(split_untangled_states=True).simulate(circuit)
result = cirq.DensityMatrixSimulator().simulate(circuit)
assert set(result._final_step_result._qubits) == set(cirq.LineQubit.range(59))
# _ = result.final_density_matrix hangs (as expected)

# Validate a trial run and sampling
result = cirq.DensityMatrixSimulator(split_untangled_states=True).run(circuit, repetitions=1000)
result = cirq.DensityMatrixSimulator().run(circuit, repetitions=1000)
assert len(result.measurements) == 59
assert len(result.measurements['0']) == 1000
assert (result.measurements['0'] == np.full(1000, 1)).all()


def test_separated_states_str_does_not_merge():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.measure(q0),
cirq.measure(q1),
cirq.X(q0),
)

result = cirq.DensityMatrixSimulator().simulate(circuit)
assert (
str(result)
== """measurements: 0=0 1=0

qubits: (cirq.LineQubit(0),)
final density matrix:
[[0.+0.j 0.+0.j]
[0.+0.j 1.+0.j]]

qubits: (cirq.LineQubit(1),)
final density matrix:
[[1.+0.j 0.+0.j]
[0.+0.j 0.+0.j]]

phase:
final density matrix:
[[1.+0.j]]"""
)


def test_unseparated_states_str():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.measure(q0),
cirq.measure(q1),
cirq.X(q0),
)

result = cirq.DensityMatrixSimulator(split_untangled_states=False).simulate(circuit)
assert (
str(result)
== """measurements: 0=0 1=0

qubits: (cirq.LineQubit(0), cirq.LineQubit(1))
final density matrix:
[[0.+0.j 0.+0.j 0.+0.j 0.+0.j]
[0.+0.j 0.+0.j 0.+0.j 0.+0.j]
[0.+0.j 0.+0.j 1.+0.j 0.+0.j]
[0.+0.j 0.+0.j 0.+0.j 0.+0.j]]"""
)
18 changes: 17 additions & 1 deletion cirq-core/cirq/sim/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,22 @@ def _final_simulator_state(self):
self._final_simulator_state_cache = self._final_step_result._simulator_state()
return self._final_simulator_state_cache

@property
def _substates(self) -> Optional[Sequence['cirq.ActOnArgs']]:
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
if self._final_step_result is None or not hasattr(self._final_step_result, '_sim_state'):
return None
sim_state = self._final_step_result._sim_state # type: ignore
state = sim_state # type: cirq.OperationTarget[cirq.ActOnArgs]
substates = dict() # type: Dict[cirq.ActOnArgs, int]
for q in state.qubits:
substates[state[q]] = 0
# Add the global phase if it exists
try:
substates[state[None]] = 0
except IndexError:
pass
return tuple(substates.keys())

def __repr__(self) -> str:
return (
f'cirq.SimulationTrialResult(params={self.params!r}, '
Expand Down Expand Up @@ -885,7 +901,7 @@ def _repr_pretty_(self, p: Any, cycle: bool) -> None:

def _value_equality_values_(self) -> Any:
measurements = {k: v.tolist() for k, v in sorted(self.measurements.items())}
return (self.params, measurements, self._final_simulator_state)
return self.params, measurements, self._final_simulator_state

@property
def qubit_map(self) -> Dict[ops.Qid, int]:
Expand Down
44 changes: 44 additions & 0 deletions cirq-core/cirq/sim/sparse_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,3 +1332,47 @@ def test_noise_model():
result = simulator.run(circuit, repetitions=100)

assert 40 <= sum(result.measurements['0'])[0] < 60


def test_separated_states_str_does_not_merge():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.measure(q0),
cirq.measure(q1),
cirq.H(q0),
cirq.GlobalPhaseOperation(0 + 1j),
)

result = cirq.Simulator().simulate(circuit)
assert (
str(result)
== """measurements: 0=0 1=0

qubits: (cirq.LineQubit(0),)
output vector: 0.707|0⟩ + 0.707|1⟩

qubits: (cirq.LineQubit(1),)
output vector: |0⟩

phase:
output vector: 1j|⟩"""
)


def test_unseparated_states_str():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.measure(q0),
cirq.measure(q1),
cirq.H(q0),
cirq.GlobalPhaseOperation(0 + 1j),
)

result = cirq.Simulator(split_untangled_states=False).simulate(circuit)
assert (
str(result)
== """measurements: 0=0 1=0

qubits: (cirq.LineQubit(0), cirq.LineQubit(1))
output vector: 0.707j|00⟩ + 0.707j|10⟩"""
)
44 changes: 35 additions & 9 deletions cirq-core/cirq/sim/state_vector_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@
TypeVar,
Type,
Optional,
cast,
)

import numpy as np

from cirq import ops, study, value
from cirq import ops, study, value, qis
from cirq.sim import simulator, state_vector, simulator_base
from cirq.sim.act_on_state_vector_args import ActOnStateVectorArgs

Expand Down Expand Up @@ -121,6 +122,9 @@ def _simulator_state(self) -> 'StateVectorSimulatorState':
"""
raise NotImplementedError()

def __repr__(self) -> str:
return f'cirq.StateVectorStepResult(sim_state={self._sim_state!r}'


@value.value_equality(unhashable=True)
class StateVectorSimulatorState:
Expand All @@ -140,7 +144,7 @@ def __repr__(self) -> str:
)

def _value_equality_values_(self) -> Any:
return (self.state_vector.tolist(), self.qubit_map)
return self.state_vector.tolist(), self.qubit_map


@value.value_equality(unhashable=True)
Expand Down Expand Up @@ -201,16 +205,32 @@ def state_vector(self):

def _value_equality_values_(self):
measurements = {k: v.tolist() for k, v in sorted(self.measurements.items())}
return (self.params, measurements, self._final_simulator_state)
return self.params, measurements, self._final_simulator_state

def __str__(self) -> str:
samples = super().__str__()
final = self.state_vector()
if len([1 for e in final if abs(e) > 0.001]) < 16:
state_vector = self.dirac_notation(3)
else:
state_vector = str(final)
return f'measurements: {samples}\noutput vector: {state_vector}'
substates = self._substates
if substates is None:
final = self.state_vector()
if len([1 for e in final if abs(e) > 0.001]) < 16:
state_vector = self.dirac_notation(3)
else:
state_vector = str(final)
return f'measurements: {samples}\noutput vector: {state_vector}'
ret = f'measurements: {samples}'
for substate in substates:
substate = cast(ActOnStateVectorArgs, substate)
final = substate.target_tensor
shape = final.shape
size = np.prod(shape, dtype=np.int64)
final = final.reshape(size)
if len([1 for e in final if abs(e) > 0.001]) < 16:
state_vector = qis.dirac_notation(final, 3, shape)
else:
state_vector = str(final)
label = f'qubits: {substate.qubits}' if substate.qubits else 'phase:'
ret += f'\n\n{label}\noutput vector: {state_vector}'
return ret

def _repr_pretty_(self, p: Any, cycle: bool) -> None:
"""Text output in Jupyter."""
Expand All @@ -221,6 +241,12 @@ def _repr_pretty_(self, p: Any, cycle: bool) -> None:
p.text(str(self))

def __repr__(self) -> str:
if self._final_step_result:
return (
'cirq.StateVectorTrialResult('
f'measurements={self.measurements!r}, '
f'final_step_result={self._final_step_result!r})'
)
return (
f'cirq.StateVectorTrialResult(params={self.params!r}, '
f'measurements={self.measurements!r}, '
Expand Down
Loading