diff --git a/nfl_data_py/__init__.py b/nfl_data_py/__init__.py index ab6c076..c6a2ecc 100644 --- a/nfl_data_py/__init__.py +++ b/nfl_data_py/__init__.py @@ -1,15 +1,15 @@ name = 'nfl_data_py' -import datetime import os import logging -from concurrent.futures import ThreadPoolExecutor, as_completed +import datetime from warnings import warn +from typing import Iterable +from concurrent.futures import ThreadPoolExecutor, as_completed -import appdirs import numpy import pandas -from typing import Iterable +import appdirs # module level doc string __doc__ = """ @@ -735,52 +735,32 @@ def import_ids(columns=None, ids=None): """Import mapping table of ids for most major data providers Args: - columns (List[str]): list of columns to return - ids (List[str]): list of specific ids to return + columns (Iterable[str]): list of columns to return + ids (Iterable[str]): list of specific ids to return Returns: DataFrame """ - - # create list of id options - avail_ids = ['mfl_id', 'sportradar_id', 'fantasypros_id', 'gsis_id', 'pff_id', - 'sleeper_id', 'nfl_id', 'espn_id', 'yahoo_id', 'fleaflicker_id', - 'cbs_id', 'rotowire_id', 'rotoworld_id', 'ktc_id', 'pfr_id', - 'cfbref_id', 'stats_id', 'stats_global_id', 'fantasy_data_id'] - avail_sites = [x[:-3] for x in avail_ids] - - # check variable types - if columns is None: - columns = [] - - if ids is None: - ids = [] - if not isinstance(columns, list): - raise ValueError('columns variable must be list.') - - if not isinstance(ids, list): - raise ValueError('ids variable must be list.') - - # confirm id is in table - if False in [x in avail_sites for x in ids]: - raise ValueError('ids variable can only contain ' + ', '.join(avail_sites)) + columns = columns or [] + if not isinstance(columns, Iterable): + raise ValueError('columns argument must be a list.') + + ids = ids or [] + if not isinstance(ids, Iterable): + raise ValueError('ids argument must be a list.') - # import data - df = pandas.read_csv(r'https://raw.githubusercontent.com/dynastyprocess/data/master/files/db_playerids.csv') + df = pandas.read_csv("https://raw.githubusercontent.com/dynastyprocess/data/master/files/db_playerids.csv") - rem_cols = [x for x in df.columns if x not in avail_ids] - tgt_ids = [x + '_id' for x in ids] - - # filter df to just specified columns - if len(columns) > 0 and len(ids) > 0: - df = df[set(tgt_ids + columns)] - elif len(columns) > 0 and len(ids) == 0: - df = df[set(avail_ids + columns)] - elif len(columns) == 0 and len(ids) > 0: - df = df[set(tgt_ids + rem_cols)] + id_cols = [c for c in df.columns if c.endswith('_id')] + non_id_cols = [c for c in df.columns if not c.endswith('_id')] - return df + # filter df to just specified ids + columns + ret_ids = [x + '_id' for x in ids] or id_cols + ret_cols = columns or non_id_cols + ret_columns = list(set([*ret_ids, *ret_cols])) + + return df[ret_columns] def import_contracts(): diff --git a/nfl_data_py/tests/nfl_test.py b/nfl_data_py/tests/nfl_test.py index 763a886..980ad3f 100644 --- a/nfl_data_py/tests/nfl_test.py +++ b/nfl_data_py/tests/nfl_test.py @@ -167,6 +167,29 @@ def test_is_df_with_data(self): s = nfl.import_ids() self.assertEqual(True, isinstance(s, pd.DataFrame)) self.assertTrue(len(s) > 0) + + def test_import_using_ids(self): + ids = ["espn", "yahoo", "gsis"] + s = nfl.import_ids(ids=ids) + self.assertTrue(all([f"{id}_id" in s.columns for id in ids])) + + def test_import_using_columns(self): + ret_columns = ["name", "birthdate", "college"] + not_ret_columns = ["draft_year", "db_season", "team"] + s = nfl.import_ids(columns=ret_columns) + self.assertTrue(all([column in s.columns for column in ret_columns])) + self.assertTrue(all([column not in s.columns for column in not_ret_columns])) + + def test_import_using_ids_and_columns(self): + ret_ids = ["espn", "yahoo", "gsis"] + ret_columns = ["name", "birthdate", "college"] + not_ret_ids = ["cfbref_id", "pff_id", "prf_id"] + not_ret_columns = ["draft_year", "db_season", "team"] + s = nfl.import_ids(columns=ret_columns, ids=ret_ids) + self.assertTrue(all([column in s.columns for column in ret_columns])) + self.assertTrue(all([column not in s.columns for column in not_ret_columns])) + self.assertTrue(all([f"{id}_id" in s.columns for id in ret_ids])) + self.assertTrue(all([f"{id}_id" not in s.columns for id in not_ret_ids])) class test_ngs(TestCase): def test_is_df_with_data(self):