diff --git a/qiskit/primitives/backend_estimator.py b/qiskit/primitives/backend_estimator.py index eb97a3b95259..6aed3c6f0e32 100644 --- a/qiskit/primitives/backend_estimator.py +++ b/qiskit/primitives/backend_estimator.py @@ -291,7 +291,7 @@ def _run( job = PrimitiveJob( self._call, circuit_indices, observable_indices, parameter_values, **run_options ) - job.submit() + job._submit() return job @staticmethod diff --git a/qiskit/primitives/backend_sampler.py b/qiskit/primitives/backend_sampler.py index 140a3091f34a..1f588921cc79 100644 --- a/qiskit/primitives/backend_sampler.py +++ b/qiskit/primitives/backend_sampler.py @@ -205,5 +205,5 @@ def _run( self._circuits.append(circuit) self._parameters.append(circuit.parameters) job = PrimitiveJob(self._call, circuit_indices, parameter_values, **run_options) - job.submit() + job._submit() return job diff --git a/qiskit/primitives/base/base_primitive_job.py b/qiskit/primitives/base/base_primitive_job.py new file mode 100644 index 000000000000..b7d721c19031 --- /dev/null +++ b/qiskit/primitives/base/base_primitive_job.py @@ -0,0 +1,78 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2024. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +""" +Primitive job abstract base class +""" + +from abc import ABC, abstractmethod +from typing import Generic, TypeVar, Union + +from ..containers import PrimitiveResult +from .base_result import BasePrimitiveResult + +ResultT = TypeVar("ResultT", bound=Union[BasePrimitiveResult, PrimitiveResult]) +StatusT = TypeVar("StatusT") + + +class BasePrimitiveJob(ABC, Generic[ResultT, StatusT]): + """Primitive job abstract base class.""" + + def __init__(self, job_id: str, **kwargs) -> None: + """Initializes the primitive job. + + Args: + job_id: A unique id in the context of the primitive used to run the job. + kwargs: Any key value metadata to associate with this job. + """ + self._job_id = job_id + self.metadata = kwargs + + def job_id(self) -> str: + """Return a unique id identifying the job.""" + return self._job_id + + @abstractmethod + def result(self) -> ResultT: + """Return the results of the job.""" + raise NotImplementedError("Subclass of BasePrimitiveJob must implement `result` method.") + + @abstractmethod + def status(self) -> StatusT: + """Return the status of the job.""" + raise NotImplementedError("Subclass of BasePrimitiveJob must implement `status` method.") + + @abstractmethod + def done(self) -> bool: + """Return whether the job has successfully run.""" + raise NotImplementedError("Subclass of BasePrimitiveJob must implement `done` method.") + + @abstractmethod + def running(self) -> bool: + """Return whether the job is actively running.""" + raise NotImplementedError("Subclass of BasePrimitiveJob must implement `running` method.") + + @abstractmethod + def cancelled(self) -> bool: + """Return whether the job has been cancelled.""" + raise NotImplementedError("Subclass of BasePrimitiveJob must implement `cancelled` method.") + + @abstractmethod + def in_final_state(self) -> bool: + """Return whether the job is in a final job state such as ``DONE`` or ``ERROR``.""" + raise NotImplementedError( + "Subclass of BasePrimitiveJob must implement `is_final_state` method." + ) + + @abstractmethod + def cancel(self): + """Attempt to cancel the job.""" + raise NotImplementedError("Subclass of BasePrimitiveJob must implement `cancel` method.") diff --git a/qiskit/primitives/estimator.py b/qiskit/primitives/estimator.py index 9edfe35d7eb2..f300d377874d 100644 --- a/qiskit/primitives/estimator.py +++ b/qiskit/primitives/estimator.py @@ -154,5 +154,5 @@ def _run( job = PrimitiveJob( self._call, circuit_indices, observable_indices, parameter_values, **run_options ) - job.submit() + job._submit() return job diff --git a/qiskit/primitives/primitive_job.py b/qiskit/primitives/primitive_job.py index 93924b2a4c23..0a3497b5d2bb 100644 --- a/qiskit/primitives/primitive_job.py +++ b/qiskit/primitives/primitive_job.py @@ -1,6 +1,6 @@ # This code is part of Qiskit. # -# (C) Copyright IBM 2022. +# (C) Copyright IBM 2022, 2024. # # This code is licensed under the Apache License, Version 2.0. You may # obtain a copy of this license in the LICENSE.txt file in the root directory @@ -13,53 +13,47 @@ Job implementation for the reference implementations of Primitives. """ +import time import uuid from concurrent.futures import ThreadPoolExecutor -from typing import Generic, TypeVar, Union +from typing import Callable, Optional -from qiskit.providers import JobError, JobStatus, JobV1 +from qiskit.providers import JobError, JobStatus, JobTimeoutError +from qiskit.providers.jobstatus import JOB_FINAL_STATES +from qiskit.utils.deprecation import deprecate_func -from .base.base_result import BasePrimitiveResult -from .containers import PrimitiveResult +from .base.base_primitive_job import BasePrimitiveJob, ResultT -T = TypeVar("T", bound=Union[BasePrimitiveResult, PrimitiveResult]) - -class PrimitiveJob(JobV1, Generic[T]): +class PrimitiveJob(BasePrimitiveJob[ResultT, JobStatus]): """ - PrimitiveJob class for the reference implemetations of Primitives. + Primitive job class for the reference implementations of Primitives. """ def __init__(self, function, *args, **kwargs): """ Args: - function: a callable function to execute the job. + function: A callable function to execute the job. """ - job_id = str(uuid.uuid4()) - super().__init__(None, job_id) + super().__init__(str(uuid.uuid4())) self._future = None self._function = function self._args = args self._kwargs = kwargs - def submit(self): + def _submit(self): if self._future is not None: - raise JobError("Primitive job has already been submitted.") + raise JobError("Primitive job has been submitted already.") executor = ThreadPoolExecutor(max_workers=1) # pylint: disable=consider-using-with self._future = executor.submit(self._function, *self._args, **self._kwargs) executor.shutdown(wait=False) - def result(self) -> T: - """Return the results of the job.""" + def result(self) -> ResultT: self._check_submitted() return self._future.result() - def cancel(self): - self._check_submitted() - return self._future.cancel() - - def status(self): + def status(self) -> JobStatus: self._check_submitted() if self._future.running(): return JobStatus.RUNNING @@ -71,4 +65,69 @@ def status(self): def _check_submitted(self): if self._future is None: - raise JobError("Job not submitted yet!. You have to .submit() first!") + raise JobError("Primitive Job has not been submitted yet.") + + def cancel(self): + self._check_submitted() + return self._future.cancel() + + def done(self) -> bool: + return self.status() == JobStatus.DONE + + def running(self) -> bool: + return self.status() == JobStatus.RUNNING + + def cancelled(self) -> bool: + return self.status() == JobStatus.CANCELLED + + def in_final_state(self) -> bool: + return self.status() in JOB_FINAL_STATES + + @deprecate_func(since="0.46.0") + def submit(self): + """Submit a job. + + .. deprecated:: 0.46.0 + ``submit`` method is deprecated as of Qiskit 0.46 and will be removed + no earlier than 3 months after the release date. + + """ + self._submit() + + @deprecate_func(since="0.46.0") + def wait_for_final_state( + self, timeout: Optional[float] = None, wait: float = 5, callback: Optional[Callable] = None + ) -> None: + """Poll the job status until it progresses to a final state such as ``DONE`` or ``ERROR``. + + .. deprecated:: 0.46.0 + ``wait_for_final_state`` method is deprecated as of Qiskit 0.46 and will be removed + no earlier than 3 months after the release date. + + Args: + timeout: Seconds to wait for the job. If ``None``, wait indefinitely. + wait: Seconds between queries. + callback: Callback function invoked after each query. + The following positional arguments are provided to the callback function: + + * job_id: Job ID + * job_status: Status of the job from the last query + * job: This BaseJob instance + + Note: different subclass might provide different arguments to + the callback function. + + Raises: + JobTimeoutError: If the job does not reach a final state before the + specified timeout. + """ + start_time = time.time() + status = self.status() + while status not in JOB_FINAL_STATES: + elapsed_time = time.time() - start_time + if timeout is not None and elapsed_time >= timeout: + raise JobTimeoutError(f"Timeout while waiting for job {self.job_id()}.") + if callback: + callback(self.job_id(), status, self) + time.sleep(wait) + status = self.status() diff --git a/qiskit/primitives/sampler.py b/qiskit/primitives/sampler.py index 9ffe42165d96..23a901603bef 100644 --- a/qiskit/primitives/sampler.py +++ b/qiskit/primitives/sampler.py @@ -136,7 +136,7 @@ def _run( self._qargs_list.append(qargs) self._parameters.append(circuit.parameters) job = PrimitiveJob(self._call, circuit_indices, parameter_values, **run_options) - job.submit() + job._submit() return job @staticmethod diff --git a/releasenotes/notes/update-primitive-job-f5c9b31f68c3ec3d.yaml b/releasenotes/notes/update-primitive-job-f5c9b31f68c3ec3d.yaml new file mode 100644 index 000000000000..11d5ee88684a --- /dev/null +++ b/releasenotes/notes/update-primitive-job-f5c9b31f68c3ec3d.yaml @@ -0,0 +1,10 @@ +--- +upgrade: + - | + Added :class:`.BasePrimitiveJob` class as an abstract job class for Primitives + and made :class:`.PrimitiveJob` inherit :class:`.BasePrimitiveJob` + instead of :class:`.JobV1`. +deprecations: + - | + :meth:`.PrimitiveJob.submit` and :meth:`.PrimitiveJob.wait_for_final_state` + are deprecated and will be removed in the future release. diff --git a/test/python/primitives/test_backend_estimator.py b/test/python/primitives/test_backend_estimator.py index 0df15f590728..659765ec89b5 100644 --- a/test/python/primitives/test_backend_estimator.py +++ b/test/python/primitives/test_backend_estimator.py @@ -25,7 +25,6 @@ from qiskit.circuit import QuantumCircuit from qiskit.circuit.library import RealAmplitudes from qiskit.primitives import BackendEstimator, EstimatorResult -from qiskit.providers import JobV1 from qiskit.providers.fake_provider import FakeNairobi, FakeNairobiV2 from qiskit.providers.fake_provider.fake_backend_v2 import FakeBackendSimple from qiskit.quantum_info import SparsePauliOp @@ -91,7 +90,6 @@ def test_estimator_run(self, backend): # Specify the circuit and observable by indices. # calculate [ ] job = estimator.run([psi1], [hamiltonian1], [theta1]) - self.assertIsInstance(job, JobV1) result = job.result() self.assertIsInstance(result, EstimatorResult) np.testing.assert_allclose(result.values, [1.5555572817900956], rtol=0.5, atol=0.2) diff --git a/test/python/primitives/test_backend_sampler.py b/test/python/primitives/test_backend_sampler.py index b4b8d79e32a3..156f0b48fefc 100644 --- a/test/python/primitives/test_backend_sampler.py +++ b/test/python/primitives/test_backend_sampler.py @@ -25,7 +25,7 @@ from qiskit import QuantumCircuit from qiskit.circuit.library import RealAmplitudes from qiskit.primitives import BackendSampler, SamplerResult -from qiskit.providers import JobStatus, JobV1 +from qiskit.providers import JobStatus from qiskit.providers.fake_provider import FakeNairobi, FakeNairobiV2 from qiskit.providers.basicaer import QasmSimulatorPy from qiskit.test import QiskitTestCase @@ -115,7 +115,6 @@ def test_sampler_run(self, backend): bell = self._circuit[1] sampler = BackendSampler(backend=backend) job = sampler.run(circuits=[bell], shots=1000) - self.assertIsInstance(job, JobV1) result = job.result() self.assertIsInstance(result, SamplerResult) self.assertEqual(result.quasi_dists[0].shots, 1000) diff --git a/test/python/primitives/test_estimator.py b/test/python/primitives/test_estimator.py index 7a190cbe0034..8606819e524b 100644 --- a/test/python/primitives/test_estimator.py +++ b/test/python/primitives/test_estimator.py @@ -23,7 +23,6 @@ from qiskit.primitives import Estimator, EstimatorResult from qiskit.primitives.base import validation from qiskit.primitives.utils import _observable_key -from qiskit.providers import JobV1 from qiskit.quantum_info import Operator, Pauli, PauliList, SparsePauliOp from qiskit.test import QiskitTestCase @@ -68,7 +67,6 @@ def test_estimator_run(self): # Specify the circuit and observable by indices. # calculate [ ] job = estimator.run([psi1], [hamiltonian1], [theta1]) - self.assertIsInstance(job, JobV1) result = job.result() self.assertIsInstance(result, EstimatorResult) np.testing.assert_allclose(result.values, [1.5555572817900956]) diff --git a/test/python/primitives/test_sampler.py b/test/python/primitives/test_sampler.py index fac9a250a9e9..67ea3124df30 100644 --- a/test/python/primitives/test_sampler.py +++ b/test/python/primitives/test_sampler.py @@ -20,7 +20,7 @@ from qiskit.circuit import Parameter from qiskit.circuit.library import RealAmplitudes, UnitaryGate from qiskit.primitives import Sampler, SamplerResult -from qiskit.providers import JobStatus, JobV1 +from qiskit.providers import JobStatus from qiskit.test import QiskitTestCase @@ -90,7 +90,6 @@ def test_sampler_run(self): bell = self._circuit[1] sampler = Sampler() job = sampler.run(circuits=[bell]) - self.assertIsInstance(job, JobV1) result = job.result() self.assertIsInstance(result, SamplerResult) self._compare_probs(result.quasi_dists, self._target[1])