Skip to content

Commit

Permalink
Add who dataset, examples and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
agoscinski committed Feb 24, 2023
1 parent cdb6195 commit c967a65
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Lint with flake8
run: |
# ignore formatting, it will be checked by black
export FORMATTING_RULES="E101,E111,E114,E115,E116,E117,E12,E13,E2,E3,E401,E5,E70,W1,W2,W3,W5"
export FORMATTING_RULES="E101,E111,E114,E115,E116,E117,E12,E13,E2,E3,E401,E5,E70,W1,W2,W3,W5"
flake8 --ignore=$FORMATTING_RULES .
- name: Lint with black
run: |
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ sphinx_rtd_theme
tqdm
traitlets>=5.0
jinja2 < 3.1
pandas
1 change: 1 addition & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pandas
jupyter
jupyter_contrib_nbextensions
matplotlib
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ zip_safe = True
packages = find:
install_requires =
numpy
scikit-learn>=1.1.0
scikit-learn>=1.1.0
8 changes: 7 additions & 1 deletion skcosmo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
load_csd_1000r,
load_degenerate_CH4_manifold,
load_nice_dataset,
load_who_dataset,
)

__all__ = ["load_degenerate_CH4_manifold", "load_csd_1000r", "load_nice_dataset"]
__all__ = [
"load_degenerate_CH4_manifold",
"load_csd_1000r",
"load_nice_dataset",
"load_who_dataset",
]
25 changes: 24 additions & 1 deletion skcosmo/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
)

import numpy as np
from sklearn.utils import Bunch
from sklearn.utils import (
Bunch,
check_pandas_support,
)


def load_nice_dataset():
Expand Down Expand Up @@ -91,3 +94,23 @@ def load_csd_1000r(return_X_y=False):
return Bunch(data=data, DESCR=fdescr)
else:
return raw_data["X"], raw_data["Y"]


def load_who_dataset():
"""Load and returns WHO dataset.
Returns
-------
who_dataset : sklearn.utils.Bunch
Dictionary-like object, with the following attributes:
data : `pandas.core.frame.DataFrame` -- the WHO dataset
as a Pandas dataframe.
DESCR: `str` -- The full description of the dataset.
"""

module_path = dirname(__file__)
target_filename = join(module_path, "data", "who_dataset.csv")
pd = check_pandas_support("load_who_dataset")
raw_data = pd.read_csv(target_filename)
with open(join(module_path, "descr", "who_dataset.rst")) as rst_file:
fdescr = rst_file.read()
return Bunch(data=raw_data, DESCR=fdescr)
46 changes: 46 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import unittest
import numpy as np

from skcosmo.datasets import (
load_degenerate_CH4_manifold,
load_csd_1000r,
load_nice_dataset,
load_who_dataset,
)


Expand Down Expand Up @@ -58,5 +60,49 @@ def test_load_csd_1000r_access_descr(self):
self.csd.DESCR


class WHOTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.size = 24240
cls.shape = (2020, 12)
cls.value = 5.00977993011475
try:
import pandas as pd # noqa F401

cls.has_pandas = True
cls.who = load_who_dataset()
except ImportError:
cls.has_pandas = False

def test_load_dataset_without_pandas(self):
"""
Check if the correct exception occurs when pandas isn't present.
"""
if self.has_pandas is False:
with self.assertRaises(ImportError) as cm:
_ = load_who_dataset()
self.assertEqual(str(cm.exception), "load_who_dataset requires pandas.")

def test_dataset_size_and_shape(self):
"""
Check if the correct number of datapoints are present in the dataset.
Also check if the size of the dataset is correct.
"""
if self.has_pandas is True:
self.assertEqual(self.who["data"].size, self.size)
self.assertEqual(self.who["data"].shape, self.shape)

def test_datapoint_value(self):
"""
Check if the value of a datapoint at a certain location is correct.
"""
if self.has_pandas is True:
self.assertTrue(
np.allclose(
self.who["data"]["SE.XPD.TOTL.GD.ZS"][1924], self.value, rtol=1e-6
)
)


if __name__ == "__main__":
unittest.main()

0 comments on commit c967a65

Please sign in to comment.