Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic general estimator creation #293

Merged
merged 17 commits into from
Dec 26, 2023
337 changes: 336 additions & 1 deletion packages/core/quri_parts/core/estimator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@

from abc import abstractproperty
from collections.abc import Iterable, Sequence
from typing import Callable, Optional, Protocol, TypeVar, Union, cast, overload
from dataclasses import dataclass
from typing import Callable, Generic, Optional, Protocol, TypeVar, Union, cast, overload

from typing_extensions import TypeAlias

from quri_parts.core.operator import Operator, PauliLabel
from quri_parts.core.state import (
CircuitQuantumState,
ParametricCircuitQuantumState,
ParametricQuantumStateT,
ParametricQuantumStateVector,
QuantumStateT,
QuantumStateVector,
)

Expand Down Expand Up @@ -329,3 +332,335 @@ def parametric_estimator(
return estimator(bound_kets, bound_bras, weights)

return parametric_estimator


@overload
def create_concurrent_parametric_estimator_from_concurrent_estimator(
concurrent_estimator: ConcurrentQuantumEstimator[_StateT],
) -> ConcurrentParametricQuantumEstimator[_ParametricStateT]:
...


@overload
def create_concurrent_parametric_estimator_from_concurrent_estimator(
concurrent_estimator: ConcurrentQuantumEstimator[CircuitQuantumState],
) -> ConcurrentParametricQuantumEstimator[ParametricCircuitQuantumState]:
...


@overload
def create_concurrent_parametric_estimator_from_concurrent_estimator(
concurrent_estimator: ConcurrentQuantumEstimator[QuantumStateVector],
) -> ConcurrentParametricQuantumEstimator[ParametricQuantumStateVector]:
...


def create_concurrent_parametric_estimator_from_concurrent_estimator(
concurrent_estimator: ConcurrentQuantumEstimator[_StateT],
) -> ConcurrentParametricQuantumEstimator[_ParametricStateT]:
"""Creates a concurrent parametric estimator from a concurrent
estimator."""

def concurrent_parametric_estimator(
operator: Estimatable,
state: _ParametricStateT,
seq_of_params: Sequence[Sequence[float]],
) -> Iterable[Estimate[complex]]:
bound_states = cast(
Sequence[_StateT], [state.bind_parameters(param) for param in seq_of_params]
)
return concurrent_estimator([operator], bound_states)

return concurrent_parametric_estimator


@overload
def create_parametric_estimator_from_concurrent_estimator(
concurrent_estimator: ConcurrentQuantumEstimator[_StateT],
) -> ParametricQuantumEstimator[_ParametricStateT]:
...


@overload
def create_parametric_estimator_from_concurrent_estimator(
concurrent_estimator: ConcurrentQuantumEstimator[CircuitQuantumState],
) -> ParametricQuantumEstimator[ParametricCircuitQuantumState]:
...


@overload
def create_parametric_estimator_from_concurrent_estimator(
concurrent_estimator: ConcurrentQuantumEstimator[QuantumStateVector],
) -> ParametricQuantumEstimator[ParametricQuantumStateVector]:
...


def create_parametric_estimator_from_concurrent_estimator(
concurrent_estimator: ConcurrentQuantumEstimator[_StateT],
) -> ParametricQuantumEstimator[_ParametricStateT]:
"""Creates a parametric estimator from a concurrent estimator."""

def parametric_estimator(
operator: Estimatable,
state: _ParametricStateT,
params: Sequence[float],
) -> Estimate[complex]:
bound_states = cast(_StateT, state.bind_parameters(params))
estimate = concurrent_estimator([operator], [bound_states])
return next(iter(estimate))

return parametric_estimator


def create_estimator_from_concurrent_estimator(
concurrent_estimator: ConcurrentQuantumEstimator[_StateT],
) -> QuantumEstimator[_StateT]:
"""Creates an estimator from a concurrent estimator."""

def estimator(
operator: Estimatable,
state: _StateT,
) -> Estimate[complex]:
return next(iter(concurrent_estimator([operator], [state])))

return estimator


def create_concurrent_estimator_from_estimator(
estimator: QuantumEstimator[_StateT],
) -> ConcurrentQuantumEstimator[_StateT]:
"""Creates a concurrent estimator from an estimator."""

def concurrent_estimator(
operators: Sequence[Estimatable],
states: Sequence[_StateT],
) -> Sequence[Estimate[complex]]:
num_ops = len(operators)
num_states = len(states)

if num_ops == 0:
raise ValueError("No operator specified.")

if num_states == 0:
raise ValueError("No state specified.")

if num_ops > 1 and num_states > 1 and num_ops != num_states:
raise ValueError(
f"Number of operators ({num_ops}) does not match"
f"number of states ({num_states})."
)

if num_states == 1:
states = [next(iter(states))] * num_ops

if num_ops == 1:
operators = [next(iter(operators))] * num_states

return [estimator(op, state) for op, state in zip(operators, states)]

return concurrent_estimator


@dataclass
class GeneralQuantumEstimator(Generic[QuantumStateT, ParametricQuantumStateT]):
r"""A callable dataclass that holds :class:`QuantumEstimator`,
:class:`ConcurrentQuantumEstimator`, :class:`ParametricQuantumEstimator`,
or :class:`ConcurrentParametricEstimator`. When it is used as a callable function,
it allows generic inputs for expectation value estimation. The allowed inputs for
using it as a callable function are:

- Act as :class:`QuantumEstimator`:
- Estimatable, QuantumStateT -> Estimate
- Act as :class:`ConcurrentQuantumEstimator`:
- Estimatable, [QuantumStateT, ...] -> [Estimate, ...]
- [Estimatable], [QuantumStateT, ...] -> [Estimate, ...]
- [Estimatable, ...], QuantumStateT -> [Estimate, ...]
- [Estimatable, ...], [QuantumStateT] -> [Estimate, ...]
- [Estimatable, ...], [QuantumStateT, ...] -> [Estimate, ...]
- Act as :class:`ParametricQuantumEstimator`:
- Estimatable, ParametricQuantumStateT, [float, ...] -> Estimate
- Act as :class:`ConcurrentParametricQuantumEstimator`:
- Estimatable, ParametricQuantumStateT, [[float, ...], ...] -> [Estimate, ...]

When a :class:`GeneralEstimator` is called directly with one of the combinations
above, it needs to parse the input arguments to figure out which of
:class:`QuantumEstimator`, :class:`ConcurrentQuantumEstimator`,
:class:`ParametricQuantumEstimator`, or :class:`ConcurrentParametricEstimator`
is required to perform the estimation. To avoid such performance penalty, you may
retrieve the desired estimator as a property directly.
"""

estimator: QuantumEstimator[QuantumStateT]
concurrent_estimator: ConcurrentQuantumEstimator[QuantumStateT]
parametric_estimator: ParametricQuantumEstimator[ParametricQuantumStateT]
concurrent_parametric_estimator: ConcurrentParametricQuantumEstimator[
ParametricQuantumStateT
]

@overload
def __call__(self, op: Estimatable, state: QuantumStateT) -> Estimate[complex]:
"""A :class:`QuantumEstimator`"""
...

@overload
def __call__(
self, op: Sequence[Estimatable], state: Sequence[QuantumStateT]
) -> Iterable[Estimate[complex]]:
"""A :class:`ConcurrentQuantumEstimator`"""
...

@overload
def __call__(
self, op: Estimatable, state: Sequence[QuantumStateT]
) -> Iterable[Estimate[complex]]:
"""A :class:`ConcurrentQuantumEstimator`"""
...

@overload
def __call__(
self, op: Sequence[Estimatable], state: QuantumStateT
) -> Iterable[Estimate[complex]]:
"""A :class:`ConcurrentQuantumEstimator`"""
...

@overload
def __call__(
self,
op: Estimatable,
state: ParametricQuantumStateT,
param: Sequence[float],
) -> Estimate[complex]:
"""A :class:`ParametricQuantumEstimator`"""
...

@overload
def __call__(
self,
op: Estimatable,
state: ParametricQuantumStateT,
param: Sequence[Sequence[float]],
) -> Iterable[Estimate[complex]]:
"""A :class:`ConcurrentParametricQuantumEstimator`"""
...

def __call__(
self,
op: Union[Estimatable, Sequence[Estimatable]],
state: Union[QuantumStateT, Sequence[QuantumStateT], ParametricQuantumStateT],
param: Optional[Union[Sequence[float], Sequence[Sequence[float]]]] = None,
) -> Union[Estimate[complex], Iterable[Estimate[complex]]]:
if param is None:
if isinstance(op, Operator) or isinstance(op, PauliLabel):
if isinstance(state, Sequence):
return self.concurrent_estimator([op], state)
state = cast(QuantumStateT, state)
return self.estimator(op, state)

if isinstance(state, Sequence):
return self.concurrent_estimator(op, state)
state = cast(QuantumStateT, state)
return self.concurrent_estimator(op, [state])

assert not isinstance(state, Sequence)
assert isinstance(op, Operator) or isinstance(op, PauliLabel)

state = cast(ParametricQuantumStateT, state)
if isinstance(param[0], Sequence):
param = cast(Sequence[Sequence[float]], param)
return self.concurrent_parametric_estimator(op, state, param)
param = cast(Sequence[float], param)
return self.parametric_estimator(op, state, param)


@overload
def create_general_estimator_from_estimator(
estimator: QuantumEstimator[CircuitQuantumState],
) -> GeneralQuantumEstimator[CircuitQuantumState, ParametricCircuitQuantumState]:
...


@overload
def create_general_estimator_from_estimator(
estimator: QuantumEstimator[QuantumStateVector],
) -> GeneralQuantumEstimator[QuantumStateVector, ParametricQuantumStateVector]:
...


def create_general_estimator_from_estimator(
estimator: QuantumEstimator[QuantumStateT],
) -> GeneralQuantumEstimator[QuantumStateT, ParametricQuantumStateT]:
"""Creates a :class:`GeneralEstimator` from a :class:`QuantumEstimator`.

Note:
- The concurrencies of the :class:`ConcurrentQuantumEstimaror` and
`ConcurrentParametricQuantumEstimaror` will be set to 1 when a
:class:`GeneralEstimator` is created with this function.
- When circuit conversion is involved in the estimator execution, the
parametric estimator created from this function will bind the parameter
first, and then convert the bound circuit every time the patametric estimator
is called.
"""
concurrent_estimator = create_concurrent_estimator_from_estimator(estimator)
parametric_estimator: ParametricQuantumEstimator[
ParametricQuantumStateT
] = create_parametric_estimator_from_concurrent_estimator(concurrent_estimator)

concurrent_parametric_estimator: ConcurrentParametricQuantumEstimator[
ParametricQuantumStateT
] = create_concurrent_parametric_estimator_from_concurrent_estimator(
concurrent_estimator
)
general_estimator = GeneralQuantumEstimator(
estimator,
concurrent_estimator,
parametric_estimator,
concurrent_parametric_estimator,
)

return general_estimator


@overload
def create_general_estimator_from_concurrent_estimator(
concurrent_estimator: ConcurrentQuantumEstimator[CircuitQuantumState],
) -> GeneralQuantumEstimator[CircuitQuantumState, ParametricCircuitQuantumState]:
...


@overload
def create_general_estimator_from_concurrent_estimator(
concurrent_estimator: ConcurrentQuantumEstimator[QuantumStateVector],
) -> GeneralQuantumEstimator[QuantumStateVector, ParametricQuantumStateVector]:
...


def create_general_estimator_from_concurrent_estimator(
concurrent_estimator: ConcurrentQuantumEstimator[QuantumStateT],
) -> GeneralQuantumEstimator[QuantumStateT, ParametricQuantumStateT]:
"""Creates a :class:`GeneralEstimator` from a
:class:`ConcurrentQuantumEstimator`.

Note:
- When circuit conversion is involved in the estimator execution, the
parametric estimator created from this function will bind the parameter
first, and then convert the bound circuit every time the patametric estimator
is called.
"""
estimator = create_estimator_from_concurrent_estimator(concurrent_estimator)
parametric_estimator: ParametricQuantumEstimator[
ParametricQuantumStateT
] = create_parametric_estimator_from_concurrent_estimator(concurrent_estimator)

concurrent_parametric_estimator: ConcurrentParametricQuantumEstimator[
ParametricQuantumStateT
] = create_concurrent_parametric_estimator_from_concurrent_estimator(
concurrent_estimator
)
general_estimator = GeneralQuantumEstimator(
estimator,
concurrent_estimator,
parametric_estimator,
concurrent_parametric_estimator,
)

return general_estimator
Loading
Loading