From 0979d443333fa6eca3db4578ecd08b21b65166dc Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 26 Oct 2024 01:12:20 +0800 Subject: [PATCH 1/5] refactor: normalize some code; --- pypots/classification/csai/core.py | 22 ----------- pypots/classification/csai/model.py | 57 +++++++++++++++++---------- pypots/imputation/csai/core.py | 5 ++- pypots/imputation/csai/data.py | 57 +++++++++++++++++++-------- pypots/imputation/csai/model.py | 17 ++++---- pypots/imputation/segrnn/core.py | 5 +-- pypots/nn/modules/csai/backbone.py | 11 +++--- pypots/nn/modules/csai/layers.py | 24 ++++++----- pypots/nn/modules/imputeformer/mlp.py | 4 +- pypots/nn/modules/segrnn/backbone.py | 2 - tests/classification/csai.py | 31 ++++++--------- tests/imputation/csai.py | 23 ++++------- 12 files changed, 128 insertions(+), 130 deletions(-) diff --git a/pypots/classification/csai/core.py b/pypots/classification/csai/core.py index 97a1fecc..dbdd025b 100644 --- a/pypots/classification/csai/core.py +++ b/pypots/classification/csai/core.py @@ -11,28 +11,6 @@ from ...nn.modules.csai import BackboneBCSAI -# class DiceBCELoss(nn.Module): -# def __init__(self, weight=None, size_average=True): -# super(DiceBCELoss, self).__init__() -# self.bcelogits = nn.BCEWithLogitsLoss() - -# def forward(self, y_score, y_out, targets, smooth=1): - -# #comment out if your model contains a sigmoid or equivalent activation layer -# # inputs = F.sigmoid(inputs) - -# #flatten label and prediction tensors -# BCE = self.bcelogits(y_out, targets) - -# y_score = y_score.view(-1) -# targets = targets.view(-1) -# intersection = (y_score * targets).sum() -# dice_loss = 1 - (2.*intersection + smooth)/(y_score.sum() + targets.sum() + smooth) - -# Dice_BCE = BCE + dice_loss - -# return BCE, Dice_BCE - class _BCSAI(nn.Module): def __init__( diff --git a/pypots/classification/csai/model.py b/pypots/classification/csai/model.py index 11c7e117..a504cc1d 100644 --- a/pypots/classification/csai/model.py +++ b/pypots/classification/csai/model.py @@ -6,7 +6,7 @@ # License: BSD-3-Clause from typing import Optional, Union -import numpy as np + import torch from torch.utils.data import DataLoader @@ -60,31 +60,43 @@ class CSAI(BaseNNClassifier): The batch size for training and evaluating the model. epochs : - The number of epochs for training the model. - - dropout : - The dropout rate for the model to prevent overfitting. Default is 0.5. + The number of epochs for training the model. patience : - The patience for the early-stopping mechanism. Given a positive integer, the training process will be stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. optimizer : - The optimizer for model training. If not given, will use a default Adam optimizer. + The optimizer for model training. + If not given, will use a default Adam optimizer. num_workers : - The number of subprocesses to use for data loading. 0 means data loading will be in the main process, i.e. there won't be subprocesses. + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. device : - The device for the model to run on. It can be a string, a :class:torch.device object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')], the model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : - The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during training into a tensorboard file). Will not save if not given. + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. model_saving_strategy : - The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. No model will be saved when it is set as None. The "best" strategy will only automatically save the best model after the training finished. The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. The "all" strategy will save every model after each epoch training. + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. verbose : - Whether to print out the training logs during the training process. + Whether to print out the training logs during the training process. """ @@ -136,6 +148,9 @@ def __init__( self.compute_intervals = compute_intervals self.dropout = dropout self.intervals = None + self.replacement_probabilities = None + self.mean_set = None + self.std_set = None # Initialise empty model self.model = _BCSAI( @@ -230,7 +245,7 @@ def fit( file_type: str = "hdf5", ) -> None: # Create dataset - self.training_set = DatasetForCSAI( + training_set = DatasetForCSAI( data=train_set, file_type=file_type, return_y=True, @@ -239,13 +254,13 @@ def fit( compute_intervals=self.compute_intervals, ) - self.intervals = self.training_set.intervals - self.replacement_probabilities = self.training_set.replacement_probabilities - self.mean_set = self.training_set.mean_set - self.std_set = self.training_set.std_set + self.intervals = training_set.intervals + self.replacement_probabilities = training_set.replacement_probabilities + self.mean_set = training_set.mean_set + self.std_set = training_set.std_set train_loader = DataLoader( - self.training_set, + training_set, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, @@ -321,15 +336,15 @@ def predict( num_workers=self.num_workers, ) - classificaion_results = [] + classification_results = [] with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) results = self.model.forward(inputs, training=False) - classificaion_results.append(results["classification_pred"]) + classification_results.append(results["classification_pred"]) - classification = torch.cat(classificaion_results).cpu().detach().numpy() + classification = torch.cat(classification_results).cpu().detach().numpy() result_dict = { "classification": classification, } diff --git a/pypots/imputation/csai/core.py b/pypots/imputation/csai/core.py index 1d59621c..676647f2 100644 --- a/pypots/imputation/csai/core.py +++ b/pypots/imputation/csai/core.py @@ -62,7 +62,10 @@ class _BCSAI(nn.Module): Notes ----- - BCSAI is a bidirectional imputation model that uses forward and backward GRU cells to handle time-series data. It computes consistency and reconstruction losses to improve imputation accuracy. During training, the forward and backward reconstructions are combined, and losses are used to update the model. In evaluation mode, the model also outputs original data and indicating masks for further analysis. + CSAI is a bidirectional imputation model that uses forward and backward GRU cells to handle time-series data. + It computes consistency and reconstruction losses to improve imputation accuracy. + During training, the forward and backward reconstructions are combined, and losses are used to update the model. + In evaluation mode, the model also outputs original data and indicating masks for further analysis. """ diff --git a/pypots/imputation/csai/data.py b/pypots/imputation/csai/data.py index 26a99a5d..110f1dbd 100644 --- a/pypots/imputation/csai/data.py +++ b/pypots/imputation/csai/data.py @@ -5,15 +5,17 @@ # Created by Linglong Qian, Joseph Arul Raj # License: BSD-3-Clause +import copy from typing import Iterable -from ...data.dataset import BaseDataset +from typing import Union + import numpy as np import torch -from typing import Union -import copy -from ...data.utils import parse_delta from sklearn.preprocessing import StandardScaler +from ...data.dataset import BaseDataset +from ...data.utils import parse_delta + def normalize_csai( data, @@ -22,7 +24,8 @@ def normalize_csai( compute_intervals: bool = False, ): """ - Normalize the data based on the given mean and standard deviation, and optionally compute time intervals between observations. + Normalize the data based on the given mean and standard deviation, + and optionally compute time intervals between observations. Parameters ---------- @@ -33,7 +36,8 @@ def normalize_csai( The mean values for each variable, used for normalization. If empty, means will be computed from the data. std : list of float, optional - The standard deviation values for each variable, used for normalization. If empty, std values will be computed from the data. + The standard deviation values for each variable, used for normalization. + If empty, std values will be computed from the data. compute_intervals : bool, optional, default=False Whether to compute the time intervals between observations for each variable. @@ -47,10 +51,12 @@ def normalize_csai( The mean values for each variable after normalization, either computed from the data or passed as input. std_set : np.ndarray - The standard deviation values for each variable after normalization, either computed from the data or passed as input. + The standard deviation values for each variable after normalization, + either computed from the data or passed as input. intervals_list : dict of int to float, optional - If `compute_intervals` is True, this will return the median time intervals between observations for each variable. + If `compute_intervals` is True, this will return the median time intervals between observations + for each variable. """ # Convert data to numpy array if it is a torch tensor @@ -296,13 +302,23 @@ class DatasetForCSAI(BaseDataset): Parameters ---------- data : - The dataset for model input, which can be either a dictionary or a path string to a data file. If it's a dictionary, `X` should be an array-like structure with shape [n_samples, sequence length (n_steps), n_features], containing the time-series data, and it can have missing values. Optionally, the dictionary can include `y`, an array-like structure with shape [n_samples], representing the labels of `X`. If `data` is a path string, it should point to a data file (e.g., h5 file) that contains key-value pairs like a dictionary, including keys for `X` and possibly `y`. + The dataset for model input, which can be either a dictionary or a path string to a data file. + If it's a dictionary, `X` should be an array-like structure + with shape [n_samples, sequence length (n_steps), n_features], containing the time-series data, + and it can have missing values. Optionally, the dictionary can include `y`, + an array-like structure with shape [n_samples], representing the labels of `X`. + If `data` is a path string, it should point to a data file (e.g., h5 file) that contains key-value pairs like + a dictionary, including keys for `X` and possibly `y`. return_X_ori : - Whether to return the original time-series data (`X_ori`) when fetching data samples, useful for evaluation purposes. + Whether to return the original time-series data (`X_ori`) when fetching data samples, + useful for evaluation purposes. return_y : - Whether to return classification labels in the `__getitem__()` method if they exist in the dataset. If `True`, labels will be included in the returned data samples, which is useful for training classification models. If `False`, the labels won't be returned, suitable for testing or validation stages. + Whether to return classification labels in the `__getitem__()` method if they exist in the dataset. + If `True`, labels will be included in the returned data samples, + which is useful for training classification models. + If `False`, the labels won't be returned, suitable for testing or validation stages. file_type : The type of the data file if `data` is a path string, such as "hdf5". @@ -317,20 +333,29 @@ class DatasetForCSAI(BaseDataset): Whether to compute time intervals between observations for handling irregular time-series data. replacement_probabilities : - Optional precomputed probabilities for sampling missing values. If not provided, they will be calculated during the initialization of the dataset. + Optional precomputed probabilities for sampling missing values. + If not provided, they will be calculated during the initialization of the dataset. normalise_mean : - A list of mean values for normalizing the input features. If not provided, they will be computed during initialization. + A list of mean values for normalizing the input features. + If not provided, they will be computed during initialization. normalise_std : - A list of standard deviation values for normalizing the input features. If not provided, they will be computed during initialization. + A list of standard deviation values for normalizing the input features. + If not provided, they will be computed during initialization. training : - Whether the dataset is used for training. If `False`, it will adjust how data is processed, particularly for evaluation and testing phases. + Whether the dataset is used for training. + If `False`, it will adjust how data is processed, particularly for evaluation and testing phases. Notes ----- - The DatasetForCSAI class is designed for bidirectional imputation of time-series data, handling both forward and backward directions to improve imputation accuracy. It supports on-the-fly data normalization and missing value simulation, making it suitable for training and evaluating deep learning models like CSAI. The class can work with large datasets stored on disk, leveraging lazy-loading to minimize memory usage, and supports both training and testing scenarios, adjusting data handling as needed. + The DatasetForCSAI class is designed for bidirectional imputation of time-series data, + handling both forward and backward directions to improve imputation accuracy. + It supports on-the-fly data normalization and missing value simulation, + making it suitable for training and evaluating deep learning models like CSAI. + The class can work with large datasets stored on disk, leveraging lazy-loading to minimize memory usage, + and supports both training and testing scenarios, adjusting data handling as needed. """ diff --git a/pypots/imputation/csai/model.py b/pypots/imputation/csai/model.py index 19c48960..2d0cc27a 100644 --- a/pypots/imputation/csai/model.py +++ b/pypots/imputation/csai/model.py @@ -11,11 +11,9 @@ import torch from torch.utils.data import DataLoader - from .core import _BCSAI from .data import DatasetForCSAI from ..base import BaseNNImputer -from ...data.checking import key_in_data_set from ...optim.adam import Adam from ...optim.base import Optimizer @@ -146,6 +144,9 @@ def __init__( self.step_channels = step_channels self.compute_intervals = compute_intervals self.intervals = None + self.replacement_probabilities = None + self.mean_set = None + self.std_set = None # Initialise model self.model = _BCSAI( @@ -238,16 +239,16 @@ def fit( file_type: str = "hdf5", ) -> None: - self.training_set = DatasetForCSAI( + training_set = DatasetForCSAI( train_set, False, False, file_type, self.removal_percent, self.increase_factor, self.compute_intervals ) - self.intervals = self.training_set.intervals - self.replacement_probabilities = self.training_set.replacement_probabilities - self.mean_set = self.training_set.mean_set - self.std_set = self.training_set.std_set + self.intervals = training_set.intervals + self.replacement_probabilities = training_set.replacement_probabilities + self.mean_set = training_set.mean_set + self.std_set = training_set.std_set training_loader = DataLoader( - self.training_set, + training_set, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, diff --git a/pypots/imputation/segrnn/core.py b/pypots/imputation/segrnn/core.py index b4978099..9de9cd58 100644 --- a/pypots/imputation/segrnn/core.py +++ b/pypots/imputation/segrnn/core.py @@ -5,13 +5,10 @@ # Created by Shengsheng Lin -from typing import Optional - -from typing import Callable import torch.nn as nn -from ...nn.modules.segrnn import BackboneSegRNN from ...nn.modules.saits import SaitsLoss +from ...nn.modules.segrnn import BackboneSegRNN class _SegRNN(nn.Module): diff --git a/pypots/nn/modules/csai/backbone.py b/pypots/nn/modules/csai/backbone.py index b163a611..41df3a1f 100644 --- a/pypots/nn/modules/csai/backbone.py +++ b/pypots/nn/modules/csai/backbone.py @@ -5,10 +5,11 @@ # Created by Linglong Qian, Joseph Arul Raj # License: BSD-3-Clause +import math + import torch import torch.nn as nn -import torch.nn.functional as F -import math + from .layers import FeatureRegression, Decay, Decay_obs, PositionalEncoding, Conv1dWithInit, TorchTransformerEncoder from ....utils.metrics import calc_mae @@ -91,7 +92,7 @@ class BackboneCSAI(nn.Module): """ def __init__(self, n_steps, n_features, rnn_hidden_size, step_channels, medians_df=None): - super(BackboneCSAI, self).__init__() + super().__init__() if medians_df is not None: self.medians_tensor = torch.tensor(list(medians_df.values())).float() @@ -134,7 +135,7 @@ def forward(self, x, mask, deltas, last_obs, h=None): decay_factor = self.weighted_obs(deltas - medians.unsqueeze(1)) - if h == None: + if h is None: data_last_obs = self.input_projection(last_obs.permute(0, 2, 1)).permute(0, 2, 1) data_decay_factor = self.input_projection(decay_factor.permute(0, 2, 1)).permute(0, 2, 1) @@ -191,7 +192,7 @@ def forward(self, x, mask, deltas, last_obs, h=None): class BackboneBCSAI(nn.Module): def __init__(self, n_steps, n_features, rnn_hidden_size, step_channels, medians_df=None): - super(BackboneBCSAI, self).__init__() + super().__init__() self.model_f = BackboneCSAI(n_steps, n_features, rnn_hidden_size, step_channels, medians_df) self.model_b = BackboneCSAI(n_steps, n_features, rnn_hidden_size, step_channels, medians_df) diff --git a/pypots/nn/modules/csai/layers.py b/pypots/nn/modules/csai/layers.py index 39752262..7d86f034 100644 --- a/pypots/nn/modules/csai/layers.py +++ b/pypots/nn/modules/csai/layers.py @@ -5,22 +5,18 @@ # Created by Joseph Arul Raj # License: BSD-3-Clause +import math + import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from torch.nn.parameter import Parameter -import math -import numpy as np -import os -import copy -import pandas as pd -from torch.nn.modules import TransformerEncoderLayer class FeatureRegression(nn.Module): def __init__(self, input_size): - super(FeatureRegression, self).__init__() + super().__init__() self.build(input_size) def build(self, input_size): @@ -43,7 +39,9 @@ def forward(self, x): class Decay(nn.Module): def __init__(self, input_size, output_size, diag=False): - super(Decay, self).__init__() + super().__init__() + self.W = None + self.b = None self.diag = diag self.build(input_size, output_size) @@ -51,7 +49,7 @@ def build(self, input_size, output_size): self.W = Parameter(torch.Tensor(output_size, input_size)) self.b = Parameter(torch.Tensor(output_size)) - if self.diag == True: + if self.diag: assert input_size == output_size m = torch.eye(input_size, input_size) self.register_buffer("m", m) @@ -64,7 +62,7 @@ def reset_parameters(self): self.b.data.uniform_(-stdv, stdv) def forward(self, d): - if self.diag == True: + if self.diag: gamma = F.relu(F.linear(d, self.W * Variable(self.m), self.b)) else: gamma = F.relu(F.linear(d, self.W, self.b)) @@ -74,7 +72,7 @@ def forward(self, d): class Decay_obs(nn.Module): def __init__(self, input_size, output_size): - super(Decay_obs, self).__init__() + super().__init__() self.linear = nn.Linear(input_size, output_size) def forward(self, delta_diff): @@ -98,7 +96,7 @@ def forward(self, delta_diff): class TorchTransformerEncoder(nn.Module): def __init__(self, heads=8, layers=1, channels=64): - super(TorchTransformerEncoder, self).__init__() + super().__init__() self.encoder_layer = nn.TransformerEncoderLayer( d_model=channels, nhead=heads, dim_feedforward=64, activation="gelu" ) @@ -110,7 +108,7 @@ def forward(self, x): class Conv1dWithInit(nn.Module): def __init__(self, in_channels, out_channels, kernel_size): - super(Conv1dWithInit, self).__init__() + super().__init__() self.conv = nn.Conv1d(in_channels, out_channels, kernel_size) nn.init.kaiming_normal_(self.conv.weight) diff --git a/pypots/nn/modules/imputeformer/mlp.py b/pypots/nn/modules/imputeformer/mlp.py index 12ef62a3..a8e31066 100644 --- a/pypots/nn/modules/imputeformer/mlp.py +++ b/pypots/nn/modules/imputeformer/mlp.py @@ -12,7 +12,7 @@ class Dense(nn.Module): """A simple fully-connected layer.""" def __init__(self, input_size, output_size, dropout=0.0, bias=True): - super(Dense, self).__init__() + super().__init__() self.layer = nn.Sequential( nn.Linear(input_size, output_size, bias=bias), nn.ReLU(), @@ -29,7 +29,7 @@ class MLP(nn.Module): """ def __init__(self, input_size, hidden_size, output_size=None, n_layers=1, dropout=0.0): - super(MLP, self).__init__() + super().__init__() layers = [ Dense( diff --git a/pypots/nn/modules/segrnn/backbone.py b/pypots/nn/modules/segrnn/backbone.py index 5b7f2fda..7d6ff3e8 100644 --- a/pypots/nn/modules/segrnn/backbone.py +++ b/pypots/nn/modules/segrnn/backbone.py @@ -4,8 +4,6 @@ # Created by Shengsheng Lin -from typing import Optional - import torch import torch.nn as nn diff --git a/tests/classification/csai.py b/tests/classification/csai.py index 916c6cf8..17f1028f 100644 --- a/tests/classification/csai.py +++ b/tests/classification/csai.py @@ -72,12 +72,10 @@ def test_0_fit(self): def test_1_classify(self): # Classify test set using the trained CSAI model results = self.csai.classify(TEST_SET) - + # Calculate binary classification metrics - metrics = calc_binary_classification_metrics( - results, DATA["test_y"] - ) - + metrics = calc_binary_classification_metrics(results, DATA["test_y"]) + logger.info( f'CSAI ROC_AUC: {metrics["roc_auc"]}, ' f'PR_AUC: {metrics["pr_auc"]}, ' @@ -85,7 +83,7 @@ def test_1_classify(self): f'Precision: {metrics["precision"]}, ' f'Recall: {metrics["recall"]}' ) - + assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5" @pytest.mark.xdist_group(name="classification-csai") @@ -98,17 +96,12 @@ def test_2_parameters(self): assert hasattr(self.csai, "best_loss") self.assertNotEqual(self.csai.best_loss, float("inf")) - assert ( - hasattr(self.csai, "best_model_dict") - and self.csai.best_model_dict is not None - ) + assert hasattr(self.csai, "best_model_dict") and self.csai.best_model_dict is not None @pytest.mark.xdist_group(name="classification-csai") def test_3_saving_path(self): # Ensure the root saving directory exists - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" + assert os.path.exists(self.saving_path), f"file {self.saving_path} does not exist" # Check if the tensorboard file and model checkpoints exist check_tb_and_model_checkpoints_existence(self.csai) @@ -124,15 +117,13 @@ def test_3_saving_path(self): def test_4_lazy_loading(self): # Fit the CSAI model using lazy-loading datasets from H5 files self.csai.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH) - + # Perform classification using lazy-loaded data results = self.csai.classify(GENERAL_H5_TEST_SET_PATH) - + # Calculate binary classification metrics - metrics = calc_binary_classification_metrics( - results, DATA["test_y"] - ) - + metrics = calc_binary_classification_metrics(results, DATA["test_y"]) + logger.info( f'Lazy-loading CSAI ROC_AUC: {metrics["roc_auc"]}, ' f'PR_AUC: {metrics["pr_auc"]}, ' @@ -140,7 +131,7 @@ def test_4_lazy_loading(self): f'Precision: {metrics["precision"]}, ' f'Recall: {metrics["recall"]}' ) - + assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5" diff --git a/tests/imputation/csai.py b/tests/imputation/csai.py index 492a0a52..986657e9 100644 --- a/tests/imputation/csai.py +++ b/tests/imputation/csai.py @@ -71,14 +71,10 @@ def test_0_fit(self): def test_1_impute(self): # Impute missing values using the trained CSAI model imputed_X = self.csai.impute(TEST_SET) - assert not np.isnan( - imputed_X - ).any(), "Output still has missing values after running impute()." - + assert not np.isnan(imputed_X).any(), "Output still has missing values after running impute()." + # Calculate mean squared error (MSE) for the test set - test_MSE = calc_mse( - imputed_X, DATA["test_X_ori"], DATA["test_X_indicating_mask"] - ) + test_MSE = calc_mse(imputed_X, DATA["test_X_ori"], DATA["test_X_indicating_mask"]) logger.info(f"CSAI test_MSE: {test_MSE}") @pytest.mark.xdist_group(name="imputation-csai") @@ -91,17 +87,12 @@ def test_2_parameters(self): assert hasattr(self.csai, "best_loss") self.assertNotEqual(self.csai.best_loss, float("inf")) - assert ( - hasattr(self.csai, "best_model_dict") - and self.csai.best_model_dict is not None - ) + assert hasattr(self.csai, "best_model_dict") and self.csai.best_model_dict is not None @pytest.mark.xdist_group(name="imputation-csai") def test_3_saving_path(self): # Ensure the root saving directory exists - assert os.path.exists( - self.saving_path - ), f"file {self.saving_path} does not exist" + assert os.path.exists(self.saving_path), f"file {self.saving_path} does not exist" # Check if the tensorboard file and model checkpoints exist check_tb_and_model_checkpoints_existence(self.csai) @@ -117,7 +108,7 @@ def test_3_saving_path(self): def test_4_lazy_loading(self): # Fit the CSAI model using lazy-loading datasets from H5 files self.csai.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH) - + # Perform imputation using lazy-loaded data imputation_results = self.csai.predict(GENERAL_H5_TEST_SET_PATH) assert not np.isnan( @@ -134,4 +125,4 @@ def test_4_lazy_loading(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 87a4168bd9bffb73239f4b3ea097cfc19732c2ac Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 26 Oct 2024 13:20:51 +0800 Subject: [PATCH 2/5] refactor: raise not implemented error for CSAI _fetch_data_from_file; --- pypots/imputation/csai/data.py | 85 ++++++---------------------------- 1 file changed, 14 insertions(+), 71 deletions(-) diff --git a/pypots/imputation/csai/data.py b/pypots/imputation/csai/data.py index 110f1dbd..a72fc322 100644 --- a/pypots/imputation/csai/data.py +++ b/pypots/imputation/csai/data.py @@ -369,14 +369,19 @@ def __init__( increase_factor: float = 0.1, compute_intervals: bool = False, replacement_probabilities=None, - normalise_mean: list = [], - normalise_std: list = [], + normalise_mean=None, + normalise_std=None, training: bool = True, ): super().__init__( data=data, return_X_ori=return_X_ori, return_X_pred=False, return_y=return_y, file_type=file_type ) + if normalise_std is None: + normalise_std = [] + if normalise_mean is None: + normalise_mean = [] + self.removal_percent = removal_percent self.increase_factor = increase_factor self.compute_intervals = compute_intervals @@ -385,6 +390,11 @@ def __init__( self.normalise_std = normalise_std self.training = training + self.normalized_data = None + self.mean_set = None + self.std_set = None + self.intervals = None + if not isinstance(self.data, str): self.normalized_data, self.mean_set, self.std_set, self.intervals = normalize_csai( self.data["X"], @@ -465,73 +475,6 @@ def _fetch_data_from_array(self, idx: int) -> Iterable: } def _fetch_data_from_file(self, idx: int) -> Iterable: - """Fetch data with the lazy-loading strategy, i.e. only loading data from the file while requesting for samples. - Here the opened file handle doesn't load the entire dataset into RAM but only load the currently accessed slice. - - Parameters - ---------- - idx : - The index of the sample to be return. - - Returns - ------- - sample : - The collated data sample, a list including all necessary sample info. - """ - - if self.file_handle is None: - self.file_handle = self._open_file_handle() - - X = torch.from_numpy(self.file_handle["X"][idx]) - normalized_data, mean_set, std_set, intervals = normalize_csai( - X, - self.normalise_mean, - self.normalise_std, - self.compute_intervals, - ) - - processed_data, replacement_probabilities = non_uniform_sample( - normalized_data, - self.removal_percent, - self.replacement_probabilities, - self.increase_factor, + raise NotImplementedError( + "CSAI does not support lazy loading because normalise mean and std need to be calculated ahead." ) - forward_X = processed_data["values"] - forward_missing_mask = processed_data["masks"] - backward_X = torch.flip(forward_X, dims=[1]) - backward_missing_mask = torch.flip(forward_missing_mask, dims=[1]) - - X_ori = self.processed_data["evals"] - indicating_mask = self.processed_data["eval_masks"] - - if self.return_y: - y = self.processed_data["labels"] - - sample = [ - torch.tensor(idx), - # for forward - forward_X, - forward_missing_mask, - processed_data["deltas_f"], - processed_data["last_obs_f"], - # for backward - backward_X, - backward_missing_mask, - processed_data["deltas_b"], - processed_data["last_obs_b"], - ] - - if self.return_X_ori: - sample.extend([X_ori, indicating_mask]) - - # if the dataset has labels and is for training, then fetch it from the file - if self.return_y: - sample.append(y) - - return { - "sample": sample, - "replacement_probabilities": replacement_probabilities, - "mean_set": mean_set, - "std_set": std_set, - "intervals": intervals, - } From b6a32803256313144e8cd80d61f985728292bd06 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sun, 27 Oct 2024 20:05:10 +0800 Subject: [PATCH 3/5] fix: the error that dataset cannot be lazy loaded for CSAI; --- pypots/imputation/csai/model.py | 53 ++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/pypots/imputation/csai/model.py b/pypots/imputation/csai/model.py index 2d0cc27a..4eaab839 100644 --- a/pypots/imputation/csai/model.py +++ b/pypots/imputation/csai/model.py @@ -6,6 +6,7 @@ # License: BSD-3-Clause from typing import Union, Optional +from venv import logger import numpy as np import torch @@ -14,6 +15,8 @@ from .core import _BCSAI from .data import DatasetForCSAI from ..base import BaseNNImputer +from ...data.checking import key_in_data_set +from ...data.saving.h5 import load_dict_from_h5 from ...optim.adam import Adam from ...optim.base import Optimizer @@ -164,6 +167,7 @@ def __init__( # set up the optimizer self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) def _assemble_input_for_training(self, data: list, training=True) -> dict: # extract data @@ -239,8 +243,21 @@ def fit( file_type: str = "hdf5", ) -> None: + if isinstance(train_set, str): + logger.warning( + "CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. " + "Hence the whole train set will be loaded into memory." + ) + train_set = load_dict_from_h5(train_set) + training_set = DatasetForCSAI( - train_set, False, False, file_type, self.removal_percent, self.increase_factor, self.compute_intervals + train_set, + False, + False, + file_type, + self.removal_percent, + self.increase_factor, + self.compute_intervals, ) self.intervals = training_set.intervals self.replacement_probabilities = training_set.replacement_probabilities @@ -254,7 +271,17 @@ def fit( num_workers=self.num_workers, # collate_fn=collate_fn_bidirectional ) + val_loader = None if val_set is not None: + if isinstance(val_set, str): + logger.warning( + "CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. " + "Hence the whole val set will be loaded into memory." + ) + val_set = load_dict_from_h5(val_set) + + if not key_in_data_set("X_ori", val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") val_set = DatasetForCSAI( val_set, True, @@ -276,23 +303,6 @@ def fit( # collate_fn=collate_fn_bidirectional ) - # Reset the model - self.model = _BCSAI( - self.n_steps, - self.n_features, - self.rnn_hidden_size, - self.step_channels, - self.consistency_weight, - self.imputation_weight, - self.intervals, - ) - - self._send_model_to_given_device() - self._print_model_size() - - # set up the optimizer - self.optimizer.init_optimizer(self.model.parameters()) - # train the model self._train_model(training_loader, val_loader) self.model.load_state_dict(self.best_model_dict) @@ -308,6 +318,13 @@ def predict( ) -> dict: self.model.eval() + + if isinstance(test_set, str): + logger.warning( + "CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. " + "Hence the whole test set will be loaded into memory." + ) + test_set = load_dict_from_h5(test_set) test_set = DatasetForCSAI( test_set, True, From edd144dadf1705a8f1f5d85290d6563fbf32af22 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sun, 27 Oct 2024 22:32:07 +0800 Subject: [PATCH 4/5] fix: lazy loading error for classification CSAI; --- pypots/classification/csai/data.py | 1 + pypots/classification/csai/model.py | 44 +++++++++++++++++------------ pypots/imputation/csai/model.py | 2 +- 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/pypots/classification/csai/data.py b/pypots/classification/csai/data.py index cd829882..3b93765c 100644 --- a/pypots/classification/csai/data.py +++ b/pypots/classification/csai/data.py @@ -6,6 +6,7 @@ # License: BSD-3-Clause from typing import Union + from ...imputation.csai.data import DatasetForCSAI as DatasetForCSAI_Imputation diff --git a/pypots/classification/csai/model.py b/pypots/classification/csai/model.py index a504cc1d..c65a3724 100644 --- a/pypots/classification/csai/model.py +++ b/pypots/classification/csai/model.py @@ -13,8 +13,11 @@ from .core import _BCSAI from .data import DatasetForCSAI from ..base import BaseNNClassifier +from ...data.checking import key_in_data_set +from ...data.saving.h5 import load_dict_from_h5 from ...optim.adam import Adam from ...optim.base import Optimizer +from ...utils.logging import logger class CSAI(BaseNNClassifier): @@ -171,6 +174,7 @@ def __init__( # set up the optimizer self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) def _assemble_input_for_training(self, data: list, training=True) -> dict: # extract data @@ -245,6 +249,12 @@ def fit( file_type: str = "hdf5", ) -> None: # Create dataset + if isinstance(train_set, str): + logger.warning( + "CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. " + "Hence the whole train set will be loaded into memory." + ) + train_set = load_dict_from_h5(train_set) training_set = DatasetForCSAI( data=train_set, file_type=file_type, @@ -267,6 +277,15 @@ def fit( ) val_loader = None if val_set is not None: + if isinstance(val_set, str): + logger.warning( + "CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. " + "Hence the whole val set will be loaded into memory." + ) + val_set = load_dict_from_h5(val_set) + + if not key_in_data_set("X_ori", val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") val_set = DatasetForCSAI( data=val_set, file_type=file_type, @@ -284,24 +303,6 @@ def fit( shuffle=False, num_workers=self.num_workers, ) - # Create model - self.model = _BCSAI( - n_steps=self.n_steps, - n_features=self.n_features, - rnn_hidden_size=self.rnn_hidden_size, - imputation_weight=self.imputation_weight, - consistency_weight=self.consistency_weight, - classification_weight=self.classification_weight, - n_classes=self.n_classes, - step_channels=self.step_channels, - dropout=self.dropout, - intervals=self.intervals, - ) - self._send_model_to_given_device() - self._print_model_size() - - # set up the optimizer - self.optimizer.init_optimizer(self.model.parameters()) # train the model self._train_model(train_loader, val_loader) @@ -317,6 +318,13 @@ def predict( ) -> dict: self.model.eval() + + if isinstance(test_set, str): + logger.warning( + "CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. " + "Hence the whole test set will be loaded into memory." + ) + test_set = load_dict_from_h5(test_set) test_set = DatasetForCSAI( data=test_set, file_type=file_type, diff --git a/pypots/imputation/csai/model.py b/pypots/imputation/csai/model.py index 4eaab839..fe655ea2 100644 --- a/pypots/imputation/csai/model.py +++ b/pypots/imputation/csai/model.py @@ -6,7 +6,6 @@ # License: BSD-3-Clause from typing import Union, Optional -from venv import logger import numpy as np import torch @@ -19,6 +18,7 @@ from ...data.saving.h5 import load_dict_from_h5 from ...optim.adam import Adam from ...optim.base import Optimizer +from ...utils.logging import logger class CSAI(BaseNNImputer): From 81ea218cfbb91be018dc6d133244ced9757593f3 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sun, 27 Oct 2024 23:49:15 +0800 Subject: [PATCH 5/5] refactor: update CSAI default arguments; --- pypots/classification/csai/model.py | 6 +++--- pypots/imputation/csai/model.py | 6 +++--- tests/classification/csai.py | 4 +--- tests/imputation/csai.py | 4 +--- 4 files changed, 8 insertions(+), 12 deletions(-) diff --git a/pypots/classification/csai/model.py b/pypots/classification/csai/model.py index c65a3724..3419c5bb 100644 --- a/pypots/classification/csai/model.py +++ b/pypots/classification/csai/model.py @@ -116,10 +116,10 @@ def __init__( increase_factor: float, compute_intervals: bool, step_channels: int, - batch_size: int, - epochs: int, dropout: float = 0.5, - patience: Union[int, None] = None, + batch_size: int = 32, + epochs: int = 100, + patience: Optional[int] = None, optimizer: Optimizer = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, diff --git a/pypots/imputation/csai/model.py b/pypots/imputation/csai/model.py index fe655ea2..a579fd2c 100644 --- a/pypots/imputation/csai/model.py +++ b/pypots/imputation/csai/model.py @@ -116,9 +116,9 @@ def __init__( increase_factor: float, compute_intervals: bool, step_channels: int, - batch_size: int, - epochs: int, - patience: Union[int, None] = None, + batch_size: int = 32, + epochs: int = 100, + patience: Optional[int] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Union[str, torch.device, list, None] = None, diff --git a/tests/classification/csai.py b/tests/classification/csai.py index 17f1028f..4a2bbf5f 100644 --- a/tests/classification/csai.py +++ b/tests/classification/csai.py @@ -44,7 +44,7 @@ class TestCSAI(unittest.TestCase): n_steps=DATA["n_steps"], n_features=DATA["n_features"], n_classes=DATA["n_classes"], - rnn_hidden_size=32, + rnn_hidden_size=64, imputation_weight=0.7, consistency_weight=0.3, classification_weight=1.0, @@ -52,11 +52,9 @@ class TestCSAI(unittest.TestCase): increase_factor=0.1, compute_intervals=True, step_channels=16, - batch_size=64, epochs=EPOCHS, dropout=0.5, optimizer=optimizer, - num_workers=4, device=DEVICE, saving_path=saving_path, model_saving_strategy="better", diff --git a/tests/imputation/csai.py b/tests/imputation/csai.py index 986657e9..f5c4873b 100644 --- a/tests/imputation/csai.py +++ b/tests/imputation/csai.py @@ -45,17 +45,15 @@ class TestCSAI(unittest.TestCase): csai = CSAI( n_steps=DATA["n_steps"], n_features=DATA["n_features"], - rnn_hidden_size=32, + rnn_hidden_size=64, imputation_weight=0.7, consistency_weight=0.3, removal_percent=10, # Assume we are removing 10% of the data increase_factor=0.1, compute_intervals=True, step_channels=16, - batch_size=64, epochs=EPOCHS, optimizer=optimizer, - num_workers=0, device=DEVICE, saving_path=saving_path, model_saving_strategy="best",