Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Caching grouped results and refactor sampling estimator #302

Merged
merged 23 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions packages/core/quri_parts/core/estimator/sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@
concurrent_sampling_estimate,
create_sampling_concurrent_estimator,
create_sampling_estimator,
get_estimate_from_sampling_result,
sampling_estimate,
)
from .estimator_helpers import (
CircuitShotPairPreparationFunction,
distribute_shots_among_pauli_sets,
get_sampling_circuits_and_shots,
)
from .overlap_estimator import (
create_sampling_overlap_estimator,
create_sampling_overlap_weighted_sum_estimator,
Expand Down Expand Up @@ -59,14 +65,20 @@
"general_pauli_sum_expectation_estimator",
"general_pauli_covariance_estimator",
"general_pauli_sum_sample_variance",
"get_estimate_from_sampling_result",
"trivial_pauli_expectation_estimator",
"trivial_pauli_covariance_estimator",
"sampling_estimate",
"create_sampling_estimator",
"create_fixed_operator_sampling_esimator",
"create_fixed_operator_sampling_concurrent_esimator",
"concurrent_sampling_estimate",
"create_sampling_concurrent_estimator",
"sampling_overlap_estimate",
"create_sampling_overlap_estimator",
"sampling_overlap_weighted_sum_estimate",
"create_sampling_overlap_weighted_sum_estimator",
"CircuitShotPairPreparationFunction",
"get_sampling_circuits_and_shots",
"distribute_shots_among_pauli_sets",
]
77 changes: 46 additions & 31 deletions packages/core/quri_parts/core/estimator/sampling/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
general_pauli_sum_sample_variance,
)
from quri_parts.core.measurement import (
CommutablePauliSetMeasurement,
CommutablePauliSetMeasurementFactory,
PauliReconstructorFactory,
)
Expand All @@ -35,6 +36,12 @@
)
from quri_parts.core.state import CircuitQuantumState

from .estimator_helpers import (
CircuitShotPairPreparationFunction,
distribute_shots_among_pauli_sets,
get_sampling_circuits_and_shots,
)


class _Estimate:
def __init__(
Expand Down Expand Up @@ -82,13 +89,27 @@ class _ConstEstimate:
error: float = 0.0


def get_estimate_from_sampling_result(
op: Operator,
measurement_groups: Iterable[CommutablePauliSetMeasurement],
const: complex,
sampling_counts: Iterable[MeasurementCounts],
) -> Estimate[complex]:
"""Converts sampling counts into the estimation of the operator's
expectation value."""
pauli_sets = tuple(m.pauli_set for m in measurement_groups)
pauli_recs = tuple(m.pauli_reconstructor_factory for m in measurement_groups)
return _Estimate(op, const, pauli_sets, pauli_recs, tuple(sampling_counts))


def sampling_estimate(
op: Estimatable,
state: CircuitQuantumState,
total_shots: int,
sampler: ConcurrentSampler,
measurement_factory: CommutablePauliSetMeasurementFactory,
shots_allocator: PauliSamplingShotsAllocator,
circuit_shot_pair_prep_fn: CircuitShotPairPreparationFunction = get_sampling_circuits_and_shots, # noqa: E501
) -> Estimate[complex]:
"""Estimate expectation value of a given operator with a given state by
sampling measurement.
Expand All @@ -104,7 +125,10 @@ def sampling_estimate(
a measurement scheme for Pauli operators constituting the original operator.
shots_allocator: A function that allocates the total shots to Pauli groups to
be measured.

circuit_shot_pair_prep_fn: A :class:`~CircuitShotPairPreparationFunction` that
prepares the set of circuits to perform measurement with. It is default to
a function that concatenates the measurement circuits after the state
preparation circuit.
Returns:
The estimated value (can be accessed with :attr:`.value`) with standard error
of estimation (can be accessed with :attr:`.error`).
Expand All @@ -118,36 +142,16 @@ def sampling_estimate(
if len(op) == 1 and PAULI_IDENTITY in op:
return _ConstEstimate(op[PAULI_IDENTITY])

# If there is a standalone Identity group then eliminate, else set const 0.
const: complex = 0.0
measurements = []
for m in measurement_factory(op):
if m.pauli_set == {PAULI_IDENTITY}:
const = op[PAULI_IDENTITY]
else:
measurements.append(m)

pauli_sets = tuple(m.pauli_set for m in measurements)
shot_allocs = shots_allocator(op, pauli_sets, total_shots)
shots_map = {pauli_set: n_shots for pauli_set, n_shots in shot_allocs}

# Eliminate pauli sets which are allocated no shots
measurement_circuit_shots = [
(m, state.circuit + m.measurement_circuit, shots_map[m.pauli_set])
for m in measurements
if shots_map[m.pauli_set] > 0
]

circuit_and_shots = [
(circuit, shots) for _, circuit, shots in measurement_circuit_shots
]
sampling_counts = sampler(circuit_and_shots)
const = op.constant
measurements = measurement_factory(op)
measurements = [m for m in measurements if m.pauli_set != {PAULI_IDENTITY}]

pauli_sets = tuple(m.pauli_set for m, _, _ in measurement_circuit_shots)
pauli_recs = tuple(
m.pauli_reconstructor_factory for m, _, _ in measurement_circuit_shots
shots_map = distribute_shots_among_pauli_sets(
op, measurements, shots_allocator, total_shots
)
return _Estimate(op, const, pauli_sets, pauli_recs, tuple(sampling_counts))
circuit_and_shots = circuit_shot_pair_prep_fn(state, measurements, shots_map)
sampling_counts = sampler(circuit_and_shots)
return get_estimate_from_sampling_result(op, measurements, const, sampling_counts)


def create_sampling_estimator(
Expand Down Expand Up @@ -185,6 +189,7 @@ def concurrent_sampling_estimate(
sampler: ConcurrentSampler,
measurement_factory: CommutablePauliSetMeasurementFactory,
shots_allocator: PauliSamplingShotsAllocator,
circuit_shot_pair_prep_fn: CircuitShotPairPreparationFunction = get_sampling_circuits_and_shots, # noqa: E501
) -> Iterable[Estimate[complex]]:
"""Estimate expectation value of given operators with given states by
sampling measurement.
Expand All @@ -200,7 +205,10 @@ def concurrent_sampling_estimate(
a measurement scheme for Pauli operators constituting the original operator.
shots_allocator: A function that allocates the total shots to Pauli groups to
be measured.

circuit_shot_pair_prep_fn: A :class:`~CircuitShotPairPreparationFunction` that
prepares the set of circuits to perform measurement with. It is default to
a function that concatenates the measurement circuits after the state
preparation circuit.
Returns:
The estimated values (can be accessed with :attr:`.value`) with standard errors
of estimation (can be accessed with :attr:`.error`).
Expand All @@ -224,9 +232,16 @@ def concurrent_sampling_estimate(
states = [next(iter(states))] * num_ops
if num_ops == 1:
operators = [next(iter(operators))] * num_states

return [
sampling_estimate(
op, state, total_shots, sampler, measurement_factory, shots_allocator
op,
state,
total_shots,
sampler,
measurement_factory,
shots_allocator,
circuit_shot_pair_prep_fn,
)
for op, state in zip(operators, states)
]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Iterable
from typing import Callable

from typing_extensions import TypeAlias

from quri_parts.circuit import NonParametricQuantumCircuit
from quri_parts.core.measurement import CommutablePauliSetMeasurement
from quri_parts.core.operator import CommutablePauliSet, Operator
from quri_parts.core.sampling import PauliSamplingShotsAllocator
from quri_parts.core.state import CircuitQuantumState

#: A function that returns the sequence of (circuit, shot) pairs for performing
#: sampling estimation on the given state. The default operation is that
#: it concatenates the measurement circuits determined by the grouping scheme
#: to the circuit that prepares the state. This is done with the
#: `circuit_shot_pairs_preparation_fn` below. Users may customize this function if
#: additional circuit operations needs to be done other than simple concatenation.
CircuitShotPairPreparationFunction: TypeAlias = Callable[
[
CircuitQuantumState,
Iterable[CommutablePauliSetMeasurement],
dict[CommutablePauliSet, int],
],
Iterable[tuple[NonParametricQuantumCircuit, int]],
]


def distribute_shots_among_pauli_sets(
operator: Operator,
measurement_groups: Iterable[CommutablePauliSetMeasurement],
shots_allocator: PauliSamplingShotsAllocator,
total_shots: int,
) -> dict[CommutablePauliSet, int]:
"""Distribute shots to each commuting pauli sets.

Args:
operator: The operator to be measured.
measurement_groups: Sequence of :class:`~CommutablePauliSetMeasurement` that
corresponds to the grouping result of the operator.
shot_allocator: A function that allocates the total shots to Pauli groups to
be measured.
total_shots: Total number of shots available for sampling measurements.
"""
pauli_sets = {m.pauli_set for m in measurement_groups}
shot_allocs = shots_allocator(operator, pauli_sets, total_shots)
return {pauli_set: n_shots for pauli_set, n_shots in shot_allocs}


def get_sampling_circuits_and_shots(
state: CircuitQuantumState,
measurement_groups: Iterable[CommutablePauliSetMeasurement],
shots_map: dict[CommutablePauliSet, int],
) -> Iterable[tuple[NonParametricQuantumCircuit, int]]:
"""Sets up the (circuit, shot) pairs for performing sampling estimation.
The circuit is given by the measurement circuit concatenated after the
circuit held inside the state.

Args:
state: The state on which the expectation value is estimated.
measurement_groups: Sequence of :class:`~CommutablePauliSetMeasurement` that
corresponds to the grouping result of the operator.
shots_map: A dictionary whose key is the commuting pauli set and the value is
the shot count assigned to the commuting pauli set.
"""
return [
(
state.circuit + m.measurement_circuit,
n_shots,
)
for m in measurement_groups
if (n_shots := shots_map[m.pauli_set]) > 0
]
57 changes: 57 additions & 0 deletions packages/core/quri_parts/core/measurement/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, Union

from quri_parts.core.operator import Operator, PauliLabel

from .bitwise_commuting_pauli import (
bitwise_commuting_pauli_measurement,
bitwise_commuting_pauli_measurement_circuit,
Expand Down Expand Up @@ -39,6 +43,58 @@
#: commutable Pauli operators and returns measurement schemes for them.
CommutablePauliSetMeasurementFactory = CommutablePauliSetMeasurementFactory


class CachedMeasurementFactory:
"""A class decorator that converts a
:class:`CommutablePauliSetMeasurementFactory` to a new
:class:`CommutablePauliSetMeasurementFactory` runs the same grouping
algorithm but caches grouping result for later usage.

Example:
>>> cached_measurement_factory = CachedMeasuremetFactory(
... bitwise_commuting_pauli_measurement
... )
>>> operator = Operator({
... pauli_label("X0 Y1"): 1,
... pauli_label("X0 Z2"): 2,
... pauli_label("Y0 Z2"): 3,
... PAULI_IDENTITY: 4
... })
>>> cached_measurement_factory(operator)
"""

def __init__(
self, measurement_factory: CommutablePauliSetMeasurementFactory
) -> None:
self._measurement_factory = measurement_factory
self._cache: dict[
frozenset[tuple[PauliLabel, complex]],
Iterable[CommutablePauliSetMeasurement],
] = {}

def __call__(
self, paulis: Union[Operator, Iterable[PauliLabel]]
) -> Iterable[CommutablePauliSetMeasurement]:
if not isinstance(paulis, Operator):
paulis = Operator({p: 1 + 0j for p in paulis})

op_key = frozenset(paulis.items())
if op_key in self._cache:
return self._cache[op_key]
groups = self._measurement_factory(paulis)
self._cache[op_key] = groups
return groups

@property
def cached_groups(
self,
) -> dict[
frozenset[tuple[PauliLabel, complex]],
Iterable[CommutablePauliSetMeasurement],
]:
return self._cache.copy()


__all__ = [
"PauliMeasurementCircuitGeneration",
"PauliReconstructor",
Expand All @@ -50,4 +106,5 @@
"bitwise_pauli_reconstructor_factory",
"bitwise_commuting_pauli_measurement",
"individual_pauli_measurement",
"CachedMeasuremetFactory",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from quri_parts.circuit import QuantumCircuit
from quri_parts.core.estimator.sampling import (
distribute_shots_among_pauli_sets,
get_sampling_circuits_and_shots,
)
from quri_parts.core.measurement import bitwise_commuting_pauli_measurement
from quri_parts.core.operator import CommutablePauliSet, Operator, pauli_label
from quri_parts.core.sampling.shots_allocator import (
create_equipartition_shots_allocator,
)
from quri_parts.core.state import GeneralCircuitQuantumState


def test_distribute_shots_among_pauli_sets() -> None:
operator = Operator({pauli_label("X0 X1"): 1, pauli_label("Y0 X1"): 1})
groups = bitwise_commuting_pauli_measurement(operator)
shots_allocator = create_equipartition_shots_allocator()
expected_distribution = {
frozenset({pauli_label("X0 X1")}): 500,
frozenset({pauli_label("Y0 X1")}): 500,
}

distribution = distribute_shots_among_pauli_sets(
operator, groups, shots_allocator, total_shots=1000
)
assert distribution == expected_distribution


def test_get_sampling_circuits_and_shots() -> None:
circuit = QuantumCircuit(2)
circuit.add_H_gate(0)
circuit.add_CNOT_gate(0, 1)
state = GeneralCircuitQuantumState(2, circuit)

operator = Operator({pauli_label("X0 X1"): 1, pauli_label("Y0 X1"): 1})
groups = bitwise_commuting_pauli_measurement(operator)
distribution: dict[CommutablePauliSet, int] = {
frozenset({pauli_label("X0 X1")}): 500,
frozenset({pauli_label("Y0 X1")}): 500,
}

pairs = get_sampling_circuits_and_shots(state, groups, distribution)
assert len(list(pairs)) == 2
for p, g in zip(pairs, groups):
assert p == (circuit.combine(g.measurement_circuit), 500)
Loading
Loading