diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 0c725625ba..8940534f88 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -26,7 +26,7 @@ jobs: - name: Lint with flake8 run: | # ignore formatting, it will be checked by black - export FORMATTING_RULES="E101,E111,E114,E115,E116,E117,E12,E13,E2,E3,E401,E5,E70,W1,W2,W3,W5" + export FORMATTING_RULES="E101,E111,E114,E115,E116,E117,E12,E13,E2,E3,E401,E5,E70,W1,W2,W3,W5" flake8 --ignore=$FORMATTING_RULES . - name: Lint with black run: | diff --git a/docs/requirements.txt b/docs/requirements.txt index 24df6b74be..ff1586e6a6 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -9,3 +9,4 @@ sphinx_rtd_theme tqdm traitlets>=5.0 jinja2 < 3.1 +pandas diff --git a/examples/requirements.txt b/examples/requirements.txt index 7cbfb2eff2..bd20b71a95 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -1,3 +1,4 @@ +pandas jupyter jupyter_contrib_nbextensions matplotlib diff --git a/setup.cfg b/setup.cfg index 1645bf2365..366e279b6d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,4 +23,4 @@ zip_safe = True packages = find: install_requires = numpy - scikit-learn>=1.1.0 + scikit-learn>=1.1.0 diff --git a/skcosmo/datasets/__init__.py b/skcosmo/datasets/__init__.py index add589b581..e916109b22 100644 --- a/skcosmo/datasets/__init__.py +++ b/skcosmo/datasets/__init__.py @@ -2,6 +2,12 @@ load_csd_1000r, load_degenerate_CH4_manifold, load_nice_dataset, + load_who_dataset, ) -__all__ = ["load_degenerate_CH4_manifold", "load_csd_1000r", "load_nice_dataset"] +__all__ = [ + "load_degenerate_CH4_manifold", + "load_csd_1000r", + "load_nice_dataset", + "load_who_dataset", +] diff --git a/skcosmo/datasets/_base.py b/skcosmo/datasets/_base.py index ed0ba0286e..694145969d 100644 --- a/skcosmo/datasets/_base.py +++ b/skcosmo/datasets/_base.py @@ -4,7 +4,10 @@ ) import numpy as np -from sklearn.utils import Bunch +from sklearn.utils import ( + Bunch, + check_pandas_support, +) def load_nice_dataset(): @@ -91,3 +94,23 @@ def load_csd_1000r(return_X_y=False): return Bunch(data=data, DESCR=fdescr) else: return raw_data["X"], raw_data["Y"] + + +def load_who_dataset(): + """Load and returns WHO dataset. + Returns + ------- + who_dataset : sklearn.utils.Bunch + Dictionary-like object, with the following attributes: + data : `pandas.core.frame.DataFrame` -- the WHO dataset + as a Pandas dataframe. + DESCR: `str` -- The full description of the dataset. + """ + + module_path = dirname(__file__) + target_filename = join(module_path, "data", "who_dataset.csv") + pd = check_pandas_support("load_who_dataset") + raw_data = pd.read_csv(target_filename) + with open(join(module_path, "descr", "who_dataset.rst")) as rst_file: + fdescr = rst_file.read() + return Bunch(data=raw_data, DESCR=fdescr) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 24615443b3..74ae121da8 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,9 +1,11 @@ import unittest +import numpy as np from skcosmo.datasets import ( load_degenerate_CH4_manifold, load_csd_1000r, load_nice_dataset, + load_who_dataset, ) @@ -58,5 +60,49 @@ def test_load_csd_1000r_access_descr(self): self.csd.DESCR +class WHOTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.size = 24240 + cls.shape = (2020, 12) + cls.value = 5.00977993011475 + try: + import pandas as pd # noqa F401 + + cls.has_pandas = True + cls.who = load_who_dataset() + except ImportError: + cls.has_pandas = False + + def test_load_dataset_without_pandas(self): + """ + Check if the correct exception occurs when pandas isn't present. + """ + if self.has_pandas is False: + with self.assertRaises(ImportError) as cm: + _ = load_who_dataset() + self.assertEqual(str(cm.exception), "load_who_dataset requires pandas.") + + def test_dataset_size_and_shape(self): + """ + Check if the correct number of datapoints are present in the dataset. + Also check if the size of the dataset is correct. + """ + if self.has_pandas is True: + self.assertEqual(self.who["data"].size, self.size) + self.assertEqual(self.who["data"].shape, self.shape) + + def test_datapoint_value(self): + """ + Check if the value of a datapoint at a certain location is correct. + """ + if self.has_pandas is True: + self.assertTrue( + np.allclose( + self.who["data"]["SE.XPD.TOTL.GD.ZS"][1924], self.value, rtol=1e-6 + ) + ) + + if __name__ == "__main__": unittest.main()