diff --git a/pypots/imputation/brits/modules/core.py b/pypots/imputation/brits/modules/core.py index d4b5ebfa..3fbe79d2 100644 --- a/pypots/imputation/brits/modules/core.py +++ b/pypots/imputation/brits/modules/core.py @@ -172,7 +172,7 @@ def impute( estimations = torch.cat(estimations, dim=1) imputed_data = masks * values + (1 - masks) * estimations - return imputed_data, hidden_states, reconstruction_loss + return imputed_data, estimations, hidden_states, reconstruction_loss def forward(self, inputs: dict, direction: str = "forward") -> dict: """Forward processing of the NN module. @@ -190,7 +190,7 @@ def forward(self, inputs: dict, direction: str = "forward") -> dict: A dictionary includes all results. """ - imputed_data, hidden_state, reconstruction_loss = self.impute(inputs, direction) + imputed_data, estimations, hidden_state, reconstruction_loss = self.impute(inputs, direction) # for each iteration, reconstruction_loss increases its value for 3 times reconstruction_loss /= self.n_steps * 3 @@ -200,6 +200,7 @@ def forward(self, inputs: dict, direction: str = "forward") -> dict: ), # single direction, has no consistency loss "reconstruction_loss": reconstruction_loss, "imputed_data": imputed_data, + "reconstructed_data": estimations, "final_hidden_state": hidden_state, } return ret_dict @@ -304,6 +305,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: ret_b = self._reverse(self.rits_b(inputs, "backward")) imputed_data = (ret_f["imputed_data"] + ret_b["imputed_data"]) / 2 + reconstructed_data = (ret_f["reconstructed_data"] + ret_b["reconstructed_data"]) / 2 results = { "imputed_data": imputed_data, @@ -323,5 +325,8 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # `loss` is always the item for backward propagating to update the model results["loss"] = loss + results['reconstructed_data'] = reconstructed_data + results['f_reconstructed_data'] = ret_f['reconstructed_data'] + results['b_reconstructed_data'] = ret_b['reconstructed_data'] return results diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py index 69989a8b..40220a7b 100644 --- a/pypots/imputation/usgan/model.py +++ b/pypots/imputation/usgan/model.py @@ -242,17 +242,14 @@ def _train_model( try: training_step = 0 - epoch_train_loss_G_collector = [] - epoch_train_loss_D_collector = [] for epoch in range(1, self.epochs + 1): self.model.train() + step_train_loss_G_collector = [] + step_train_loss_D_collector = [] for idx, data in enumerate(training_loader): training_step += 1 inputs = self._assemble_input_for_training(data) - step_train_loss_G_collector = [] - step_train_loss_D_collector = [] - if idx % self.G_steps == 0: self.G_optimizer.zero_grad() results = self.model.forward( @@ -278,9 +275,6 @@ def _train_model( mean_step_train_D_loss = np.mean(step_train_loss_D_collector) mean_step_train_G_loss = np.mean(step_train_loss_G_collector) - epoch_train_loss_D_collector.append(mean_step_train_D_loss) - epoch_train_loss_G_collector.append(mean_step_train_G_loss) - # save training loss logs into the tensorboard file for every step if in need # Note: the `training_step` is not the actual number of steps that Discriminator and Generator get # trained, the actual number should be D_steps*training_step and G_steps*training_step accordingly @@ -292,8 +286,8 @@ def _train_model( self._save_log_into_tb_file( training_step, "training", loss_results ) - mean_epoch_train_D_loss = np.mean(epoch_train_loss_D_collector) - mean_epoch_train_G_loss = np.mean(epoch_train_loss_G_collector) + mean_epoch_train_D_loss = np.mean(step_train_loss_D_collector) + mean_epoch_train_G_loss = np.mean(step_train_loss_G_collector) if val_loader is not None: self.model.eval() diff --git a/pypots/imputation/usgan/modules/core.py b/pypots/imputation/usgan/modules/core.py index 16504d6b..53a1cb6a 100644 --- a/pypots/imputation/usgan/modules/core.py +++ b/pypots/imputation/usgan/modules/core.py @@ -14,6 +14,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from ....utils.metrics import calc_mse from .submodules import Discriminator from ...brits.modules import _BRITS @@ -62,24 +63,23 @@ def forward( if training: forward_X = inputs["forward"]["X"] forward_missing_mask = inputs["forward"]["missing_mask"] - - inputs["discrimination"] = self.discriminator( - forward_X, forward_missing_mask - ) + imputed_data = results['imputed_data'] if training_object == "discriminator": + inputs["discrimination"] = self.discriminator(imputed_data.detach(), forward_missing_mask) l_D = F.binary_cross_entropy_with_logits( inputs["discrimination"], forward_missing_mask ) results["discrimination_loss"] = l_D else: - inputs["discrimination"] = inputs["discrimination"].detach() - l_G = F.binary_cross_entropy_with_logits( + inputs["discrimination"] = self.discriminator(imputed_data, forward_missing_mask) + l_G = -F.binary_cross_entropy_with_logits( inputs["discrimination"], - 1 - forward_missing_mask, + forward_missing_mask, weight=1 - forward_missing_mask, ) - loss_gene = l_G + self.lambda_mse * results["loss"] + reconstruction_loss = calc_mse(forward_X, results['reconstructed_data'], forward_missing_mask) + 0.1 * calc_mse(results['f_reconstructed_data'], results['b_reconstructed_data']) + loss_gene = l_G + self.lambda_mse * reconstruction_loss results["generation_loss"] = loss_gene return results