From 8dffbf17864aa6e03c6c0155cae67772b2476f82 Mon Sep 17 00:00:00 2001 From: Jiaxuan Date: Mon, 10 Jan 2022 17:42:01 -0800 Subject: [PATCH] Generic `kumo_loader` function that can points to snowflake/s3/local dataset (#156) * Add a generic `kumo_loader` function that can points to snowflake/s3/local dataset * Clean up code, switch test env from snowflake to s3 to save costs from CI * lint * change test data location to local --- benchmark/train/configs/financial.yaml | 7 +- .../train/configs/financial_regression.yaml | 1 - .../train/configs/imdb_classification.yaml | 1 - kumo/config/config.py | 34 +++--- kumo/custom/loader/financial.py | 49 --------- kumo/train/loader.py | 104 +++++++++++++++++- test/train/configs/financial.yaml | 21 ++-- test/train/test_trainer.py | 11 -- 8 files changed, 135 insertions(+), 93 deletions(-) delete mode 100644 kumo/custom/loader/financial.py diff --git a/benchmark/train/configs/financial.yaml b/benchmark/train/configs/financial.yaml index 2641f339..cc71fe7e 100644 --- a/benchmark/train/configs/financial.yaml +++ b/benchmark/train/configs/financial.yaml @@ -6,9 +6,10 @@ snowflake: warehouse: WH_XS database: kumo dataset: - format: snowflake - location: local - name: Financial + format: csv + data_dir: 's3://kumo-datasets' + metadata_dir: 'test/csv_data' + name: FINANCIAL target_table: LOAN target_column: STATUS task: node diff --git a/benchmark/train/configs/financial_regression.yaml b/benchmark/train/configs/financial_regression.yaml index d7b8903a..c9368379 100644 --- a/benchmark/train/configs/financial_regression.yaml +++ b/benchmark/train/configs/financial_regression.yaml @@ -7,7 +7,6 @@ snowflake: database: kumo dataset: format: snowflake - location: local name: Financial target_table: LOAN target_column: AMOUNT diff --git a/benchmark/train/configs/imdb_classification.yaml b/benchmark/train/configs/imdb_classification.yaml index 9d0e390d..64986b8e 100644 --- a/benchmark/train/configs/imdb_classification.yaml +++ b/benchmark/train/configs/imdb_classification.yaml @@ -7,7 +7,6 @@ snowflake: database: kumo dataset: format: snowflake - location: local name: IMDB target_table: U2BASE target_column: RATING diff --git a/kumo/config/config.py b/kumo/config/config.py index 990daa63..4722d122 100644 --- a/kumo/config/config.py +++ b/kumo/config/config.py @@ -20,31 +20,37 @@ def set_cfg(cfg): :return: configuration use by the experiment. ''' - # Set defaults + # Change defaults from PyG GraphGym version cfg.model.type = 'heterognn' cfg.gnn.head = 'node' cfg.dataset.name = 'Financial' cfg.gnn.dim_emb = 16 + + # Overwrite GraphGym scheduler + # (might improve default optimizer after more training experiences + # on databases) + cfg.optim.scheduler = 'none' + # ----------------------------------------------------------------------- # - # Snowflake options + # New options in Kumo # ----------------------------------------------------------------------- # - cfg.snowflake = CN() + cfg.snowflake = CN() # Account name cfg.snowflake.account = 'xva19026' - # User name cfg.snowflake.user = '' - # Password cfg.snowflake.password = '' - # Warehouse name cfg.snowflake.warehouse = 'WH_XS' - # Database name cfg.snowflake.database = 'kumo' + # directory for dataset + cfg.dataset.data_dir = '' + # directory for dataset metadata + cfg.dataset.metadata_dir = 'test/csv_data' # Default to random split cfg.dataset.split_type = 'random' # If split_type == 'column', split by the values of the column: @@ -52,6 +58,11 @@ def set_cfg(cfg): # the highest values of this column will be in the test split. # Restriction: the split column has to be in the prediction target table. cfg.dataset.split_column = None + # TODO: Duplicate `label_table` and `label_column` + cfg.dataset.target_table = cfg.dataset.label_table + cfg.dataset.target_column = cfg.dataset.label_column + # Tables where shallow embeddings are included for feature augmentation + cfg.dataset.augment_table = [] # early stopping configs cfg.optim.early_stopping = False @@ -59,14 +70,5 @@ def set_cfg(cfg): # if None, set patience = total num epochs / 10 cfg.optim.patience = None - # Overwrite GraphGym scheduler - # (might improve default optimizer after more training experiences - # on databases) - cfg.optim.scheduler = 'none' - - # TODO: Duplicate `label_table` and `label_column` - cfg.dataset.target_table = cfg.dataset.label_table - cfg.dataset.target_column = cfg.dataset.label_column - set_cfg(cfg) diff --git a/kumo/custom/loader/financial.py b/kumo/custom/loader/financial.py deleted file mode 100644 index a85b49f3..00000000 --- a/kumo/custom/loader/financial.py +++ /dev/null @@ -1,49 +0,0 @@ -import os -import os.path as osp - -from torch_geometric.graphgym.register import register_loader - -from kumo.store import Store -from kumo.scan import DatabaseMetadata, DatabaseStats -from kumo.connector import SnowflakeConnector, CSVConnector - - -def load_financial(format, name, dataset_dir, target_table_name, - target_column_name, split_column_name=None): - if name != "Financial": - return None - - root_dir = osp.join(osp.dirname(osp.realpath(__file__)), "..", "..", "..") - root_dir = osp.join(root_dir, "test", "csv_data", "FINANCIAL") - dbmeta = DatabaseMetadata.load(osp.join(root_dir, "metadata.yml")) - dbmeta.set_target(target_table_name, target_column_name) - if split_column_name is not None: - dbmeta.set_split(target_table_name, split_column_name) - - if format == "snowflake": - connector = SnowflakeConnector( - account=os.getenv("SNOWFLAKE_ACCOUNT"), - user=os.getenv("SNOWFLAKE_USER"), - password=os.getenv("SNOWFLAKE_PASSWORD"), warehouse="WH_XS", - database="KUMO", schema=name) - elif format == "s3": - s3_root_dir = "s3://kumo-datasets/financial/csv" - connector = CSVConnector(s3_root_dir, na_values="?") - else: - raise ValueError("Unrecognized database format: {}".format(format)) - - dbstats = DatabaseStats.from_connector(connector, dbmeta) - dbstats.print_summary() - dbaugment = [] - - data = Store.from_connector(connector, dbmeta, dbstats, dbaugment) - - # TODO: Temporary work around (merge targets for LOAN.STATUS) - if target_table_name == "LOAN" and target_column_name == "STATUS": - data['LOAN'].y[data["LOAN"].y == 2] = 0 - data["LOAN"].y[data["LOAN"].y == 3] = 1 - - return data - - -register_loader("financial", load_financial) diff --git a/kumo/train/loader.py b/kumo/train/loader.py index d44c62aa..626f0877 100644 --- a/kumo/train/loader.py +++ b/kumo/train/loader.py @@ -1,4 +1,5 @@ import logging +import os import os.path as osp import pandas as pd import torch @@ -13,6 +14,8 @@ from kumo.store import Store from torch_geometric.graphgym.loader import index2mask from torch_geometric.data import InMemoryDataset +from kumo.scan import DatabaseMetadata, DatabaseStats +from kumo.connector import SnowflakeConnector, CSVConnector def preprocess_dataset(data: Store): @@ -28,7 +31,99 @@ def preprocess_dataset(data: Store): for key in data.metadata()[0]: if "x" not in data[key]: data[key].x = torch.ones(data[key].num_nodes, 1) - # todo: necessary data transformation + return data + + +def kumo_loader(cfg: CfgNode): + """ + Load dataset following Kumo support APIs + Args: + cfg (CfgNode): Global config object + + Returns: Kumo Store object + + """ + + # Load dataset metadata + metadata_dir = cfg.dataset.metadata_dir + assert metadata_dir is not None, \ + "cfg.dataset.metadata_dir is required in yaml config" + if not osp.isabs(metadata_dir): + metadata_dir = osp.join(osp.dirname(osp.realpath(__file__)), "..", + "..", metadata_dir) + if cfg.dataset.name not in metadata_dir: + metadata_dir = osp.join(metadata_dir, cfg.dataset.name) + if ".yml" not in metadata_dir and ".yaml" not in metadata_dir: + metadata_dir = osp.join(metadata_dir, "metadata.yml") # default + dbmeta = DatabaseMetadata.load(metadata_dir) + dbmeta.set_target(cfg.dataset.target_table, cfg.dataset.target_column) + if cfg.dataset.split_column is not None: + dbmeta.set_split(cfg.dataset.target_table, cfg.dataset.split_column) + + # Load dataset + if cfg.dataset.format == "snowflake": + account = cfg.snowflake.account if cfg.snowflake.account is not None \ + else os.getenv("SNOWFLAKE_ACCOUNT") + user = cfg.snowflake.user if cfg.snowflake.user is not None \ + else os.getenv("SNOWFLAKE_USER") + password = cfg.snowflake.password \ + if cfg.snowflake.password is not None \ + else os.getenv("SNOWFLAKE_PASSWORD") + warehouse = cfg.snowflake.warehouse \ + if cfg.snowflake.warehouse is not None \ + else os.getenv("SNOWFLAKE_WAREHOUSE") + database = cfg.snowflake.database \ + if cfg.snowflake.database is not None \ + else os.getenv("SNOWFLAKE_DATABASE") + assert account is not None, \ + "SNOWFLAKE_ACCOUNT required in environment variable or yaml config" + assert user is not None, \ + "SNOWFLAKE_USER required in environment variable or yaml config" + assert password is not None, \ + "SNOWFLAKE_PASSWORD required in " \ + "environment variable or yaml config" + assert warehouse is not None, \ + "SNOWFLAKE_WAREHOUSE required in " \ + "environment variable or yaml config" + assert database is not None, \ + "SNOWFLAKE_DATABASE required in " \ + "environment variable or yaml config" + connector = SnowflakeConnector(account=account, user=user, + password=password, warehouse=warehouse, + database=database, + schema=cfg.dataset.name) + elif cfg.dataset.format == "csv": + data_dir = cfg.dataset.data_dir + assert data_dir is not None, "cfg.dataset.data_dir is required" + if "s3:" in data_dir: + if cfg.dataset.name not in data_dir: + data_dir = osp.join(data_dir, cfg.dataset.name.lower(), "csv") + elif not osp.isabs(data_dir): + data_dir = osp.join(osp.dirname(osp.realpath(__file__)), "..", + "..", data_dir) + if cfg.dataset.name not in data_dir: + data_dir = osp.join(data_dir, cfg.dataset.name) + else: + raise ValueError("{} not found".format(cfg.dataset.data_dir)) + + connector = CSVConnector(data_dir, na_values="?") + else: + raise ValueError("Unrecognized database format: {}".format( + cfg.dataset.format)) + + dbstats = DatabaseStats.from_connector(connector, dbmeta) + dbstats.print_summary() + + data = Store.from_connector(connector, dbmeta, dbstats, + cfg.dataset.augment_table) + + # todo: Temporary work around for financial prediction task + # todo: in order to compare results with baselines + if cfg.dataset.target_table == "LOAN" and \ + cfg.dataset.target_column == "STATUS": + data['LOAN'].y[data["LOAN"].y == 2] = 0 + data["LOAN"].y[data["LOAN"].y == 3] = 1 + return data @@ -47,13 +142,16 @@ def load_dataset(cfg: CfgNode, **kwargs): dataset_dir = osp.join(cfg.dataset.dir, name) target_table = cfg.dataset.target_table target_column = cfg.dataset.target_column - # Try to load customized data format + # First try to load with any customized data loader for func in register.loader_dict.values(): dataset = func(format, name, dataset_dir, target_table, target_column, cfg.dataset.split_column) if dataset is not None: return dataset - if format == "PyG": + # Then try to load with standard Kumo data loaders + if format == "snowflake" or format == "csv": + dataset = kumo_loader(cfg) + elif format == "PyG": try: dataset = getattr(pyg_dataset, name)(dataset_dir, diff --git a/test/train/configs/financial.yaml b/test/train/configs/financial.yaml index d91e5b23..010b1c33 100644 --- a/test/train/configs/financial.yaml +++ b/test/train/configs/financial.yaml @@ -6,40 +6,43 @@ snowflake: warehouse: WH_XS database: kumo dataset: - format: snowflake - location: local - name: Financial + format: csv + data_dir: 'test/csv_data' + metadata_dir: 'test/csv_data' + name: FINANCIAL target_table: LOAN target_column: STATUS task: node task_type: classification split: [0.8, 0.1, 0.1] + split_mode: random + split_column: DATE # only needed when split_mode = column encoder: True encoder_name: db encoder_bn: True train: mode: db_fast - sampler: neighbor + sampler: full_batch neighbor_sizes: [10,10,10,10] batch_size: 512 eval_period: 20 ckpt_period: 100 val: - sampler: neighbor + sampler: full_batch model: type: heterognn gnn: - layers_pre_mp: 1 + layers_pre_mp: 2 layers_mp: 3 layers_post_mp: 1 dim_inner: 64 layer_type: SAGEConv stage_type: stack - batchnorm: True + batchnorm: False act: prelu dropout: 0.0 agg: mean optim: optimizer: adam - base_lr: 0.01 - max_epoch: 200 + base_lr: 0.001 + max_epoch: 50 diff --git a/test/train/test_trainer.py b/test/train/test_trainer.py index 2130f317..fff8001e 100644 --- a/test/train/test_trainer.py +++ b/test/train/test_trainer.py @@ -1,5 +1,4 @@ import os.path as osp -from os import environ from collections import namedtuple @@ -21,16 +20,6 @@ def test_loader(): Args = namedtuple("Args", ["cfg_file", "repeat", "opts"]) args = Args(osp.join(root, "configs", "financial.yaml"), 1, []) - if ( - environ.get("SNOWFLAKE_ACCOUNT") is None - or environ.get("SNOWFLAKE_USER") is None - or environ.get("SNOWFLAKE_PASSWORD") is None - ): - raise Exception( - "Set Snowflake env (SNOWFLAKE_ACCOUNT, " - "SNOWFLAKE_USER, SNOWFLAKE_PASSWORD)" - ) - load_cfg(cfg, args) dump_cfg(cfg) # Repeat for different random seeds