Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delay scikit-learn import until first use #1061

Merged
merged 7 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 17 additions & 25 deletions qiskit_experiments/data_processing/sklearn_discriminators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,24 @@

"""Discriminators that wrap SKLearn."""

from typing import Any, List, Dict
from typing import Any, List, Dict, TYPE_CHECKING

from qiskit_experiments.data_processing.discriminator import BaseDiscriminator
from qiskit_experiments.data_processing.exceptions import DataProcessorError
from qiskit_experiments.warnings import HAS_SKLEARN

try:
if TYPE_CHECKING:
from sklearn.discriminant_analysis import (
LinearDiscriminantAnalysis,
QuadraticDiscriminantAnalysis,
)

HAS_SKLEARN = True
except ImportError:
HAS_SKLEARN = False


class SkLDA(BaseDiscriminator):
"""A wrapper for the SKlearn linear discriminant analysis."""
"""A wrapper for the scikit-learn linear discriminant analysis.

.. note::
This class requires that scikit-learn is installed.
"""

def __init__(self, lda: "LinearDiscriminantAnalysis"):
"""
Expand All @@ -40,11 +40,6 @@ def __init__(self, lda: "LinearDiscriminantAnalysis"):
Raises:
DataProcessorError: if SKlearn could not be imported.
"""
if not HAS_SKLEARN:
raise DataProcessorError(
f"SKlearn is needed to initialize an {self.__class__.__name__}."
)

self._lda = lda
self.attributes = [
"coef_",
Expand Down Expand Up @@ -88,11 +83,10 @@ def config(self) -> Dict[str, Any]:
return {"params": self._lda.get_params(), "attributes": attr_conf}

@classmethod
@HAS_SKLEARN.require_in_call
def from_config(cls, config: Dict[str, Any]) -> "SkLDA":
"""Deserialize from an object."""

if not HAS_SKLEARN:
raise DataProcessorError(f"SKlearn is needed to initialize an {cls.__name__}.")
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

lda = LinearDiscriminantAnalysis()
lda.set_params(**config["params"])
Expand All @@ -105,7 +99,11 @@ def from_config(cls, config: Dict[str, Any]) -> "SkLDA":


class SkQDA(BaseDiscriminator):
"""A wrapper for the SKlearn quadratic discriminant analysis."""
"""A wrapper for the SKlearn quadratic discriminant analysis.

.. note::
This class requires that scikit-learn is installed.
"""

def __init__(self, qda: "QuadraticDiscriminantAnalysis"):
"""
Expand All @@ -116,11 +114,6 @@ def __init__(self, qda: "QuadraticDiscriminantAnalysis"):
Raises:
DataProcessorError: if SKlearn could not be imported.
"""
if not HAS_SKLEARN:
raise DataProcessorError(
f"SKlearn is needed to initialize an {self.__class__.__name__}."
)

self._qda = qda
self.attributes = [
"coef_",
Expand Down Expand Up @@ -165,11 +158,10 @@ def config(self) -> Dict[str, Any]:
return {"params": self._qda.get_params(), "attributes": attr_conf}

@classmethod
@HAS_SKLEARN.require_in_call
def from_config(cls, config: Dict[str, Any]) -> "SkQDA":
"""Deserialize from an object."""

if not HAS_SKLEARN:
raise DataProcessorError(f"SKlearn is needed to initialize an {cls.__name__}.")
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis

qda = QuadraticDiscriminantAnalysis()
qda.set_params(**config["params"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,20 @@

"""Multi state discrimination analysis."""

from typing import List, Tuple
from typing import List, Tuple, TYPE_CHECKING

import matplotlib
import numpy as np

from qiskit.providers.options import Options
from qiskit_experiments.framework import BaseAnalysis, AnalysisResultData, ExperimentData
from qiskit_experiments.data_processing import SkQDA
from qiskit_experiments.data_processing.exceptions import DataProcessorError
from qiskit_experiments.visualization import BasePlotter, IQPlotter, MplDrawer, PlotStyle
from qiskit_experiments.warnings import HAS_SKLEARN

try:
if TYPE_CHECKING:
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis

HAS_SKLEARN = True
except ImportError:
HAS_SKLEARN = False


class MultiStateDiscriminationAnalysis(BaseAnalysis):
r"""This class fits a multi-state discriminator to the data.
Expand All @@ -43,22 +39,13 @@ class MultiStateDiscriminationAnalysis(BaseAnalysis):

Here, :math:`d` is the number of levels that were discriminated while :math:`P(i|j)` is the
probability of measuring outcome :math:`i` given that state :math:`j` was prepared.
"""

def __init__(self):
"""Setup the analysis.

Raises:
DataProcessorError: if sklearn is not installed.
"""
if not HAS_SKLEARN:
raise DataProcessorError(
f"SKlearn is needed to initialize an {self.__class__.__name__}."
)

super().__init__()
.. note::
This class requires that scikit-learn is installed.
"""

@classmethod
@HAS_SKLEARN.require_in_call
def _default_options(cls) -> Options:
"""Return default analysis options.

Expand All @@ -76,6 +63,8 @@ def _default_options(cls) -> Options:
)
options.plot = True
options.ax = None
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis

options.discriminator = SkQDA(QuadraticDiscriminantAnalysis())
return options

Expand Down
14 changes: 14 additions & 0 deletions qiskit_experiments/warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import warnings
from typing import Callable, Optional, Type, Dict

from qiskit.utils.lazy_tester import LazyImportTester


def deprecated_function(
last_version: Optional[str] = None,
Expand Down Expand Up @@ -240,3 +242,15 @@ def wrapper(*args, **kwargs):
return wrapper

return decorator


HAS_SKLEARN = LazyImportTester(
{
"sklearn.discriminant_analysis": (
"LinearDiscriminantAnalysis",
"QuadraticDiscriminantAnalysis",
)
},
name="scikit-learn",
install="pip install scikit-learn",
)
9 changes: 9 additions & 0 deletions releasenotes/notes/sklearn-imports-c82155c0a2c81811.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
fixes:
- |
The importing of ``scikit-learn`` was moved from module-level imports
inside of ``try`` blocks to dynamic imports at first usage of the
``scikit-learn`` specific feature. This change should avoid errors in the
installation of ``scikit-learn`` from preventing a user using features of
``qiskit-experiments`` that do not require ``scikit-learn``. See `#1050
<https://github.com/Qiskit/qiskit-experiments/issues/1050>`_.
Loading