Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BasePrimitiveJob class and deprecate PrimitiveJob.submit method #11552

Merged
merged 4 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion qiskit/primitives/backend_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def _run(
job = PrimitiveJob(
self._call, circuit_indices, observable_indices, parameter_values, **run_options
)
job.submit()
job._submit()
return job

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion qiskit/primitives/backend_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,5 +205,5 @@ def _run(
self._circuits.append(circuit)
self._parameters.append(circuit.parameters)
job = PrimitiveJob(self._call, circuit_indices, parameter_values, **run_options)
job.submit()
job._submit()
return job
78 changes: 78 additions & 0 deletions qiskit/primitives/base/base_primitive_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# This code is part of Qiskit.
#
# (C) Copyright IBM 2024.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""
Primitive job abstract base class
"""

from abc import ABC, abstractmethod
from typing import Generic, TypeVar, Union

from ..containers import PrimitiveResult
from .base_result import BasePrimitiveResult

Result = TypeVar("Result", bound=Union[BasePrimitiveResult, PrimitiveResult])
Status = TypeVar("Status")
ihincks marked this conversation as resolved.
Show resolved Hide resolved


class BasePrimitiveJob(ABC, Generic[Result, Status]):
"""Primitive job abstract base class."""

def __init__(self, job_id: str, **kwargs) -> None:
"""Initializes the primitive job.

Args:
job_id: a unique id in the context of the primitive used to run the job.
t-imamichi marked this conversation as resolved.
Show resolved Hide resolved
kwargs: Any key value metadata to associate with this job.
"""
self._job_id = job_id
self.metadata = kwargs

def job_id(self) -> str:
"""Return a unique id identifying the job."""
return self._job_id

@abstractmethod
def result(self) -> Result:
"""Return the results of the job."""
raise NotImplementedError("Subclass of BasePrimitiveJob must implement `result` method.")

@abstractmethod
def status(self) -> Status:
"""Return the status of the job."""
raise NotImplementedError("Subclass of BasePrimitiveJob must implement `status` method.")

@abstractmethod
def done(self) -> bool:
"""Return whether the job has successfully run."""
raise NotImplementedError("Subclass of BasePrimitiveJob must implement `done` method.")

@abstractmethod
def running(self) -> bool:
"""Return whether the job is actively running."""
raise NotImplementedError("Subclass of BasePrimitiveJob must implement `running` method.")

@abstractmethod
def cancelled(self) -> bool:
"""Return whether the job has been cancelled."""
raise NotImplementedError("Subclass of BasePrimitiveJob must implement `cancelled` method.")

@abstractmethod
def in_final_state(self) -> bool:
"""Return whether the job is in a final job state such as ``DONE`` or ``ERROR``."""
raise NotImplementedError(
"Subclass of BasePrimitiveJob must implement `is_final_state` method."
)

@abstractmethod
def cancel(self):
"""Attempt to cancel the job."""
raise NotImplementedError("Subclass of BasePrimitiveJob must implement `cancel` method.")
2 changes: 1 addition & 1 deletion qiskit/primitives/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,5 +154,5 @@ def _run(
job = PrimitiveJob(
self._call, circuit_indices, observable_indices, parameter_values, **run_options
)
job.submit()
job._submit()
return job
99 changes: 81 additions & 18 deletions qiskit/primitives/primitive_job.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This code is part of Qiskit.
#
# (C) Copyright IBM 2022.
# (C) Copyright IBM 2022, 2024.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
Expand All @@ -13,53 +13,51 @@
Job implementation for the reference implementations of Primitives.
"""

import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Generic, TypeVar, Union
from typing import Callable, Generic, Optional, TypeVar, Union

from qiskit.providers import JobError, JobStatus, JobV1
from qiskit.providers import JobError, JobStatus, JobTimeoutError
from qiskit.providers.jobstatus import JOB_FINAL_STATES
from qiskit.utils.deprecation import deprecate_func

from .base.base_primitive_job import BasePrimitiveJob
from .base.base_result import BasePrimitiveResult
from .containers import PrimitiveResult

T = TypeVar("T", bound=Union[BasePrimitiveResult, PrimitiveResult])
Result = TypeVar("Result", bound=Union[BasePrimitiveResult, PrimitiveResult])
t-imamichi marked this conversation as resolved.
Show resolved Hide resolved


class PrimitiveJob(JobV1, Generic[T]):
class PrimitiveJob(BasePrimitiveJob[Result, JobStatus], Generic[Result]):
t-imamichi marked this conversation as resolved.
Show resolved Hide resolved
"""
PrimitiveJob class for the reference implemetations of Primitives.
Primitive job class for the reference implementations of Primitives.
"""

def __init__(self, function, *args, **kwargs):
"""
Args:
function: a callable function to execute the job.
"""
job_id = str(uuid.uuid4())
super().__init__(None, job_id)
super().__init__(str(uuid.uuid4()))
self._future = None
self._function = function
self._args = args
self._kwargs = kwargs

def submit(self):
def _submit(self):
if self._future is not None:
raise JobError("Primitive job has already been submitted.")
raise JobError("Primitive job has been submitted already.")

executor = ThreadPoolExecutor(max_workers=1) # pylint: disable=consider-using-with
self._future = executor.submit(self._function, *self._args, **self._kwargs)
executor.shutdown(wait=False)

def result(self) -> T:
"""Return the results of the job."""
def result(self) -> Result:
self._check_submitted()
return self._future.result()

def cancel(self):
self._check_submitted()
return self._future.cancel()

def status(self):
def status(self) -> JobStatus:
self._check_submitted()
if self._future.running():
return JobStatus.RUNNING
Expand All @@ -71,4 +69,69 @@ def status(self):

def _check_submitted(self):
if self._future is None:
raise JobError("Job not submitted yet!. You have to .submit() first!")
raise JobError("Primitive Job has not been submitted yet.")

def cancel(self):
self._check_submitted()
return self._future.cancel()

def done(self) -> bool:
return self.status() == JobStatus.DONE

def running(self) -> bool:
return self.status() == JobStatus.RUNNING

def cancelled(self) -> bool:
return self.status() == JobStatus.CANCELLED

def in_final_state(self) -> bool:
return self.status() in JOB_FINAL_STATES

@deprecate_func(since="0.46.0")
def submit(self):
"""Submit a job.

.. deprecated:: 0.46.0
``submit`` method is deprecated as of Qiskit 0.46 and will be removed
no earlier than 3 months after the release date.

"""
self._submit()

@deprecate_func(since="0.46.0")
def wait_for_final_state(
self, timeout: Optional[float] = None, wait: float = 5, callback: Optional[Callable] = None
) -> None:
"""Poll the job status until it progresses to a final state such as ``DONE`` or ``ERROR``.

.. deprecated:: 0.46.0
``wait_for_final_state`` method is deprecated as of Qiskit 0.46 and will be removed
no earlier than 3 months after the release date.

Args:
timeout: Seconds to wait for the job. If ``None``, wait indefinitely.
wait: Seconds between queries.
callback: Callback function invoked after each query.
The following positional arguments are provided to the callback function:

* job_id: Job ID
* job_status: Status of the job from the last query
* job: This BaseJob instance

Note: different subclass might provide different arguments to
the callback function.

Raises:
JobTimeoutError: If the job does not reach a final state before the
specified timeout.
"""
start_time = time.time()
status = self.status()
while status not in JOB_FINAL_STATES:
elapsed_time = time.time() - start_time
if timeout is not None and elapsed_time >= timeout:
raise JobTimeoutError(f"Timeout while waiting for job {self.job_id()}.")
if callback:
callback(self.job_id(), status, self)
time.sleep(wait)
status = self.status()
2 changes: 1 addition & 1 deletion qiskit/primitives/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _run(
self._qargs_list.append(qargs)
self._parameters.append(circuit.parameters)
job = PrimitiveJob(self._call, circuit_indices, parameter_values, **run_options)
job.submit()
job._submit()
return job

@staticmethod
Expand Down
10 changes: 10 additions & 0 deletions releasenotes/notes/update-primitive-job-f5c9b31f68c3ec3d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
upgrade:
- |
Added :class:`.BasePrimitiveJob` class as an abstract job class for Primitives
and made :class:`.PrimitiveJob` inherit :class:`.BasePrimitiveJob`
instead of :class:`.JobV1`.
deprecations:
- |
:meth:`.PrimitiveJob.submit` and :meth:`.PrimitiveJob.wait_for_final_state`
are deprecated and will be removed in the future release.
2 changes: 0 additions & 2 deletions test/python/primitives/test_backend_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from qiskit.circuit import QuantumCircuit
from qiskit.circuit.library import RealAmplitudes
from qiskit.primitives import BackendEstimator, EstimatorResult
from qiskit.providers import JobV1
from qiskit.providers.fake_provider import FakeNairobi, FakeNairobiV2
from qiskit.providers.fake_provider.fake_backend_v2 import FakeBackendSimple
from qiskit.quantum_info import SparsePauliOp
Expand Down Expand Up @@ -91,7 +90,6 @@ def test_estimator_run(self, backend):
# Specify the circuit and observable by indices.
# calculate [ <psi1(theta1)|H1|psi1(theta1)> ]
job = estimator.run([psi1], [hamiltonian1], [theta1])
self.assertIsInstance(job, JobV1)
result = job.result()
self.assertIsInstance(result, EstimatorResult)
np.testing.assert_allclose(result.values, [1.5555572817900956], rtol=0.5, atol=0.2)
Expand Down
3 changes: 1 addition & 2 deletions test/python/primitives/test_backend_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from qiskit import QuantumCircuit
from qiskit.circuit.library import RealAmplitudes
from qiskit.primitives import BackendSampler, SamplerResult
from qiskit.providers import JobStatus, JobV1
from qiskit.providers import JobStatus
from qiskit.providers.fake_provider import FakeNairobi, FakeNairobiV2
from qiskit.providers.basicaer import QasmSimulatorPy
from qiskit.test import QiskitTestCase
Expand Down Expand Up @@ -115,7 +115,6 @@ def test_sampler_run(self, backend):
bell = self._circuit[1]
sampler = BackendSampler(backend=backend)
job = sampler.run(circuits=[bell], shots=1000)
self.assertIsInstance(job, JobV1)
result = job.result()
self.assertIsInstance(result, SamplerResult)
self.assertEqual(result.quasi_dists[0].shots, 1000)
Expand Down
2 changes: 0 additions & 2 deletions test/python/primitives/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from qiskit.primitives import Estimator, EstimatorResult
from qiskit.primitives.base import validation
from qiskit.primitives.utils import _observable_key
from qiskit.providers import JobV1
from qiskit.quantum_info import Operator, Pauli, PauliList, SparsePauliOp
from qiskit.test import QiskitTestCase

Expand Down Expand Up @@ -68,7 +67,6 @@ def test_estimator_run(self):
# Specify the circuit and observable by indices.
# calculate [ <psi1(theta1)|H1|psi1(theta1)> ]
job = estimator.run([psi1], [hamiltonian1], [theta1])
self.assertIsInstance(job, JobV1)
result = job.result()
self.assertIsInstance(result, EstimatorResult)
np.testing.assert_allclose(result.values, [1.5555572817900956])
Expand Down
3 changes: 1 addition & 2 deletions test/python/primitives/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from qiskit.circuit import Parameter
from qiskit.circuit.library import RealAmplitudes, UnitaryGate
from qiskit.primitives import Sampler, SamplerResult
from qiskit.providers import JobStatus, JobV1
from qiskit.providers import JobStatus
from qiskit.test import QiskitTestCase


Expand Down Expand Up @@ -90,7 +90,6 @@ def test_sampler_run(self):
bell = self._circuit[1]
sampler = Sampler()
job = sampler.run(circuits=[bell])
self.assertIsInstance(job, JobV1)
result = job.result()
self.assertIsInstance(result, SamplerResult)
self._compare_probs(result.quasi_dists, self._target[1])
Expand Down