Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

debug USGAN #339

Merged
merged 4 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions pypots/imputation/brits/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def impute(

estimations = torch.cat(estimations, dim=1)
imputed_data = masks * values + (1 - masks) * estimations
return imputed_data, hidden_states, reconstruction_loss
return imputed_data, estimations, hidden_states, reconstruction_loss

def forward(self, inputs: dict, direction: str = "forward") -> dict:
"""Forward processing of the NN module.
Expand All @@ -190,7 +190,7 @@ def forward(self, inputs: dict, direction: str = "forward") -> dict:
A dictionary includes all results.

"""
imputed_data, hidden_state, reconstruction_loss = self.impute(inputs, direction)
imputed_data, estimations, hidden_state, reconstruction_loss = self.impute(inputs, direction)
# for each iteration, reconstruction_loss increases its value for 3 times
reconstruction_loss /= self.n_steps * 3

Expand All @@ -200,6 +200,7 @@ def forward(self, inputs: dict, direction: str = "forward") -> dict:
), # single direction, has no consistency loss
"reconstruction_loss": reconstruction_loss,
"imputed_data": imputed_data,
"reconstructed_data": estimations,
"final_hidden_state": hidden_state,
}
return ret_dict
Expand Down Expand Up @@ -304,6 +305,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
ret_b = self._reverse(self.rits_b(inputs, "backward"))

imputed_data = (ret_f["imputed_data"] + ret_b["imputed_data"]) / 2
reconstructed_data = (ret_f["reconstructed_data"] + ret_b["reconstructed_data"]) / 2

results = {
"imputed_data": imputed_data,
Expand All @@ -323,5 +325,8 @@ def forward(self, inputs: dict, training: bool = True) -> dict:

# `loss` is always the item for backward propagating to update the model
results["loss"] = loss
results['reconstructed_data'] = reconstructed_data
results['f_reconstructed_data'] = ret_f['reconstructed_data']
results['b_reconstructed_data'] = ret_b['reconstructed_data']

return results
14 changes: 4 additions & 10 deletions pypots/imputation/usgan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,17 +242,14 @@ def _train_model(

try:
training_step = 0
epoch_train_loss_G_collector = []
epoch_train_loss_D_collector = []
for epoch in range(1, self.epochs + 1):
self.model.train()
step_train_loss_G_collector = []
step_train_loss_D_collector = []
for idx, data in enumerate(training_loader):
training_step += 1
inputs = self._assemble_input_for_training(data)

step_train_loss_G_collector = []
step_train_loss_D_collector = []

if idx % self.G_steps == 0:
self.G_optimizer.zero_grad()
results = self.model.forward(
Expand All @@ -278,9 +275,6 @@ def _train_model(
mean_step_train_D_loss = np.mean(step_train_loss_D_collector)
mean_step_train_G_loss = np.mean(step_train_loss_G_collector)

epoch_train_loss_D_collector.append(mean_step_train_D_loss)
epoch_train_loss_G_collector.append(mean_step_train_G_loss)

# save training loss logs into the tensorboard file for every step if in need
# Note: the `training_step` is not the actual number of steps that Discriminator and Generator get
# trained, the actual number should be D_steps*training_step and G_steps*training_step accordingly
Expand All @@ -292,8 +286,8 @@ def _train_model(
self._save_log_into_tb_file(
training_step, "training", loss_results
)
mean_epoch_train_D_loss = np.mean(epoch_train_loss_D_collector)
mean_epoch_train_G_loss = np.mean(epoch_train_loss_G_collector)
mean_epoch_train_D_loss = np.mean(step_train_loss_D_collector)
mean_epoch_train_G_loss = np.mean(step_train_loss_G_collector)

if val_loader is not None:
self.model.eval()
Expand Down
16 changes: 8 additions & 8 deletions pypots/imputation/usgan/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ....utils.metrics import calc_mse

from .submodules import Discriminator
from ...brits.modules import _BRITS
Expand Down Expand Up @@ -62,24 +63,23 @@ def forward(
if training:
forward_X = inputs["forward"]["X"]
forward_missing_mask = inputs["forward"]["missing_mask"]

inputs["discrimination"] = self.discriminator(
forward_X, forward_missing_mask
)
imputed_data = results['imputed_data']

if training_object == "discriminator":
inputs["discrimination"] = self.discriminator(imputed_data.detach(), forward_missing_mask)
l_D = F.binary_cross_entropy_with_logits(
inputs["discrimination"], forward_missing_mask
)
results["discrimination_loss"] = l_D
else:
inputs["discrimination"] = inputs["discrimination"].detach()
l_G = F.binary_cross_entropy_with_logits(
inputs["discrimination"] = self.discriminator(imputed_data, forward_missing_mask)
l_G = -F.binary_cross_entropy_with_logits(
inputs["discrimination"],
1 - forward_missing_mask,
forward_missing_mask,
weight=1 - forward_missing_mask,
)
loss_gene = l_G + self.lambda_mse * results["loss"]
reconstruction_loss = calc_mse(forward_X, results['reconstructed_data'], forward_missing_mask) + 0.1 * calc_mse(results['f_reconstructed_data'], results['b_reconstructed_data'])
loss_gene = l_G + self.lambda_mse * reconstruction_loss
results["generation_loss"] = loss_gene

return results
Loading