Skip to content

Commit

Permalink
Allow specifying initial state vector in DensityMatrixSimulator (#5223)
Browse files Browse the repository at this point in the history
This changes how ActOnDensityMatrixArgs is constructed to allow specifying the initial state as a state vector or state tensor, or as a density matrix or density tensor. Some of this could perhaps be moved into `cirq.to_valid_density_matrix` if people think that is a better place. Currently `to_valid_density_matrix` only handles 1D state vectors or 2D density matrices, not 2x2x..2 tensors in either case, but if we have the qid_shape we can tell handle these unambiguously.

Fixes #3958
  • Loading branch information
maffoo authored Apr 8, 2022
1 parent 3ca4d5a commit 5dbde9e
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 5 deletions.
10 changes: 7 additions & 3 deletions cirq-core/cirq/qis/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,9 +968,13 @@ def to_valid_density_matrix(
ValueError if the density_matrix_rep is not valid.
"""
qid_shape = _qid_shape_from_args(num_qubits, qid_shape)
if isinstance(density_matrix_rep, np.ndarray) and density_matrix_rep.ndim == 2:
validate_density_matrix(density_matrix_rep, qid_shape=qid_shape, dtype=dtype, atol=atol)
return density_matrix_rep
if isinstance(density_matrix_rep, np.ndarray):
N = np.prod(qid_shape, dtype=np.int64)
if len(qid_shape) > 1 and density_matrix_rep.shape == qid_shape * 2:
density_matrix_rep = density_matrix_rep.reshape((N, N))
if density_matrix_rep.shape == (N, N):
validate_density_matrix(density_matrix_rep, qid_shape=qid_shape, dtype=dtype, atol=atol)
return density_matrix_rep

state_vector = to_valid_state_vector(
density_matrix_rep, len(qid_shape), qid_shape=qid_shape, dtype=dtype
Expand Down
27 changes: 26 additions & 1 deletion cirq-core/cirq/qis/states_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,14 +607,29 @@ def test_to_valid_density_matrix_from_density_matrix():
assert_valid_density_matrix(np.diag([0.2, 0.8, 0, 0]), qid_shape=(4,))


def test_to_valid_density_matrix_from_density_matrix_tensor():
np.testing.assert_almost_equal(
cirq.to_valid_density_matrix(
cirq.one_hot(shape=(2, 2, 2, 2, 2, 2), dtype=np.complex64), num_qubits=3
),
cirq.one_hot(shape=(8, 8), dtype=np.complex64),
)
np.testing.assert_almost_equal(
cirq.to_valid_density_matrix(
cirq.one_hot(shape=(2, 3, 4, 2, 3, 4), dtype=np.complex64), qid_shape=(2, 3, 4)
),
cirq.one_hot(shape=(24, 24), dtype=np.complex64),
)


def test_to_valid_density_matrix_not_square():
with pytest.raises(ValueError, match='shape'):
cirq.to_valid_density_matrix(np.array([[1], [0]]), num_qubits=1)


def test_to_valid_density_matrix_size_mismatch_num_qubits():
with pytest.raises(ValueError, match='shape'):
cirq.to_valid_density_matrix(np.array([[1, 0], [0, 0]]), num_qubits=2)
cirq.to_valid_density_matrix(np.array([[[1, 0], [0, 0]], [[0, 0], [0, 0]]]), num_qubits=2)
with pytest.raises(ValueError, match='shape'):
cirq.to_valid_density_matrix(np.eye(4) / 4.0, num_qubits=1)

Expand Down Expand Up @@ -690,6 +705,16 @@ def test_to_valid_density_matrix_from_state_vector():
)


def test_to_valid_density_matrix_from_state_vector_tensor():
np.testing.assert_almost_equal(
cirq.to_valid_density_matrix(
density_matrix_rep=np.array(np.full((2, 2), 0.5), dtype=np.complex64),
num_qubits=2,
),
0.25 * np.ones((4, 4)),
)


def test_to_valid_density_matrix_from_state_invalid_state():
with pytest.raises(ValueError, match="Invalid quantum state"):
cirq.to_valid_density_matrix(np.array([1, 0, 0]), num_qubits=2)
Expand Down
6 changes: 5 additions & 1 deletion cirq-core/cirq/sim/act_on_density_matrix_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ def create(
).reshape(qid_shape * 2)
else:
if qid_shape is not None:
density_matrix = initial_state.reshape(qid_shape * 2)
if dtype and initial_state.dtype != dtype:
initial_state = initial_state.astype(dtype)
density_matrix = qis.to_valid_density_matrix(
initial_state, len(qid_shape), qid_shape=qid_shape, dtype=dtype
).reshape(qid_shape * 2)
else:
density_matrix = initial_state
if np.may_share_memory(density_matrix, initial_state):
Expand Down
47 changes: 47 additions & 0 deletions cirq-core/cirq/sim/act_on_density_matrix_args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,50 @@ def test_with_qubits():
def test_qid_shape_error():
with pytest.raises(ValueError, match="qid_shape must be provided"):
cirq.sim.act_on_density_matrix_args._BufferedDensityMatrix.create(initial_state=0)


def test_initial_state_vector():
qubits = cirq.LineQubit.range(3)
args = cirq.ActOnDensityMatrixArgs(
qubits=qubits, initial_state=np.full((8,), 1 / np.sqrt(8)), dtype=np.complex64
)
assert args.target_tensor.shape == (2, 2, 2, 2, 2, 2)

args2 = cirq.ActOnDensityMatrixArgs(
qubits=qubits, initial_state=np.full((2, 2, 2), 1 / np.sqrt(8)), dtype=np.complex64
)
assert args2.target_tensor.shape == (2, 2, 2, 2, 2, 2)


def test_initial_state_matrix():
qubits = cirq.LineQubit.range(3)
args = cirq.ActOnDensityMatrixArgs(
qubits=qubits, initial_state=np.full((8, 8), 1 / 8), dtype=np.complex64
)
assert args.target_tensor.shape == (2, 2, 2, 2, 2, 2)

args2 = cirq.ActOnDensityMatrixArgs(
qubits=qubits, initial_state=np.full((2, 2, 2, 2, 2, 2), 1 / 8), dtype=np.complex64
)
assert args2.target_tensor.shape == (2, 2, 2, 2, 2, 2)


def test_initial_state_bad_shape():
qubits = cirq.LineQubit.range(3)
with pytest.raises(ValueError, match="Invalid quantum state"):
cirq.ActOnDensityMatrixArgs(
qubits=qubits, initial_state=np.full((4,), 1 / 2), dtype=np.complex64
)
with pytest.raises(ValueError, match="Invalid quantum state"):
cirq.ActOnDensityMatrixArgs(
qubits=qubits, initial_state=np.full((2, 2), 1 / 2), dtype=np.complex64
)

with pytest.raises(ValueError, match="Invalid quantum state"):
cirq.ActOnDensityMatrixArgs(
qubits=qubits, initial_state=np.full((4, 4), 1 / 4), dtype=np.complex64
)
with pytest.raises(ValueError, match="Invalid quantum state"):
cirq.ActOnDensityMatrixArgs(
qubits=qubits, initial_state=np.full((2, 2, 2, 2), 1 / 4), dtype=np.complex64
)

0 comments on commit 5dbde9e

Please sign in to comment.