From 74b4a5079bf8664a524e14f75c3ce0549edb5171 Mon Sep 17 00:00:00 2001 From: Will Shanks Date: Tue, 28 Feb 2023 17:51:43 -0500 Subject: [PATCH 1/7] Delay scikit-learn import until first use --- .../data_processing/sklearn_discriminators.py | 347 +++++++++--------- .../multi_state_discrimination_analysis.py | 17 +- qiskit_experiments/warnings.py | 15 + test/data_processing/test_discriminator.py | 245 ++++++------- .../test_multi_state_discrimination.py | 24 ++ 5 files changed, 333 insertions(+), 315 deletions(-) diff --git a/qiskit_experiments/data_processing/sklearn_discriminators.py b/qiskit_experiments/data_processing/sklearn_discriminators.py index ad43e13c7c..e31c9700d0 100644 --- a/qiskit_experiments/data_processing/sklearn_discriminators.py +++ b/qiskit_experiments/data_processing/sklearn_discriminators.py @@ -1,181 +1,166 @@ -# This code is part of Qiskit. -# -# (C) Copyright IBM 2022. -# -# This code is licensed under the Apache License, Version 2.0. You may -# obtain a copy of this license in the LICENSE.txt file in the root directory -# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. -# -# Any modifications or derivative works of this code must retain this -# copyright notice, and modified files need to carry a notice indicating -# that they have been altered from the originals. - -"""Discriminators that wrap SKLearn.""" - -from typing import Any, List, Dict - -from qiskit_experiments.data_processing.discriminator import BaseDiscriminator -from qiskit_experiments.data_processing.exceptions import DataProcessorError - -try: - from sklearn.discriminant_analysis import ( - LinearDiscriminantAnalysis, - QuadraticDiscriminantAnalysis, - ) - - HAS_SKLEARN = True -except ImportError: - HAS_SKLEARN = False - - -class SkLDA(BaseDiscriminator): - """A wrapper for the SKlearn linear discriminant analysis.""" - - def __init__(self, lda: "LinearDiscriminantAnalysis"): - """ - Args: - lda: The sklearn linear discriminant analysis. This may be a trained or an - untrained discriminator. - - Raises: - DataProcessorError: if SKlearn could not be imported. - """ - if not HAS_SKLEARN: - raise DataProcessorError( - f"SKlearn is needed to initialize an {self.__class__.__name__}." - ) - - self._lda = lda - self.attributes = [ - "coef_", - "intercept_", - "covariance_", - "explained_variance_ratio_", - "means_", - "priors_", - "scalings_", - "xbar_", - "classes_", - "n_features_in_", - "feature_names_in_", - ] - - @property - def discriminator(self) -> Any: - """Return then SKLearn object.""" - return self._lda - - def is_trained(self) -> bool: - """Return True if the discriminator has been trained on data.""" - return not getattr(self._lda, "classes_", None) is None - - def predict(self, data: List): - """Wrap the predict method of the LDA.""" - return self._lda.predict(data) - - def fit(self, data: List, labels: List): - """Fit the LDA. - - Args: - data: The independent data. - labels: The labels corresponding to data. - """ - self._lda.fit(data, labels) - - def config(self) -> Dict[str, Any]: - """Return the configuration of the LDA.""" - attr_conf = {attr: getattr(self._lda, attr, None) for attr in self.attributes} - return {"params": self._lda.get_params(), "attributes": attr_conf} - - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "SkLDA": - """Deserialize from an object.""" - - if not HAS_SKLEARN: - raise DataProcessorError(f"SKlearn is needed to initialize an {cls.__name__}.") - - lda = LinearDiscriminantAnalysis() - lda.set_params(**config["params"]) - - for name, value in config["attributes"].items(): - if value is not None: - setattr(lda, name, value) - - return SkLDA(lda) - - -class SkQDA(BaseDiscriminator): - """A wrapper for the SKlearn quadratic discriminant analysis.""" - - def __init__(self, qda: "QuadraticDiscriminantAnalysis"): - """ - Args: - qda: The sklearn quadratic discriminant analysis. This may be a trained or an - untrained discriminator. - - Raises: - DataProcessorError: if SKlearn could not be imported. - """ - if not HAS_SKLEARN: - raise DataProcessorError( - f"SKlearn is needed to initialize an {self.__class__.__name__}." - ) - - self._qda = qda - self.attributes = [ - "coef_", - "intercept_", - "covariance_", - "explained_variance_ratio_", - "means_", - "priors_", - "scalings_", - "xbar_", - "classes_", - "n_features_in_", - "feature_names_in_", - "rotations_", - ] - - @property - def discriminator(self) -> Any: - """Return then SKLearn object.""" - return self._qda - - def is_trained(self) -> bool: - """Return True if the discriminator has been trained on data.""" - return not getattr(self._qda, "classes_", None) is None - - def predict(self, data: List): - """Wrap the predict method of the QDA.""" - return self._qda.predict(data) - - def fit(self, data: List, labels: List): - """Fit the QDA. - - Args: - data: The independent data. - labels: The labels corresponding to data. - """ - self._qda.fit(data, labels) - - def config(self) -> Dict[str, Any]: - """Return the configuration of the QDA.""" - attr_conf = {attr: getattr(self._qda, attr, None) for attr in self.attributes} - return {"params": self._qda.get_params(), "attributes": attr_conf} - - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "SkQDA": - """Deserialize from an object.""" - - if not HAS_SKLEARN: - raise DataProcessorError(f"SKlearn is needed to initialize an {cls.__name__}.") - - qda = QuadraticDiscriminantAnalysis() - qda.set_params(**config["params"]) - - for name, value in config["attributes"].items(): - if value is not None: - setattr(qda, name, value) - - return SkQDA(qda) +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Discriminators that wrap SKLearn.""" + +from typing import Any, List, Dict, TYPE_CHECKING + +from qiskit_experiments.data_processing.discriminator import BaseDiscriminator +from qiskit_experiments.data_processing.exceptions import DataProcessorError +from qiskit_experiments.warnings import HAS_SKLEARN + +if TYPE_CHECKING: + from sklearn.discriminant_analysis import ( + LinearDiscriminantAnalysis, + QuadraticDiscriminantAnalysis, + ) + + +class SkLDA(BaseDiscriminator): + """A wrapper for the SKlearn linear discriminant analysis.""" + + def __init__(self, lda: "LinearDiscriminantAnalysis"): + """ + Args: + lda: The sklearn linear discriminant analysis. This may be a trained or an + untrained discriminator. + + Raises: + DataProcessorError: if SKlearn could not be imported. + """ + self._lda = lda + self.attributes = [ + "coef_", + "intercept_", + "covariance_", + "explained_variance_ratio_", + "means_", + "priors_", + "scalings_", + "xbar_", + "classes_", + "n_features_in_", + "feature_names_in_", + ] + + @property + def discriminator(self) -> Any: + """Return then SKLearn object.""" + return self._lda + + def is_trained(self) -> bool: + """Return True if the discriminator has been trained on data.""" + return not getattr(self._lda, "classes_", None) is None + + def predict(self, data: List): + """Wrap the predict method of the LDA.""" + return self._lda.predict(data) + + def fit(self, data: List, labels: List): + """Fit the LDA. + + Args: + data: The independent data. + labels: The labels corresponding to data. + """ + self._lda.fit(data, labels) + + def config(self) -> Dict[str, Any]: + """Return the configuration of the LDA.""" + attr_conf = {attr: getattr(self._lda, attr, None) for attr in self.attributes} + return {"params": self._lda.get_params(), "attributes": attr_conf} + + @classmethod + @HAS_SKLEARN.require_in_call + def from_config(cls, config: Dict[str, Any]) -> "SkLDA": + """Deserialize from an object.""" + from sklearn.discriminant_analysis import LinearDiscriminantAnalysis + + lda = LinearDiscriminantAnalysis() + lda.set_params(**config["params"]) + + for name, value in config["attributes"].items(): + if value is not None: + setattr(lda, name, value) + + return SkLDA(lda) + + +class SkQDA(BaseDiscriminator): + """A wrapper for the SKlearn quadratic discriminant analysis.""" + + def __init__(self, qda: "QuadraticDiscriminantAnalysis"): + """ + Args: + qda: The sklearn quadratic discriminant analysis. This may be a trained or an + untrained discriminator. + + Raises: + DataProcessorError: if SKlearn could not be imported. + """ + self._qda = qda + self.attributes = [ + "coef_", + "intercept_", + "covariance_", + "explained_variance_ratio_", + "means_", + "priors_", + "scalings_", + "xbar_", + "classes_", + "n_features_in_", + "feature_names_in_", + "rotations_", + ] + + @property + def discriminator(self) -> Any: + """Return then SKLearn object.""" + return self._qda + + def is_trained(self) -> bool: + """Return True if the discriminator has been trained on data.""" + return not getattr(self._qda, "classes_", None) is None + + def predict(self, data: List): + """Wrap the predict method of the QDA.""" + return self._qda.predict(data) + + def fit(self, data: List, labels: List): + """Fit the QDA. + + Args: + data: The independent data. + labels: The labels corresponding to data. + """ + self._qda.fit(data, labels) + + def config(self) -> Dict[str, Any]: + """Return the configuration of the QDA.""" + attr_conf = {attr: getattr(self._qda, attr, None) for attr in self.attributes} + return {"params": self._qda.get_params(), "attributes": attr_conf} + + @classmethod + @HAS_SKLEARN.require_in_call + def from_config(cls, config: Dict[str, Any]) -> "SkQDA": + """Deserialize from an object.""" + from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis + + qda = QuadraticDiscriminantAnalysis() + qda.set_params(**config["params"]) + + for name, value in config["attributes"].items(): + if value is not None: + setattr(qda, name, value) + + return SkQDA(qda) diff --git a/qiskit_experiments/library/characterization/analysis/multi_state_discrimination_analysis.py b/qiskit_experiments/library/characterization/analysis/multi_state_discrimination_analysis.py index be39380b22..46bbab82a8 100644 --- a/qiskit_experiments/library/characterization/analysis/multi_state_discrimination_analysis.py +++ b/qiskit_experiments/library/characterization/analysis/multi_state_discrimination_analysis.py @@ -12,7 +12,7 @@ """Multi state discrimination analysis.""" -from typing import List, Tuple +from typing import List, Tuple, TYPE_CHECKING import matplotlib import numpy as np @@ -22,14 +22,11 @@ from qiskit_experiments.data_processing import SkQDA from qiskit_experiments.data_processing.exceptions import DataProcessorError from qiskit_experiments.visualization import BasePlotter, IQPlotter, MplDrawer, PlotStyle +from qiskit_experiments.warnings import HAS_SKLEARN -try: +if TYPE_CHECKING: from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis - HAS_SKLEARN = True -except ImportError: - HAS_SKLEARN = False - class MultiStateDiscriminationAnalysis(BaseAnalysis): r"""This class fits a multi-state discriminator to the data. @@ -51,14 +48,10 @@ def __init__(self): Raises: DataProcessorError: if sklearn is not installed. """ - if not HAS_SKLEARN: - raise DataProcessorError( - f"SKlearn is needed to initialize an {self.__class__.__name__}." - ) - super().__init__() @classmethod + @HAS_SKLEARN.require_in_call def _default_options(cls) -> Options: """Return default analysis options. @@ -76,6 +69,8 @@ def _default_options(cls) -> Options: ) options.plot = True options.ax = None + from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis + options.discriminator = SkQDA(QuadraticDiscriminantAnalysis()) return options diff --git a/qiskit_experiments/warnings.py b/qiskit_experiments/warnings.py index 4f9742e5b5..d2b65b1bc7 100644 --- a/qiskit_experiments/warnings.py +++ b/qiskit_experiments/warnings.py @@ -16,6 +16,9 @@ import warnings from typing import Callable, Optional, Type, Dict +from qiskit.exceptions import QiskitError +from qiskit.utils.lazy_tester import LazyImportTester + def deprecated_function( last_version: Optional[str] = None, @@ -240,3 +243,15 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +HAS_SKLEARN = LazyImportTester( + { + "sklearn.discriminant_analysis": ( + "LinearDiscriminantAnalysis", + "QuadraticDiscriminantAnalysis", + ) + }, + name="scikit-learn", + install="pip install scikit-learn", +) diff --git a/test/data_processing/test_discriminator.py b/test/data_processing/test_discriminator.py index dc0fea629a..2118aab5db 100644 --- a/test/data_processing/test_discriminator.py +++ b/test/data_processing/test_discriminator.py @@ -1,123 +1,122 @@ -# This code is part of Qiskit. -# -# (C) Copyright IBM 2022. -# -# This code is licensed under the Apache License, Version 2.0. You may -# obtain a copy of this license in the LICENSE.txt file in the root directory -# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. -# -# Any modifications or derivative works of this code must retain this -# copyright notice, and modified files need to carry a notice indicating -# that they have been altered from the originals. - -"""Tests for the serializable discriminator objects.""" - -from test.base import QiskitExperimentsTestCase -from functools import wraps -from unittest import SkipTest -import numpy as np - -from qiskit_experiments.data_processing import SkLDA, SkQDA - -try: - from sklearn.discriminant_analysis import ( - LinearDiscriminantAnalysis, - QuadraticDiscriminantAnalysis, - ) - - HAS_SKLEARN = True -except ImportError: - HAS_SKLEARN = False - - -def requires_sklearn(func): - """Decorator to check for SKLearn.""" - - @wraps(func) - def wrapper(*args, **kwargs): - if not HAS_SKLEARN: - raise SkipTest("SKLearn is required for test.") - - func(*args, **kwargs) - - return wrapper - - -class TestDiscriminator(QiskitExperimentsTestCase): - """Tests for the discriminator.""" - - @requires_sklearn - def test_lda_serialization(self): - """Test the serialization of a lda.""" - - sk_lda = LinearDiscriminantAnalysis() - sk_lda.fit([[-1, 0], [1, 0], [-1.1, 0], [0.9, 0.1]], [0, 1, 0, 1]) - - self.assertTrue(sk_lda.predict([[1.1, 0]])[0], 1) - - lda = SkLDA(sk_lda) - - self.assertTrue(lda.is_trained()) - self.assertTrue(lda.predict([[1.1, 0]])[0], 1) - - def check_lda(lda1, lda2): - test_data = [[1.1, 0], [0.1, 0], [-2, 0]] - - lda1_y = lda1.predict(test_data) - lda2_y = lda2.predict(test_data) - - if len(lda1_y) != len(lda2_y): - return False - - for idx, y_val1 in enumerate(lda1_y): - if lda2_y[idx] != y_val1: - return False - - for attribute in lda1.attributes: - if not np.allclose( - getattr(lda1.discriminator, attribute, np.array([])), - getattr(lda2.discriminator, attribute, np.array([])), - ): - return False - - return True - - self.assertRoundTripSerializable(lda, check_lda) - - @requires_sklearn - def test_qda_serialization(self): - """Test the serialization of a qda.""" - - sk_qda = QuadraticDiscriminantAnalysis() - sk_qda.fit([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], [0, 0, 0, 1, 1, 1]) - - self.assertTrue(sk_qda.predict([[1.1, 3]])[0], 1) - - qda = SkQDA(sk_qda) - - self.assertTrue(qda.is_trained()) - self.assertTrue(qda.predict([[1.1, 3]])[0], 1) - - def check_qda(qda1, qda2): - test_data = [[1.1, 0], [0.1, 0], [-2, 0]] - - qda1_y = qda1.predict(test_data) - qda2_y = qda2.predict(test_data) - - if len(qda1_y) != len(qda2_y): - return False - - for idx, y_val1 in enumerate(qda1_y): - if qda2_y[idx] != y_val1: - return False - - for attribute in qda1.attributes: - if not np.allclose( - getattr(qda1.discriminator, attribute, np.array([])), - getattr(qda2.discriminator, attribute, np.array([])), - ): - return False - - return True - - self.assertRoundTripSerializable(qda, check_qda) +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Tests for the serializable discriminator objects.""" + +from test.base import QiskitExperimentsTestCase +from functools import wraps +from unittest import SkipTest +import numpy as np + +from qiskit.exceptions import MissingOptionalLibraryError + +from qiskit_experiments.data_processing import SkLDA, SkQDA +from qiskit_experiments.warnings import HAS_SKLEARN + + +def requires_sklearn(func): + """Decorator to check for SKLearn.""" + + @wraps(func) + def wrapper(*args, **kwargs): + try: + HAS_SKLEARN.require_now("SKLearn disciminator testing") + except MissingOptionalLibraryError: + raise SkipTest("SKLearn is required for test.") + + func(*args, **kwargs) + + return wrapper + + +class TestDiscriminator(QiskitExperimentsTestCase): + """Tests for the discriminator.""" + + @requires_sklearn + def test_lda_serialization(self): + """Test the serialization of a lda.""" + + from sklearn.discriminant_analysis import LinearDiscriminantAnalysis + + sk_lda = LinearDiscriminantAnalysis() + sk_lda.fit([[-1, 0], [1, 0], [-1.1, 0], [0.9, 0.1]], [0, 1, 0, 1]) + + self.assertTrue(sk_lda.predict([[1.1, 0]])[0], 1) + + lda = SkLDA(sk_lda) + + self.assertTrue(lda.is_trained()) + self.assertTrue(lda.predict([[1.1, 0]])[0], 1) + + def check_lda(lda1, lda2): + test_data = [[1.1, 0], [0.1, 0], [-2, 0]] + + lda1_y = lda1.predict(test_data) + lda2_y = lda2.predict(test_data) + + if len(lda1_y) != len(lda2_y): + return False + + for idx, y_val1 in enumerate(lda1_y): + if lda2_y[idx] != y_val1: + return False + + for attribute in lda1.attributes: + if not np.allclose( + getattr(lda1.discriminator, attribute, np.array([])), + getattr(lda2.discriminator, attribute, np.array([])), + ): + return False + + return True + + self.assertRoundTripSerializable(lda, check_lda) + + @requires_sklearn + def test_qda_serialization(self): + """Test the serialization of a qda.""" + + from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis + + sk_qda = QuadraticDiscriminantAnalysis() + sk_qda.fit([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], [0, 0, 0, 1, 1, 1]) + + self.assertTrue(sk_qda.predict([[1.1, 3]])[0], 1) + + qda = SkQDA(sk_qda) + + self.assertTrue(qda.is_trained()) + self.assertTrue(qda.predict([[1.1, 3]])[0], 1) + + def check_qda(qda1, qda2): + test_data = [[1.1, 0], [0.1, 0], [-2, 0]] + + qda1_y = qda1.predict(test_data) + qda2_y = qda2.predict(test_data) + + if len(qda1_y) != len(qda2_y): + return False + + for idx, y_val1 in enumerate(qda1_y): + if qda2_y[idx] != y_val1: + return False + + for attribute in qda1.attributes: + if not np.allclose( + getattr(qda1.discriminator, attribute, np.array([])), + getattr(qda2.discriminator, attribute, np.array([])), + ): + return False + + return True + + self.assertRoundTripSerializable(qda, check_qda) diff --git a/test/library/characterization/test_multi_state_discrimination.py b/test/library/characterization/test_multi_state_discrimination.py index 4beb0d4d8d..1440da7b51 100644 --- a/test/library/characterization/test_multi_state_discrimination.py +++ b/test/library/characterization/test_multi_state_discrimination.py @@ -11,13 +11,35 @@ # that they have been altered from the originals. """Test the multi state discrimination experiments.""" +from functools import wraps from test.base import QiskitExperimentsTestCase +from unittest import SkipTest + from ddt import ddt, data from qiskit import pulse +from qiskit.exceptions import MissingOptionalLibraryError + from qiskit_experiments.library import MultiStateDiscrimination from qiskit_experiments.test.pulse_backend import SingleTransmonTestBackend +from qiskit_experiments.warnings import HAS_SKLEARN + + +def requires_sklearn(func): + """Decorator to check for SKLearn.""" + + @wraps(func) + def wrapper(*args, **kwargs): + try: + HAS_SKLEARN.require_now("SKLearn disciminator testing") + except MissingOptionalLibraryError: + raise SkipTest("SKLearn is required for test.") + + func(*args, **kwargs) + + return wrapper + @ddt class TestMultiStateDiscrimination(QiskitExperimentsTestCase): @@ -52,6 +74,7 @@ def setUp(self): self.schedules = {"x12": x12} @data(2, 3) + @requires_sklearn def test_circuit_generation(self, n_states): """Test the experiment circuit generation""" exp = MultiStateDiscrimination( @@ -63,6 +86,7 @@ def test_circuit_generation(self, n_states): self.assertEqual(exp.circuits()[-1].metadata["label"], n_states - 1) @data(2, 3) + @requires_sklearn def test_discrimination_analysis(self, n_states): """Test the discrimination analysis""" exp = MultiStateDiscrimination( From 575efd1edf0dffd1621f26206e1f58f43b2a40ff Mon Sep 17 00:00:00 2001 From: Will Shanks Date: Wed, 1 Mar 2023 14:04:38 -0500 Subject: [PATCH 2/7] Address pylint warnings --- .../data_processing/sklearn_discriminators.py | 1 - .../analysis/multi_state_discrimination_analysis.py | 9 --------- qiskit_experiments/warnings.py | 1 - test/data_processing/test_discriminator.py | 4 ++-- .../characterization/test_multi_state_discrimination.py | 4 ++-- 5 files changed, 4 insertions(+), 15 deletions(-) diff --git a/qiskit_experiments/data_processing/sklearn_discriminators.py b/qiskit_experiments/data_processing/sklearn_discriminators.py index e31c9700d0..ce97f1e4d9 100644 --- a/qiskit_experiments/data_processing/sklearn_discriminators.py +++ b/qiskit_experiments/data_processing/sklearn_discriminators.py @@ -15,7 +15,6 @@ from typing import Any, List, Dict, TYPE_CHECKING from qiskit_experiments.data_processing.discriminator import BaseDiscriminator -from qiskit_experiments.data_processing.exceptions import DataProcessorError from qiskit_experiments.warnings import HAS_SKLEARN if TYPE_CHECKING: diff --git a/qiskit_experiments/library/characterization/analysis/multi_state_discrimination_analysis.py b/qiskit_experiments/library/characterization/analysis/multi_state_discrimination_analysis.py index 46bbab82a8..dcafd655c4 100644 --- a/qiskit_experiments/library/characterization/analysis/multi_state_discrimination_analysis.py +++ b/qiskit_experiments/library/characterization/analysis/multi_state_discrimination_analysis.py @@ -20,7 +20,6 @@ from qiskit.providers.options import Options from qiskit_experiments.framework import BaseAnalysis, AnalysisResultData, ExperimentData from qiskit_experiments.data_processing import SkQDA -from qiskit_experiments.data_processing.exceptions import DataProcessorError from qiskit_experiments.visualization import BasePlotter, IQPlotter, MplDrawer, PlotStyle from qiskit_experiments.warnings import HAS_SKLEARN @@ -42,14 +41,6 @@ class MultiStateDiscriminationAnalysis(BaseAnalysis): probability of measuring outcome :math:`i` given that state :math:`j` was prepared. """ - def __init__(self): - """Setup the analysis. - - Raises: - DataProcessorError: if sklearn is not installed. - """ - super().__init__() - @classmethod @HAS_SKLEARN.require_in_call def _default_options(cls) -> Options: diff --git a/qiskit_experiments/warnings.py b/qiskit_experiments/warnings.py index d2b65b1bc7..3758ac2764 100644 --- a/qiskit_experiments/warnings.py +++ b/qiskit_experiments/warnings.py @@ -16,7 +16,6 @@ import warnings from typing import Callable, Optional, Type, Dict -from qiskit.exceptions import QiskitError from qiskit.utils.lazy_tester import LazyImportTester diff --git a/test/data_processing/test_discriminator.py b/test/data_processing/test_discriminator.py index 2118aab5db..9d1a49fdc7 100644 --- a/test/data_processing/test_discriminator.py +++ b/test/data_processing/test_discriminator.py @@ -30,8 +30,8 @@ def requires_sklearn(func): def wrapper(*args, **kwargs): try: HAS_SKLEARN.require_now("SKLearn disciminator testing") - except MissingOptionalLibraryError: - raise SkipTest("SKLearn is required for test.") + except MissingOptionalLibraryError as exc: + raise SkipTest("SKLearn is required for test.") from exc func(*args, **kwargs) diff --git a/test/library/characterization/test_multi_state_discrimination.py b/test/library/characterization/test_multi_state_discrimination.py index 1440da7b51..a99b978dab 100644 --- a/test/library/characterization/test_multi_state_discrimination.py +++ b/test/library/characterization/test_multi_state_discrimination.py @@ -33,8 +33,8 @@ def requires_sklearn(func): def wrapper(*args, **kwargs): try: HAS_SKLEARN.require_now("SKLearn disciminator testing") - except MissingOptionalLibraryError: - raise SkipTest("SKLearn is required for test.") + except MissingOptionalLibraryError as exc: + raise SkipTest("SKLearn is required for test.") from exc func(*args, **kwargs) From 7d9d95789d7f9fb4c382dca29b1d4efdc07b7c28 Mon Sep 17 00:00:00 2001 From: Will Shanks Date: Thu, 2 Mar 2023 14:42:01 -0500 Subject: [PATCH 3/7] Add test of warning about optional scikit-learn dependency --- test/framework/test_warnings.py | 35 +++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/test/framework/test_warnings.py b/test/framework/test_warnings.py index 45ab7c90de..0f45fb4959 100644 --- a/test/framework/test_warnings.py +++ b/test/framework/test_warnings.py @@ -13,6 +13,9 @@ # pylint: disable=unused-argument, unused-variable """Test warning helper.""" +import subprocess +import sys +import textwrap from test.base import QiskitExperimentsTestCase from qiskit_experiments.framework import BaseExperiment @@ -86,3 +89,35 @@ def __init__(self, physical_qubits): with self.assertWarns(DeprecationWarning): instance = OldExperiment(qubit=0) self.assertEqual(instance._physical_qubits, (0,)) + + def test_warn_sklearn(self): + """Test that a suggestion to import scikit-learn is given when appropriate""" + script = """ + import sys + sys.modules["sklearn"] = None + import qiskit_experiments + print("qiskit_experiments imported!") + from qiskit_experiments.data_processing.sklearn_discriminators import SkLDA + SkLDA.from_config({}) + """ + script = textwrap.dedent(script) + + proc = subprocess.run( + [sys.executable, "-c", script], check=False, text=True, capture_output=True + ) + + self.assertTrue( + proc.stdout.startswith("qiskit_experiments imported!"), + msg="Failed to import qiskit_experiments without sklearn", + ) + + self.assertNotEqual( + proc.returncode, + 0, + msg="scikit-learn usage did not error without scikit-learn available", + ) + self.assertTrue( + "qiskit.exceptions.MissingOptionalLibraryError" in proc.stderr + and "scikit-learn" in proc.stderr, + msg="scikit-learn import guard did not run on scikit-learn usage", + ) From 0a6b980452c65a04e297f860d049441808241329 Mon Sep 17 00:00:00 2001 From: Will Shanks Date: Thu, 2 Mar 2023 15:13:50 -0500 Subject: [PATCH 4/7] Add release note --- releasenotes/notes/sklearn-imports-c82155c0a2c81811.yaml | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 releasenotes/notes/sklearn-imports-c82155c0a2c81811.yaml diff --git a/releasenotes/notes/sklearn-imports-c82155c0a2c81811.yaml b/releasenotes/notes/sklearn-imports-c82155c0a2c81811.yaml new file mode 100644 index 0000000000..56029aa414 --- /dev/null +++ b/releasenotes/notes/sklearn-imports-c82155c0a2c81811.yaml @@ -0,0 +1,9 @@ +--- +fixes: + - | + The importing of ``scikit-learn`` was moved from module-level imports + inside of ``try`` blocks to dynamic imports at first usage of the + ``scikit-learn`` specific feature. This change should avoid errors in the + installation of ``scikit-learn`` from preventing a user using features of + ``qiskit-experiments`` that do not require ``scikit-learn``. See `#1050 + `_. From c826d0fb2afae7daa1cd5b14ec649c4d28ac0954 Mon Sep 17 00:00:00 2001 From: Will Shanks Date: Fri, 3 Mar 2023 12:43:37 -0500 Subject: [PATCH 5/7] Apply suggestions from code review Co-authored-by: Helena Zhang --- test/data_processing/test_discriminator.py | 2 +- .../library/characterization/test_multi_state_discrimination.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/data_processing/test_discriminator.py b/test/data_processing/test_discriminator.py index 9d1a49fdc7..7ce095ff7c 100644 --- a/test/data_processing/test_discriminator.py +++ b/test/data_processing/test_discriminator.py @@ -29,7 +29,7 @@ def requires_sklearn(func): @wraps(func) def wrapper(*args, **kwargs): try: - HAS_SKLEARN.require_now("SKLearn disciminator testing") + HAS_SKLEARN.require_now("SKLearn discriminator testing") except MissingOptionalLibraryError as exc: raise SkipTest("SKLearn is required for test.") from exc diff --git a/test/library/characterization/test_multi_state_discrimination.py b/test/library/characterization/test_multi_state_discrimination.py index a99b978dab..508299cd75 100644 --- a/test/library/characterization/test_multi_state_discrimination.py +++ b/test/library/characterization/test_multi_state_discrimination.py @@ -32,7 +32,7 @@ def requires_sklearn(func): @wraps(func) def wrapper(*args, **kwargs): try: - HAS_SKLEARN.require_now("SKLearn disciminator testing") + HAS_SKLEARN.require_now("SKLearn discriminator testing") except MissingOptionalLibraryError as exc: raise SkipTest("SKLearn is required for test.") from exc From fc3f5f97a1a9f59da880633104c38a945f4246ff Mon Sep 17 00:00:00 2001 From: Will Shanks Date: Fri, 3 Mar 2023 14:12:57 -0500 Subject: [PATCH 6/7] Make sklearn import test sensitive to any error in sklearn import --- test/framework/test_warnings.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/test/framework/test_warnings.py b/test/framework/test_warnings.py index 0f45fb4959..365176ffae 100644 --- a/test/framework/test_warnings.py +++ b/test/framework/test_warnings.py @@ -93,10 +93,20 @@ def __init__(self, physical_qubits): def test_warn_sklearn(self): """Test that a suggestion to import scikit-learn is given when appropriate""" script = """ - import sys - sys.modules["sklearn"] = None + import builtins + disallowed_imports = {"sklearn"} + old_import = builtins.__import__ + def guarded_import(name, *args, **kwargs): + if name in disallowed_imports: + raise import_error(f"Import of {name} not allowed!") + return old_import(name, *args, **kwargs) + builtins.__import__ = guarded_import + # Raise Exception on imports so that ImportError can't be caught + import_error = Exception import qiskit_experiments print("qiskit_experiments imported!") + # Raise ImportError so the guard can catch it + import_error = ImportError from qiskit_experiments.data_processing.sklearn_discriminators import SkLDA SkLDA.from_config({}) """ From 249c4274851c2fb7811bd5b62fc12b0f07e126ea Mon Sep 17 00:00:00 2001 From: Will Shanks Date: Fri, 3 Mar 2023 15:08:17 -0500 Subject: [PATCH 7/7] Document classes that require scikit-learn --- .../data_processing/sklearn_discriminators.py | 338 +++++++++--------- .../multi_state_discrimination_analysis.py | 3 + 2 files changed, 176 insertions(+), 165 deletions(-) diff --git a/qiskit_experiments/data_processing/sklearn_discriminators.py b/qiskit_experiments/data_processing/sklearn_discriminators.py index ce97f1e4d9..49d3072004 100644 --- a/qiskit_experiments/data_processing/sklearn_discriminators.py +++ b/qiskit_experiments/data_processing/sklearn_discriminators.py @@ -1,165 +1,173 @@ -# This code is part of Qiskit. -# -# (C) Copyright IBM 2022. -# -# This code is licensed under the Apache License, Version 2.0. You may -# obtain a copy of this license in the LICENSE.txt file in the root directory -# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. -# -# Any modifications or derivative works of this code must retain this -# copyright notice, and modified files need to carry a notice indicating -# that they have been altered from the originals. - -"""Discriminators that wrap SKLearn.""" - -from typing import Any, List, Dict, TYPE_CHECKING - -from qiskit_experiments.data_processing.discriminator import BaseDiscriminator -from qiskit_experiments.warnings import HAS_SKLEARN - -if TYPE_CHECKING: - from sklearn.discriminant_analysis import ( - LinearDiscriminantAnalysis, - QuadraticDiscriminantAnalysis, - ) - - -class SkLDA(BaseDiscriminator): - """A wrapper for the SKlearn linear discriminant analysis.""" - - def __init__(self, lda: "LinearDiscriminantAnalysis"): - """ - Args: - lda: The sklearn linear discriminant analysis. This may be a trained or an - untrained discriminator. - - Raises: - DataProcessorError: if SKlearn could not be imported. - """ - self._lda = lda - self.attributes = [ - "coef_", - "intercept_", - "covariance_", - "explained_variance_ratio_", - "means_", - "priors_", - "scalings_", - "xbar_", - "classes_", - "n_features_in_", - "feature_names_in_", - ] - - @property - def discriminator(self) -> Any: - """Return then SKLearn object.""" - return self._lda - - def is_trained(self) -> bool: - """Return True if the discriminator has been trained on data.""" - return not getattr(self._lda, "classes_", None) is None - - def predict(self, data: List): - """Wrap the predict method of the LDA.""" - return self._lda.predict(data) - - def fit(self, data: List, labels: List): - """Fit the LDA. - - Args: - data: The independent data. - labels: The labels corresponding to data. - """ - self._lda.fit(data, labels) - - def config(self) -> Dict[str, Any]: - """Return the configuration of the LDA.""" - attr_conf = {attr: getattr(self._lda, attr, None) for attr in self.attributes} - return {"params": self._lda.get_params(), "attributes": attr_conf} - - @classmethod - @HAS_SKLEARN.require_in_call - def from_config(cls, config: Dict[str, Any]) -> "SkLDA": - """Deserialize from an object.""" - from sklearn.discriminant_analysis import LinearDiscriminantAnalysis - - lda = LinearDiscriminantAnalysis() - lda.set_params(**config["params"]) - - for name, value in config["attributes"].items(): - if value is not None: - setattr(lda, name, value) - - return SkLDA(lda) - - -class SkQDA(BaseDiscriminator): - """A wrapper for the SKlearn quadratic discriminant analysis.""" - - def __init__(self, qda: "QuadraticDiscriminantAnalysis"): - """ - Args: - qda: The sklearn quadratic discriminant analysis. This may be a trained or an - untrained discriminator. - - Raises: - DataProcessorError: if SKlearn could not be imported. - """ - self._qda = qda - self.attributes = [ - "coef_", - "intercept_", - "covariance_", - "explained_variance_ratio_", - "means_", - "priors_", - "scalings_", - "xbar_", - "classes_", - "n_features_in_", - "feature_names_in_", - "rotations_", - ] - - @property - def discriminator(self) -> Any: - """Return then SKLearn object.""" - return self._qda - - def is_trained(self) -> bool: - """Return True if the discriminator has been trained on data.""" - return not getattr(self._qda, "classes_", None) is None - - def predict(self, data: List): - """Wrap the predict method of the QDA.""" - return self._qda.predict(data) - - def fit(self, data: List, labels: List): - """Fit the QDA. - - Args: - data: The independent data. - labels: The labels corresponding to data. - """ - self._qda.fit(data, labels) - - def config(self) -> Dict[str, Any]: - """Return the configuration of the QDA.""" - attr_conf = {attr: getattr(self._qda, attr, None) for attr in self.attributes} - return {"params": self._qda.get_params(), "attributes": attr_conf} - - @classmethod - @HAS_SKLEARN.require_in_call - def from_config(cls, config: Dict[str, Any]) -> "SkQDA": - """Deserialize from an object.""" - from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis - - qda = QuadraticDiscriminantAnalysis() - qda.set_params(**config["params"]) - - for name, value in config["attributes"].items(): - if value is not None: - setattr(qda, name, value) - - return SkQDA(qda) +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Discriminators that wrap SKLearn.""" + +from typing import Any, List, Dict, TYPE_CHECKING + +from qiskit_experiments.data_processing.discriminator import BaseDiscriminator +from qiskit_experiments.warnings import HAS_SKLEARN + +if TYPE_CHECKING: + from sklearn.discriminant_analysis import ( + LinearDiscriminantAnalysis, + QuadraticDiscriminantAnalysis, + ) + + +class SkLDA(BaseDiscriminator): + """A wrapper for the scikit-learn linear discriminant analysis. + + .. note:: + This class requires that scikit-learn is installed. + """ + + def __init__(self, lda: "LinearDiscriminantAnalysis"): + """ + Args: + lda: The sklearn linear discriminant analysis. This may be a trained or an + untrained discriminator. + + Raises: + DataProcessorError: if SKlearn could not be imported. + """ + self._lda = lda + self.attributes = [ + "coef_", + "intercept_", + "covariance_", + "explained_variance_ratio_", + "means_", + "priors_", + "scalings_", + "xbar_", + "classes_", + "n_features_in_", + "feature_names_in_", + ] + + @property + def discriminator(self) -> Any: + """Return then SKLearn object.""" + return self._lda + + def is_trained(self) -> bool: + """Return True if the discriminator has been trained on data.""" + return not getattr(self._lda, "classes_", None) is None + + def predict(self, data: List): + """Wrap the predict method of the LDA.""" + return self._lda.predict(data) + + def fit(self, data: List, labels: List): + """Fit the LDA. + + Args: + data: The independent data. + labels: The labels corresponding to data. + """ + self._lda.fit(data, labels) + + def config(self) -> Dict[str, Any]: + """Return the configuration of the LDA.""" + attr_conf = {attr: getattr(self._lda, attr, None) for attr in self.attributes} + return {"params": self._lda.get_params(), "attributes": attr_conf} + + @classmethod + @HAS_SKLEARN.require_in_call + def from_config(cls, config: Dict[str, Any]) -> "SkLDA": + """Deserialize from an object.""" + from sklearn.discriminant_analysis import LinearDiscriminantAnalysis + + lda = LinearDiscriminantAnalysis() + lda.set_params(**config["params"]) + + for name, value in config["attributes"].items(): + if value is not None: + setattr(lda, name, value) + + return SkLDA(lda) + + +class SkQDA(BaseDiscriminator): + """A wrapper for the SKlearn quadratic discriminant analysis. + + .. note:: + This class requires that scikit-learn is installed. + """ + + def __init__(self, qda: "QuadraticDiscriminantAnalysis"): + """ + Args: + qda: The sklearn quadratic discriminant analysis. This may be a trained or an + untrained discriminator. + + Raises: + DataProcessorError: if SKlearn could not be imported. + """ + self._qda = qda + self.attributes = [ + "coef_", + "intercept_", + "covariance_", + "explained_variance_ratio_", + "means_", + "priors_", + "scalings_", + "xbar_", + "classes_", + "n_features_in_", + "feature_names_in_", + "rotations_", + ] + + @property + def discriminator(self) -> Any: + """Return then SKLearn object.""" + return self._qda + + def is_trained(self) -> bool: + """Return True if the discriminator has been trained on data.""" + return not getattr(self._qda, "classes_", None) is None + + def predict(self, data: List): + """Wrap the predict method of the QDA.""" + return self._qda.predict(data) + + def fit(self, data: List, labels: List): + """Fit the QDA. + + Args: + data: The independent data. + labels: The labels corresponding to data. + """ + self._qda.fit(data, labels) + + def config(self) -> Dict[str, Any]: + """Return the configuration of the QDA.""" + attr_conf = {attr: getattr(self._qda, attr, None) for attr in self.attributes} + return {"params": self._qda.get_params(), "attributes": attr_conf} + + @classmethod + @HAS_SKLEARN.require_in_call + def from_config(cls, config: Dict[str, Any]) -> "SkQDA": + """Deserialize from an object.""" + from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis + + qda = QuadraticDiscriminantAnalysis() + qda.set_params(**config["params"]) + + for name, value in config["attributes"].items(): + if value is not None: + setattr(qda, name, value) + + return SkQDA(qda) diff --git a/qiskit_experiments/library/characterization/analysis/multi_state_discrimination_analysis.py b/qiskit_experiments/library/characterization/analysis/multi_state_discrimination_analysis.py index dcafd655c4..94b9bd990a 100644 --- a/qiskit_experiments/library/characterization/analysis/multi_state_discrimination_analysis.py +++ b/qiskit_experiments/library/characterization/analysis/multi_state_discrimination_analysis.py @@ -39,6 +39,9 @@ class MultiStateDiscriminationAnalysis(BaseAnalysis): Here, :math:`d` is the number of levels that were discriminated while :math:`P(i|j)` is the probability of measuring outcome :math:`i` given that state :math:`j` was prepared. + + .. note:: + This class requires that scikit-learn is installed. """ @classmethod