Skip to content

Commit

Permalink
refactor: clean linting issues;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Oct 25, 2024
1 parent 0195cbb commit bf045b4
Show file tree
Hide file tree
Showing 14 changed files with 383 additions and 408 deletions.
2 changes: 1 addition & 1 deletion pypots/classification/csai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@

__all__ = [
"CSAI",
]
]
40 changes: 20 additions & 20 deletions pypots/classification/csai/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
# self.bcelogits = nn.BCEWithLogitsLoss()

# def forward(self, y_score, y_out, targets, smooth=1):

# #comment out if your model contains a sigmoid or equivalent activation layer
# # inputs = F.sigmoid(inputs)
# # inputs = F.sigmoid(inputs)

# #flatten label and prediction tensors
# BCE = self.bcelogits(y_out, targets)

Expand All @@ -30,23 +30,23 @@
# dice_loss = 1 - (2.*intersection + smooth)/(y_score.sum() + targets.sum() + smooth)

# Dice_BCE = BCE + dice_loss

# return BCE, Dice_BCE


class _BCSAI(nn.Module):
def __init__(
self,
n_steps: int,
n_features: int,
rnn_hidden_size: int,
imputation_weight: float,
consistency_weight: float,
classification_weight: float,
n_classes: int,
step_channels: int,
dropout: float = 0.5,
intervals=None,
self,
n_steps: int,
n_features: int,
rnn_hidden_size: int,
imputation_weight: float,
consistency_weight: float,
classification_weight: float,
n_classes: int,
step_channels: int,
dropout: float = 0.5,
intervals=None,
):
super().__init__()
self.n_steps = n_steps
Expand Down Expand Up @@ -107,17 +107,17 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
b_classification_loss = F.nll_loss(torch.log(b_prediction), inputs["labels"])
# f_classification_loss, _ = criterion(f_prediction, f_logits, inputs["labels"].unsqueeze(1).float())
# b_classification_loss, _ = criterion(b_prediction, b_logits, inputs["labels"].unsqueeze(1).float())
classification_loss = (f_classification_loss + b_classification_loss)
classification_loss = f_classification_loss + b_classification_loss

loss = (
self.consistency_weight * consistency_loss +
self.imputation_weight * reconstruction_loss +
self.classification_weight * classification_loss
self.consistency_weight * consistency_loss
+ self.imputation_weight * reconstruction_loss
+ self.classification_weight * classification_loss
)

results["loss"] = loss
results["classification_loss"] = classification_loss
results["f_reconstruction"] = f_reconstruction
results["b_reconstruction"] = b_reconstruction

return results
return results
31 changes: 15 additions & 16 deletions pypots/classification/csai/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@
# License: BSD-3-Clause

from typing import Union
from ...imputation.csai.data import DatasetForCSAI as DatasetForCSAI_Imputation

from ...imputation.csai.data import DatasetForCSAI as DatasetForCSAI_Imputation


class DatasetForCSAI(DatasetForCSAI_Imputation):
def __init__(self,
data: Union[dict, str],
file_type: str = "hdf5",
return_y: bool = True,
removal_percent: float = 0.0,
increase_factor: float = 0.1,
compute_intervals: bool = False,
replacement_probabilities = None,
normalise_mean : list = [],
normalise_std: list = [],
training: bool = True
def __init__(
self,
data: Union[dict, str],
file_type: str = "hdf5",
return_y: bool = True,
removal_percent: float = 0.0,
increase_factor: float = 0.1,
compute_intervals: bool = False,
replacement_probabilities=None,
normalise_mean: list = [],
normalise_std: list = [],
training: bool = True,
):
super().__init__(
data=data,
Expand All @@ -34,6 +34,5 @@ def __init__(self,
replacement_probabilities=replacement_probabilities,
normalise_mean=normalise_mean,
normalise_std=normalise_std,
training=training
)

training=training,
)
131 changes: 59 additions & 72 deletions pypots/classification/csai/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
"""
Expand All @@ -19,7 +18,6 @@


class CSAI(BaseNNClassifier):

"""
The PyTorch implementation of the CSAI model.
Expand Down Expand Up @@ -87,7 +85,7 @@ class CSAI(BaseNNClassifier):
verbose :
Whether to print out the training logs during the training process.
"""

def __init__(
Expand All @@ -99,33 +97,33 @@ def __init__(
consistency_weight: float,
classification_weight: float,
n_classes: int,
removal_percent: int,
increase_factor: float,
compute_intervals: bool,
step_channels:int,
batch_size: int,
epochs: int,
removal_percent: int,
increase_factor: float,
compute_intervals: bool,
step_channels: int,
batch_size: int,
epochs: int,
dropout: float = 0.5,
patience: Union[int, None] = None,
optimizer: Optimizer = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
patience: Union[int, None] = None,
optimizer: Optimizer = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
model_saving_strategy: Union[str, None] = "best",
verbose: bool = True
model_saving_strategy: Union[str, None] = "best",
verbose: bool = True,
):
super().__init__(
n_classes,
batch_size,
epochs,
n_classes,
batch_size,
epochs,
patience,
num_workers,
num_workers,
device,
saving_path,
model_saving_strategy,
saving_path,
model_saving_strategy,
verbose,
)

self.n_steps = n_steps
self.n_features = n_features
self.rnn_hidden_size = rnn_hidden_size
Expand All @@ -138,8 +136,8 @@ def __init__(
self.compute_intervals = compute_intervals
self.dropout = dropout
self.intervals = None
# Initialise empty model

# Initialise empty model
self.model = _BCSAI(
n_steps=self.n_steps,
n_features=self.n_features,
Expand All @@ -161,19 +159,10 @@ def __init__(

def _assemble_input_for_training(self, data: list, training=True) -> dict:
# extract data
sample = data['sample']
(
indices,
X,
missing_mask,
deltas,
last_obs,
back_X,
back_missing_mask,
back_deltas,
back_last_obs,
labels
) = self._send_data_to_given_device(sample)
sample = data["sample"]
(indices, X, missing_mask, deltas, last_obs, back_X, back_missing_mask, back_deltas, back_last_obs, labels) = (
self._send_data_to_given_device(sample)
)

inputs = {
"indices": indices,
Expand All @@ -195,10 +184,10 @@ def _assemble_input_for_training(self, data: list, training=True) -> dict:

def _assemble_input_for_validating(self, data: list) -> dict:
return self._assemble_input_for_training(data)

def _assemble_input_for_testing(self, data: list) -> dict:
# extract data
sample = data['sample']
sample = data["sample"]
(
indices,
X,
Expand Down Expand Up @@ -231,30 +220,30 @@ def _assemble_input_for_testing(self, data: list) -> dict:
# "X_ori": X_ori,
# "indicating_mask": indicating_mask,
}

return inputs

def fit(
self,
train_set,
val_set= None,
file_type: str = "hdf5",
)-> None:
self,
train_set,
val_set=None,
file_type: str = "hdf5",
) -> None:
# Create dataset
self.training_set = DatasetForCSAI(
data=train_set,
file_type=file_type,
return_y=True,
removal_percent=self.removal_percent,
increase_factor=self.increase_factor,
compute_intervals=self.compute_intervals,
)
data=train_set,
file_type=file_type,
return_y=True,
removal_percent=self.removal_percent,
increase_factor=self.increase_factor,
compute_intervals=self.compute_intervals,
)

self.intervals = self.training_set.intervals
self.replacement_probabilities = self.training_set.replacement_probabilities
self.mean_set = self.training_set.mean_set
self.std_set = self.training_set.std_set

train_loader = DataLoader(
self.training_set,
batch_size=self.batch_size,
Expand Down Expand Up @@ -297,7 +286,7 @@ def fit(
self._print_model_size()

# set up the optimizer
self.optimizer.init_optimizer(self.model.parameters())
self.optimizer.init_optimizer(self.model.parameters())

# train the model
self._train_model(train_loader, val_loader)
Expand All @@ -306,13 +295,12 @@ def fit(

self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")


def predict(
self,
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> dict:
self,
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> dict:

self.model.eval()
test_set = DatasetForCSAI(
data=test_set,
Expand All @@ -339,20 +327,19 @@ def predict(
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
classificaion_results.append(results['classification_pred'])


classificaion_results.append(results["classification_pred"])

classification = torch.cat(classificaion_results).cpu().detach().numpy()
result_dict = {
"classification": classification,
}
}
return result_dict

def classify(
self,
test_set,
file_type: str = "hdf5",
):
self,
test_set,
file_type: str = "hdf5",
):

result_dict = self.predict(test_set, file_type)
return result_dict['classification']
return result_dict["classification"]
2 changes: 1 addition & 1 deletion pypots/imputation/csai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@

__all__ = [
"CSAI",
]
]
Loading

0 comments on commit bf045b4

Please sign in to comment.