Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix state vector factorization validation #5076

Merged
merged 1 commit into from
Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions cirq-core/cirq/linalg/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,7 @@ def factor_state_vector(
remainder = remainder / np.sum(abs(remainder) ** 2) ** 0.5
if validate:
t2 = state_vector_kronecker_product(extracted, remainder)
axes2 = list(axes) + [i for i in range(t1.ndim) if i not in axes]
t3 = transpose_state_vector_to_axis_order(t2, axes2)
if not np.allclose(t3, t, atol=atol):
if not predicates.allclose_up_to_global_phase(t2, t1, atol=atol):
raise ValueError('The tensor cannot be factored by the requested axes')
return extracted, remainder

Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/sim/state_vector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,10 +379,10 @@ def test_step_result_bloch_vector():

def test_factor_validation():
args = cirq.Simulator()._create_act_on_args(0, qubits=cirq.LineQubit.range(2))
args.apply_operation(cirq.H(cirq.LineQubit(0)))
args.apply_operation(cirq.H(cirq.LineQubit(0)) ** 0.7)
t = args.create_merged_state().target_tensor
cirq.linalg.transformations.factor_state_vector(t, [0])
cirq.linalg.transformations.factor_state_vector(t, [1], atol=1e-2)
cirq.linalg.transformations.factor_state_vector(t, [1])
args.apply_operation(cirq.CNOT(cirq.LineQubit(0), cirq.LineQubit(1)))
t = args.create_merged_state().target_tensor
with pytest.raises(ValueError, match='factor'):
Expand Down