Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CSAI cannot accept dataset files for lazy loading #545

Merged
merged 5 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions pypots/classification/csai/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,6 @@

from ...nn.modules.csai import BackboneBCSAI

# class DiceBCELoss(nn.Module):
# def __init__(self, weight=None, size_average=True):
# super(DiceBCELoss, self).__init__()
# self.bcelogits = nn.BCEWithLogitsLoss()

# def forward(self, y_score, y_out, targets, smooth=1):

# #comment out if your model contains a sigmoid or equivalent activation layer
# # inputs = F.sigmoid(inputs)

# #flatten label and prediction tensors
# BCE = self.bcelogits(y_out, targets)

# y_score = y_score.view(-1)
# targets = targets.view(-1)
# intersection = (y_score * targets).sum()
# dice_loss = 1 - (2.*intersection + smooth)/(y_score.sum() + targets.sum() + smooth)

# Dice_BCE = BCE + dice_loss

# return BCE, Dice_BCE


class _BCSAI(nn.Module):
def __init__(
Expand Down
1 change: 1 addition & 0 deletions pypots/classification/csai/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# License: BSD-3-Clause

from typing import Union

from ...imputation.csai.data import DatasetForCSAI as DatasetForCSAI_Imputation


Expand Down
107 changes: 65 additions & 42 deletions pypots/classification/csai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
# License: BSD-3-Clause

from typing import Optional, Union
import numpy as np

import torch
from torch.utils.data import DataLoader

from .core import _BCSAI
from .data import DatasetForCSAI
from ..base import BaseNNClassifier
from ...data.checking import key_in_data_set
from ...data.saving.h5 import load_dict_from_h5
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.logging import logger


class CSAI(BaseNNClassifier):
Expand Down Expand Up @@ -60,31 +63,43 @@ class CSAI(BaseNNClassifier):
The batch size for training and evaluating the model.

epochs :
The number of epochs for training the model.

dropout :
The dropout rate for the model to prevent overfitting. Default is 0.5.
The number of epochs for training the model.

patience :
The patience for the early-stopping mechanism. Given a positive integer, the training process will be stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping.
The patience for the early-stopping mechanism. Given a positive integer, the training process will be
stopped when the model does not perform better after that number of epochs.
Leaving it default as None will disable the early-stopping.

optimizer :
The optimizer for model training. If not given, will use a default Adam optimizer.
The optimizer for model training.
If not given, will use a default Adam optimizer.

num_workers :
The number of subprocesses to use for data loading. 0 means data loading will be in the main process, i.e. there won't be subprocesses.
The number of subprocesses to use for data loading.
`0` means data loading will be in the main process, i.e. there won't be subprocesses.

device :
The device for the model to run on. It can be a string, a :class:torch.device object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')], the model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.
The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them.
If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple),
then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models.
If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the
model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices).
Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.

saving_path :
The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during training into a tensorboard file). Will not save if not given.
The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during
training into a tensorboard file). Will not save if not given.

model_saving_strategy :
The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. No model will be saved when it is set as None. The "best" strategy will only automatically save the best model after the training finished. The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. The "all" strategy will save every model after each epoch training.
The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"].
No model will be saved when it is set as None.
The "best" strategy will only automatically save the best model after the training finished.
The "better" strategy will automatically save the model during training whenever the model performs
better than in previous epochs.
The "all" strategy will save every model after each epoch training.

verbose :
Whether to print out the training logs during the training process.
Whether to print out the training logs during the training process.

"""

Expand All @@ -101,10 +116,10 @@ def __init__(
increase_factor: float,
compute_intervals: bool,
step_channels: int,
batch_size: int,
epochs: int,
dropout: float = 0.5,
patience: Union[int, None] = None,
batch_size: int = 32,
epochs: int = 100,
patience: Optional[int] = None,
optimizer: Optimizer = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down Expand Up @@ -136,6 +151,9 @@ def __init__(
self.compute_intervals = compute_intervals
self.dropout = dropout
self.intervals = None
self.replacement_probabilities = None
self.mean_set = None
self.std_set = None

# Initialise empty model
self.model = _BCSAI(
Expand All @@ -156,6 +174,7 @@ def __init__(

# set up the optimizer
self.optimizer = optimizer
self.optimizer.init_optimizer(self.model.parameters())

def _assemble_input_for_training(self, data: list, training=True) -> dict:
# extract data
Expand Down Expand Up @@ -230,7 +249,13 @@ def fit(
file_type: str = "hdf5",
) -> None:
# Create dataset
self.training_set = DatasetForCSAI(
if isinstance(train_set, str):
logger.warning(
"CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. "
"Hence the whole train set will be loaded into memory."
)
train_set = load_dict_from_h5(train_set)
training_set = DatasetForCSAI(
data=train_set,
file_type=file_type,
return_y=True,
Expand All @@ -239,19 +264,28 @@ def fit(
compute_intervals=self.compute_intervals,
)

self.intervals = self.training_set.intervals
self.replacement_probabilities = self.training_set.replacement_probabilities
self.mean_set = self.training_set.mean_set
self.std_set = self.training_set.std_set
self.intervals = training_set.intervals
self.replacement_probabilities = training_set.replacement_probabilities
self.mean_set = training_set.mean_set
self.std_set = training_set.std_set

train_loader = DataLoader(
self.training_set,
training_set,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
)
val_loader = None
if val_set is not None:
if isinstance(val_set, str):
logger.warning(
"CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. "
"Hence the whole val set will be loaded into memory."
)
val_set = load_dict_from_h5(val_set)

if not key_in_data_set("X_ori", val_set):
raise ValueError("val_set must contain 'X_ori' for model validation.")
val_set = DatasetForCSAI(
data=val_set,
file_type=file_type,
Expand All @@ -269,24 +303,6 @@ def fit(
shuffle=False,
num_workers=self.num_workers,
)
# Create model
self.model = _BCSAI(
n_steps=self.n_steps,
n_features=self.n_features,
rnn_hidden_size=self.rnn_hidden_size,
imputation_weight=self.imputation_weight,
consistency_weight=self.consistency_weight,
classification_weight=self.classification_weight,
n_classes=self.n_classes,
step_channels=self.step_channels,
dropout=self.dropout,
intervals=self.intervals,
)
self._send_model_to_given_device()
self._print_model_size()

# set up the optimizer
self.optimizer.init_optimizer(self.model.parameters())

# train the model
self._train_model(train_loader, val_loader)
Expand All @@ -302,6 +318,13 @@ def predict(
) -> dict:

self.model.eval()

if isinstance(test_set, str):
logger.warning(
"CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. "
"Hence the whole test set will be loaded into memory."
)
test_set = load_dict_from_h5(test_set)
test_set = DatasetForCSAI(
data=test_set,
file_type=file_type,
Expand All @@ -321,15 +344,15 @@ def predict(
num_workers=self.num_workers,
)

classificaion_results = []
classification_results = []

with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
classificaion_results.append(results["classification_pred"])
classification_results.append(results["classification_pred"])

classification = torch.cat(classificaion_results).cpu().detach().numpy()
classification = torch.cat(classification_results).cpu().detach().numpy()
result_dict = {
"classification": classification,
}
Expand Down
5 changes: 4 additions & 1 deletion pypots/imputation/csai/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ class _BCSAI(nn.Module):

Notes
-----
BCSAI is a bidirectional imputation model that uses forward and backward GRU cells to handle time-series data. It computes consistency and reconstruction losses to improve imputation accuracy. During training, the forward and backward reconstructions are combined, and losses are used to update the model. In evaluation mode, the model also outputs original data and indicating masks for further analysis.
CSAI is a bidirectional imputation model that uses forward and backward GRU cells to handle time-series data.
It computes consistency and reconstruction losses to improve imputation accuracy.
During training, the forward and backward reconstructions are combined, and losses are used to update the model.
In evaluation mode, the model also outputs original data and indicating masks for further analysis.

"""

Expand Down
Loading
Loading