Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General sampling estimator #321

Merged
merged 4 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions packages/core/quri_parts/core/estimator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,8 @@ class GeneralQuantumEstimator(Generic[_StateT, _ParametricStateT]):
- Act as :class:`ConcurrentParametricQuantumEstimator`:
- Estimatable, _ParametricStateT, [[float, ...], ...] -> [Estimate, ...]

When a :class:`GeneralEstimator` is called directly with one of the combinations
above, it needs to parse the input arguments to figure out which of
When a :class:`GeneralQuantumEstimator` is called directly with one of the
combinations above, it needs to parse the input arguments to figure out which of
:class:`QuantumEstimator`, :class:`ConcurrentQuantumEstimator`,
:class:`ParametricQuantumEstimator`, or :class:`ConcurrentParametricEstimator`
is required to perform the estimation. To avoid such performance penalty, you may
Expand Down Expand Up @@ -597,12 +597,13 @@ def create_general_estimator_from_estimator(
def create_general_estimator_from_estimator(
estimator: QuantumEstimator[_StateT],
) -> GeneralQuantumEstimator[_StateT, _ParametricStateT]:
"""Creates a :class:`GeneralEstimator` from a :class:`QuantumEstimator`.
"""Creates a :class:`GeneralQuantumEstimator` from a
:class:`QuantumEstimator`.

Note:
- The concurrencies of the :class:`ConcurrentQuantumEstimaror` and
`ConcurrentParametricQuantumEstimaror` will be set to 1 when a
:class:`GeneralEstimator` is created with this function.
:class:`GeneralQuantumEstimator` is created with this function.
- When circuit conversion is involved in the estimator execution, the
parametric estimator created from this function will bind the parameter
first, and then convert the bound circuit every time the patametric estimator
Expand Down Expand Up @@ -645,7 +646,7 @@ def create_general_estimator_from_concurrent_estimator(
def create_general_estimator_from_concurrent_estimator(
concurrent_estimator: ConcurrentQuantumEstimator[_StateT],
) -> GeneralQuantumEstimator[_StateT, _ParametricStateT]:
"""Creates a :class:`GeneralEstimator` from a
"""Creates a :class:`GeneralQuantumEstimator` from a
:class:`ConcurrentQuantumEstimator`.

Note:
Expand Down
2 changes: 2 additions & 0 deletions packages/core/quri_parts/core/estimator/sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from .estimator import (
concurrent_sampling_estimate,
create_general_sampling_estimator,
create_sampling_concurrent_estimator,
create_sampling_estimator,
get_estimate_from_sampling_result,
Expand Down Expand Up @@ -81,4 +82,5 @@
"CircuitShotPairPreparationFunction",
"get_sampling_circuits_and_shots",
"distribute_shots_among_pauli_sets",
"create_general_sampling_estimator",
]
27 changes: 26 additions & 1 deletion packages/core/quri_parts/core/estimator/sampling/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
ConcurrentQuantumEstimator,
Estimatable,
Estimate,
GeneralQuantumEstimator,
QuantumEstimator,
create_general_estimator_from_estimator,
)
from quri_parts.core.estimator.sampling.pauli import (
general_pauli_sum_expectation_estimator,
Expand All @@ -35,7 +37,7 @@
MeasurementCounts,
PauliSamplingShotsAllocator,
)
from quri_parts.core.state import CircuitQuantumState
from quri_parts.core.state import CircuitQuantumState, ParametricCircuitQuantumState

from .estimator_helpers import (
CircuitShotPairPreparationFunction,
Expand Down Expand Up @@ -286,3 +288,26 @@ def estimator(
)

return estimator


def create_general_sampling_estimator(
total_shots: int,
sampler: ConcurrentSampler,
measurement_factory: CommutablePauliSetMeasurementFactory,
shots_allocator: PauliSamplingShotsAllocator,
) -> GeneralQuantumEstimator[CircuitQuantumState, ParametricCircuitQuantumState]:
"""Creates a :class:`GeneralQuantumEstimator` that performs sampling
estimation.

Args:
total_shots: Total number of shots available for sampling measurements.
sampler: A Sampler that actually performs the sampling measurements.
measurement_factory: A function that performs Pauli grouping and returns
a measurement scheme for Pauli operators constituting the original operator.
shots_allocator: A function that allocates the total shots to Pauli groups to
be measured.
"""
sampling_estimator = create_sampling_estimator(
total_shots, sampler, measurement_factory, shots_allocator
)
return create_general_estimator_from_estimator(sampling_estimator)
121 changes: 119 additions & 2 deletions packages/core/tests/core/estimator/sampling/test_sampling_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from collections.abc import Collection, Iterable
from math import sqrt
from typing import Any, Union, cast
from unittest.mock import Mock

import numpy as np
import pytest

from quri_parts.circuit import H, NonParametricQuantumCircuit, QuantumCircuit, X
from quri_parts.circuit import (
H,
NonParametricQuantumCircuit,
QuantumCircuit,
UnboundParametricQuantumCircuit,
X,
)
from quri_parts.core.estimator import Estimate
from quri_parts.core.estimator.sampling import (
concurrent_sampling_estimate,
create_general_sampling_estimator,
create_sampling_concurrent_estimator,
create_sampling_estimator,
get_estimate_from_sampling_result,
Expand All @@ -44,7 +53,11 @@
from quri_parts.core.sampling.shots_allocator import (
create_equipartition_shots_allocator,
)
from quri_parts.core.state import CircuitQuantumState, ComputationalBasisState
from quri_parts.core.state import (
CircuitQuantumState,
ComputationalBasisState,
ParametricCircuitQuantumState,
)

n_qubits = 3

Expand Down Expand Up @@ -510,3 +523,107 @@ def test_sampling_concurrent_estimator(self) -> None:
assert len(estimate_list) == 2
assert_sample(estimate_list[0])
assert estimate_list[1].value == (1 - 1 + 2 - 4) / 8


class GeneralSamplingEstimator(unittest.TestCase):
def setUp(self) -> None:
s = mock_sampler()
self.general_estimator = create_general_sampling_estimator(
total_shots(),
s,
bitwise_commuting_pauli_measurement,
allocator,
)

def test_general_quantum_estimator(self) -> None:
estimate = self.general_estimator(operator(), initial_state())
assert_sample(estimate)

def test_concurrent_estimate(self) -> None:
estimates = self.general_estimator(
operator(),
[initial_state(), ComputationalBasisState(3, bits=0b001)],
)

estimate_list = list(estimates)
assert len(estimate_list) == 2
assert_sample(estimate_list[0])
assert_sample(estimate_list[1])

estimates = self.general_estimator(
[operator()],
[initial_state(), ComputationalBasisState(3, bits=0b001)],
)

estimate_list = list(estimates)
assert len(estimate_list) == 2
assert_sample(estimate_list[0])
assert_sample(estimate_list[1])

estimates = self.general_estimator(
[operator(), pauli_label("Z0")],
ComputationalBasisState(3, bits=0b001),
)

estimate_list = list(estimates)
assert len(estimate_list) == 2
assert_sample(estimate_list[0])
assert estimate_list[1].value == (1 - 1 + 2 - 4) / 8

estimates = self.general_estimator(
[operator(), pauli_label("Z0")],
[ComputationalBasisState(3, bits=0b001)],
)

estimate_list = list(estimates)
assert len(estimate_list) == 2
assert_sample(estimate_list[0])
assert estimate_list[1].value == (1 - 1 + 2 - 4) / 8

estimates = self.general_estimator(
[operator(), pauli_label("Z0")],
[initial_state(), ComputationalBasisState(3, bits=0b001)],
)

estimate_list = list(estimates)
assert len(estimate_list) == 2
assert_sample(estimate_list[0])
assert estimate_list[1].value == (1 - 1 + 2 - 4) / 8

def test_parametric_estimate(self) -> None:
circuit = UnboundParametricQuantumCircuit(n_qubits)
circuit.add_X_gate(0)
circuit.add_ParametricRX_gate(0)
circuit.add_ParametricRY_gate(1)
circuit.add_ParametricRZ_gate(2)

state = ParametricCircuitQuantumState(n_qubits, circuit)

estimate = self.general_estimator(operator(), state, [0, 1, 2])
assert_sample(estimate)

estimate = self.general_estimator(operator(), state, np.array([0, 1, 2]))
assert_sample(estimate)

def test_concurrent_parametric_estimate(self) -> None:
circuit = UnboundParametricQuantumCircuit(n_qubits)
circuit.add_X_gate(0)
circuit.add_ParametricRX_gate(0)
circuit.add_ParametricRY_gate(1)
circuit.add_ParametricRZ_gate(2)

state = ParametricCircuitQuantumState(n_qubits, circuit)

estimates = self.general_estimator(operator(), state, [[0, 1, 2], [4, 5, 6]])
estimate_list = list(estimates)
assert len(estimate_list) == 2
assert_sample(estimate_list[0])
assert_sample(estimate_list[1])

estimates = self.general_estimator(
operator(), state, np.array([[0, 1, 2], [4, 5, 6]])
)
estimate_list = list(estimates)
assert len(estimate_list) == 2
assert_sample(estimate_list[0])
assert_sample(estimate_list[1])
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def test_with_vector(self) -> None:
assert np.isclose(estimate_list[1].value, 12 * np.sqrt(2))


class TestGeneralEstimator(unittest.TestCase):
class TestGeneralQuantumEstimator(unittest.TestCase):
def setUp(self) -> None:
self.op_0 = PAULI_IDENTITY
self.op_1 = Operator({pauli_label("X0"): 1, pauli_label("Y0"): 1})
Expand Down
Loading