From 0b93e27dffc86293f41ec8ffc2df121519457364 Mon Sep 17 00:00:00 2001 From: Shef Date: Tue, 19 Mar 2024 15:20:39 -0400 Subject: [PATCH] Update ClassicalSimulator to confirm to simulation abstraction (#6432) --- cirq-core/cirq/sim/classical_simulator.py | 298 +++++++++++++----- .../cirq/sim/classical_simulator_test.py | 97 ++++++ 2 files changed, 308 insertions(+), 87 deletions(-) diff --git a/cirq-core/cirq/sim/classical_simulator.py b/cirq-core/cirq/sim/classical_simulator.py index 515d1869e23..a5287637bfc 100644 --- a/cirq-core/cirq/sim/classical_simulator.py +++ b/cirq-core/cirq/sim/classical_simulator.py @@ -12,96 +12,220 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict -from collections import defaultdict -from cirq.sim.simulator import SimulatesSamples -from cirq import ops, protocols -from cirq.study.resolver import ParamResolver -from cirq.circuits.circuit import AbstractCircuit -from cirq.ops.raw_types import Qid + +from typing import Dict, Generic, Any, Sequence, List, Optional, Union, TYPE_CHECKING +from copy import deepcopy, copy +from cirq import ops, qis +from cirq.value import big_endian_int_to_bits +from cirq import sim +from cirq.sim.simulation_state import TSimulationState, SimulationState import numpy as np +if TYPE_CHECKING: + import cirq + -def _is_identity(op: ops.Operation) -> bool: - if isinstance(op.gate, (ops.XPowGate, ops.CXPowGate, ops.CCXPowGate, ops.SwapPowGate)): - return op.gate.exponent % 2 == 0 +def _is_identity(action) -> bool: + """Check if the given action is equivalent to an identity.""" + gate = action.gate if isinstance(action, ops.Operation) else action + if isinstance(gate, (ops.XPowGate, ops.CXPowGate, ops.CCXPowGate, ops.SwapPowGate)): + return gate.exponent % 2 == 0 return False -class ClassicalStateSimulator(SimulatesSamples): - """A simulator that accepts only gates with classical counterparts. - - This simulator evolves a single state, using only gates that output a single state for each - input state. The simulator runs in linear time, at the cost of not supporting superposition. - It can be used to estimate costs and simulate circuits for simple non-quantum algorithms using - many more qubits than fully capable quantum simulators. - - The supported gates are: - - cirq.X - - cirq.CNOT - - cirq.SWAP - - cirq.TOFFOLI - - cirq.measure - - Args: - circuit: The circuit to simulate. - param_resolver: Parameters to run with the program. - repetitions: Number of times to repeat the run. It is expected that - this is validated greater than zero before calling this method. - - Returns: - A dictionary mapping measurement keys to measurement results. - - Raises: - ValueError: If - - one of the gates is not an X, CNOT, SWAP, TOFFOLI or a measurement. - - A measurement key is used for measurements on different numbers of qubits. - """ - - def _run( - self, circuit: AbstractCircuit, param_resolver: ParamResolver, repetitions: int - ) -> Dict[str, np.ndarray]: - results_dict: Dict[str, np.ndarray] = {} - values_dict: Dict[Qid, int] = defaultdict(int) - param_resolver = param_resolver or ParamResolver({}) - resolved_circuit = protocols.resolve_parameters(circuit, param_resolver) - - for moment in resolved_circuit: - for op in moment: - if _is_identity(op): - continue - if op.gate == ops.X: - (q,) = op.qubits - values_dict[q] ^= 1 - elif op.gate == ops.CNOT: - c, q = op.qubits - values_dict[q] ^= values_dict[c] - elif op.gate == ops.SWAP: - a, b = op.qubits - values_dict[a], values_dict[b] = values_dict[b], values_dict[a] - elif op.gate == ops.TOFFOLI: - c1, c2, q = op.qubits - values_dict[q] ^= values_dict[c1] & values_dict[c2] - elif protocols.is_measurement(op): - measurement_values = np.array( - [[[values_dict[q] for q in op.qubits]]] * repetitions, dtype=np.uint8 - ) - key = op.gate.key # type: ignore - if key in results_dict: - if op._num_qubits_() != results_dict[key].shape[-1]: - raise ValueError( - f'Measurement shape {len(measurement_values)} does not match ' - f'{results_dict[key].shape[-1]} in {key}.' - ) - results_dict[key] = np.concatenate( - (results_dict[key], measurement_values), axis=1 - ) - else: - results_dict[key] = measurement_values - else: - raise ValueError( - f'{op} is not one of cirq.X, cirq.CNOT, cirq.SWAP, ' - 'cirq.CCNOT, or a measurement' - ) - - return results_dict +class ClassicalBasisState(qis.QuantumStateRepresentation): + """Represents a classical basis state for efficient state evolution.""" + + def __init__(self, initial_state: Union[List[int], np.ndarray]): + """Initializes the ClassicalBasisState object. + + Args: + initial_state: The initial state in the computational basis. + """ + self.basis = initial_state + + def copy(self, deep_copy_buffers: bool = True) -> 'ClassicalBasisState': + """Creates a copy of the ClassicalBasisState object. + + Args: + deep_copy_buffers: Whether to deep copy the internal buffers. + Returns: + A copy of the ClassicalBasisState object. + """ + return ClassicalBasisState( + initial_state=deepcopy(self.basis) if deep_copy_buffers else copy(self.basis) + ) + + def measure( + self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None + ) -> List[int]: + """Measures the density matrix. + + Args: + axes: The axes to measure. + seed: The random number seed to use. + Returns: + The measurements in order. + """ + return [self.basis[i] for i in axes] + + +class ClassicalBasisSimState(SimulationState[ClassicalBasisState]): + """Represents the state of a quantum simulation using classical basis states.""" + + def __init__( + self, + initial_state: Union[int, List[int]] = 0, + qubits: Optional[Sequence['cirq.Qid']] = None, + classical_data: Optional['cirq.ClassicalDataStore'] = None, + ): + """Initializes the ClassicalBasisSimState object. + + Args: + qubits: The qubits to simulate. + initial_state: The initial state for the simulation. + classical_data: The classical data container for the simulation. + + Raises: + ValueError: If qubits not provided and initial_state is int. + If initial_state is not an int, List[int], or np.ndarray. + + An initial_state value of type integer is parsed in big endian order. + """ + if isinstance(initial_state, int): + if qubits is None: + raise ValueError('qubits must be provided if initial_state is not List[int]') + state = ClassicalBasisState( + big_endian_int_to_bits(initial_state, bit_count=len(qubits)) + ) + elif isinstance(initial_state, (list, np.ndarray)): + state = ClassicalBasisState(initial_state) + else: + raise ValueError('initial_state must be an int or List[int] or np.ndarray') + super().__init__(state=state, qubits=qubits, classical_data=classical_data) + + def _act_on_fallback_(self, action, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True): + """Acts on the state with a given operation. + + Args: + action: The operation to apply. + qubits: The qubits to apply the operation to. + allow_decompose: Whether to allow decomposition of the operation. + + Returns: + True if the operation was applied successfully. + + Raises: + ValueError: If initial_state shape for type np.ndarray is not equal to 1. + If gate is not one of X, CNOT, SWAP, CCNOT, or a measurement. + """ + if isinstance(self._state.basis, np.ndarray) and len(self._state.basis.shape) != 1: + raise ValueError('initial_state shape for type np.ndarray is not equal to 1') + gate = action.gate if isinstance(action, ops.Operation) else action + mapped_qubits = [self.qubit_map[i] for i in qubits] + if _is_identity(gate): + pass + elif gate == ops.X: + (q,) = mapped_qubits + self._state.basis[q] ^= 1 + elif gate == ops.CNOT: + c, q = mapped_qubits + self._state.basis[q] ^= self._state.basis[c] + elif gate == ops.SWAP: + a, b = mapped_qubits + self._state.basis[a], self._state.basis[b] = self._state.basis[b], self._state.basis[a] + elif gate == ops.TOFFOLI: + c1, c2, q = mapped_qubits + self._state.basis[q] ^= self._state.basis[c1] & self._state.basis[c2] + else: + raise ValueError(f'{gate} is not one of X, CNOT, SWAP, CCNOT, or a measurement') + return True + + +class ClassicalStateStepResult( + sim.StepResultBase['ClassicalBasisSimState'], Generic[TSimulationState] +): + """The step result provided by `ClassicalStateSimulator.simulate_moment_steps`.""" + + +class ClassicalStateTrialResult( + sim.SimulationTrialResultBase['ClassicalBasisSimState'], Generic[TSimulationState] +): + """The trial result provided by `ClassicalStateSimulator.simulate`.""" + + +class ClassicalStateSimulator( + sim.SimulatorBase[ + ClassicalStateStepResult['ClassicalBasisSimState'], + ClassicalStateTrialResult['ClassicalBasisSimState'], + 'ClassicalBasisSimState', + ], + Generic[TSimulationState], +): + """A simulator that accepts only gates with classical counterparts.""" + + def __init__( + self, *, noise: 'cirq.NOISE_MODEL_LIKE' = None, split_untangled_states: bool = False + ): + """Initializes a ClassicalStateSimulator. + + Args: + noise: The noise model used by the simulator. + split_untangled_states: Whether to run the simulation as a product state. + + Raises: + ValueError: If noise_model is not None. + """ + if noise is not None: + raise ValueError(f'{noise=} is not supported') + super().__init__(noise=noise, split_untangled_states=split_untangled_states) + + def _create_simulator_trial_result( + self, + params: 'cirq.ParamResolver', + measurements: Dict[str, np.ndarray], + final_simulator_state: 'cirq.SimulationStateBase[ClassicalBasisSimState]', + ) -> 'ClassicalStateTrialResult[ClassicalBasisSimState]': + """Creates a trial result for the simulator. + + Args: + params: The parameter resolver for the simulation. + measurements: The measurement results. + final_simulator_state: The final state of the simulator. + Returns: + A trial result for the simulator. + """ + return ClassicalStateTrialResult( + params, measurements, final_simulator_state=final_simulator_state + ) + + def _create_step_result( + self, sim_state: 'cirq.SimulationStateBase[ClassicalBasisSimState]' + ) -> 'ClassicalStateStepResult[ClassicalBasisSimState]': + """Creates a step result for the simulator. + + Args: + sim_state: The current state of the simulator. + Returns: + A step result for the simulator. + """ + return ClassicalStateStepResult(sim_state) + + def _create_partial_simulation_state( + self, + initial_state: Any, + qubits: Sequence['cirq.Qid'], + classical_data: 'cirq.ClassicalDataStore', + ) -> 'ClassicalBasisSimState': + """Creates a partial simulation state for the simulator. + + Args: + initial_state: The initial state for the simulation. + qubits: The qubits associated with the state. + classical_data: The shared classical data container for this simulation. + Returns: + A partial simulation state. + """ + return ClassicalBasisSimState( + initial_state=initial_state, qubits=qubits, classical_data=classical_data + ) diff --git a/cirq-core/cirq/sim/classical_simulator_test.py b/cirq-core/cirq/sim/classical_simulator_test.py index d67fe911d1f..3cf8c170bd8 100644 --- a/cirq-core/cirq/sim/classical_simulator_test.py +++ b/cirq-core/cirq/sim/classical_simulator_test.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import numpy as np import pytest import cirq @@ -205,3 +206,99 @@ def test_compatible_measurement(): sim = cirq.ClassicalStateSimulator() res = sim.run(c, repetitions=3).records np.testing.assert_equal(res['key'], np.array([[[0, 0], [1, 1]]] * 3, dtype=np.uint8)) + + +def test_simulate_sweeps_param_resolver(): + q0, q1 = cirq.LineQubit.range(2) + simulator = cirq.ClassicalStateSimulator() + for b0 in [0, 1]: + for b1 in [0, 1]: + circuit = cirq.Circuit( + (cirq.X ** sympy.Symbol('b0'))(q0), (cirq.X ** sympy.Symbol('b1'))(q1) + ) + params = [ + cirq.ParamResolver({'b0': b0, 'b1': b1}), + cirq.ParamResolver({'b0': b1, 'b1': b0}), + ] + results = simulator.simulate_sweep(circuit, params=params) + + assert results[0].params == params[0] + assert results[1].params == params[1] + + +def test_create_partial_simulation_state_from_int_with_no_qubits(): + sim = cirq.ClassicalStateSimulator() + initial_state = 5 + qs = None + classical_data = cirq.value.ClassicalDataDictionaryStore() + with pytest.raises(ValueError): + sim._create_partial_simulation_state( + initial_state=initial_state, qubits=qs, classical_data=classical_data + ) + + +def test_create_partial_simulation_state_from_invalid_state(): + sim = cirq.ClassicalStateSimulator() + initial_state = None + qs = cirq.LineQubit.range(2) + classical_data = cirq.value.ClassicalDataDictionaryStore() + with pytest.raises(ValueError): + sim._create_partial_simulation_state( + initial_state=initial_state, qubits=qs, classical_data=classical_data + ) + + +def test_create_partial_simulation_state_from_int(): + sim = cirq.ClassicalStateSimulator() + initial_state = 15 + qs = cirq.LineQubit.range(4) + classical_data = cirq.value.ClassicalDataDictionaryStore() + expected_result = [1, 1, 1, 1] + result = sim._create_partial_simulation_state( + initial_state=initial_state, qubits=qs, classical_data=classical_data + )._state.basis + assert result == expected_result + + +def test_create_valid_partial_simulation_state_from_list(): + sim = cirq.ClassicalStateSimulator() + initial_state = [1, 1, 1, 1] + qs = cirq.LineQubit.range(4) + classical_data = cirq.value.ClassicalDataDictionaryStore() + expected_result = [1, 1, 1, 1] + result = sim._create_partial_simulation_state( + initial_state=initial_state, qubits=qs, classical_data=classical_data + )._state.basis + assert result == expected_result + + +def test_create_valid_partial_simulation_state_from_np(): + sim = cirq.ClassicalStateSimulator() + initial_state = np.array([1, 1]) + qs = cirq.LineQubit.range(2) + classical_data = cirq.value.ClassicalDataDictionaryStore() + sim_state = sim._create_partial_simulation_state( + initial_state=initial_state, qubits=qs, classical_data=classical_data + ) + sim_state._act_on_fallback_(action=cirq.CX, qubits=qs) + result = sim_state._state.basis + expected_result = np.array([1, 0]) + np.testing.assert_equal(result, expected_result) + + +def test_create_invalid_partial_simulation_state_from_np(): + initial_state = np.array([[1, 1], [1, 1]]) + qs = cirq.LineQubit.range(2) + classical_data = cirq.value.ClassicalDataDictionaryStore() + sim = cirq.ClassicalStateSimulator() + sim_state = sim._create_partial_simulation_state( + initial_state=initial_state, qubits=qs, classical_data=classical_data + ) + with pytest.raises(ValueError): + sim_state._act_on_fallback_(action=cirq.CX, qubits=qs) + + +def test_noise_model(): + noise_model = cirq.NoiseModel.from_noise_model_like(cirq.depolarize(p=0.01)) + with pytest.raises(ValueError): + cirq.ClassicalStateSimulator(noise=noise_model)