Skip to content

Commit

Permalink
cleaned up checkpointing and acgan loss
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Fuest committed Oct 14, 2024
1 parent 88df3dc commit 87bbc81
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 17 deletions.
4 changes: 2 additions & 2 deletions config/model_config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
device: 0 # 0, cpu
device: 1 # 0, cpu
seq_len: 96 # should not be changed for the current datasets
input_dim: 2 # or 1 depending on user, but is dynamically set
noise_dim: 256
cond_emb_dim: 64
shuffle: True
sparse_conditioning_loss_weight: 0.8 # sparse conditioning training sample weight for loss computation [0, 1]
sparse_conditioning_loss_weight: 0.5 # sparse conditioning training sample weight for loss computation [0, 1]
freeze_cond_after_warmup: False # specify whether to freeze conditioning module parameters after warmup epochs
save_cycle: 200 # specify number of epochs to save model after

Expand Down
10 changes: 7 additions & 3 deletions generator/diffcharge/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import copy
import math
import os
from functools import partial
from datetime import datetime

import torch
import torch.nn as nn
Expand Down Expand Up @@ -286,6 +286,7 @@ def train_model(self, train_dataset):
os.makedirs(self.opt.results_folder, exist_ok=True)

for epoch in tqdm(range(self.opt.n_epochs), desc="Training"):
self.train_timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.current_epoch = epoch + 1
batch_loss = []
for i, (time_series_batch, conditioning_vars_batch) in enumerate(
Expand Down Expand Up @@ -368,11 +369,14 @@ def train_model(self, train_dataset):
self.lr_scheduler.step(epoch_mean_loss)

if (epoch + 1) % self.opt.save_cycle == 0:
os.mkdir(os.path.join(self.opt.results_folder, self.train_timestamp))

checkpoint_path = os.path.join(
self.opt.results_folder, f"ddpm_checkpoint_epoch_{epoch + 1}.pt"
os.path.join(self.opt.results_folder, self.train_timestamp),
f"diffcharge_checkpoint_{epoch + 1}.pt",
)

self.save(checkpoint_path, self.current_epoch)
print(f"Saved checkpoint at {checkpoint_path}.")

print("Training complete")
self.writer.close()
Expand Down
8 changes: 7 additions & 1 deletion generator/diffusion_ts/gaussian_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import copy
import math
import os
from datetime import datetime
from functools import partial

import torch
Expand Down Expand Up @@ -384,6 +385,7 @@ def forward(self, x, conditioning_vars=None, **kwargs):
)

def train_model(self, train_dataset):
self.train_timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.train()
self.to(self.device)

Expand Down Expand Up @@ -504,10 +506,14 @@ def train_model(self, train_dataset):
self.scheduler.step(total_loss)

if (epoch + 1) % self.opt.save_cycle == 0:
os.mkdir(os.path.join(self.opt.results_folder, self.train_timestamp))

checkpoint_path = os.path.join(
self.opt.results_folder, f"checkpoint-{epoch + 1}.pt"
os.path.join(self.opt.results_folder, self.train_timestamp),
f"diffusion_ts_checkpoint_{epoch + 1}.pt",
)
self.save(checkpoint_path, self.current_epoch)

print(f"Saved checkpoint at {checkpoint_path}.")

print("Training complete")
Expand Down
32 changes: 21 additions & 11 deletions generator/gan/acgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""

import os
from datetime import datetime

import torch
import torch.nn as nn
Expand Down Expand Up @@ -195,6 +196,7 @@ def __init__(self, opt):
)

def train_model(self, dataset):
self.train_timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
batch_size = self.opt.batch_size
num_epoch = self.opt.n_epochs
train_loader = prepare_dataloader(dataset, batch_size)
Expand Down Expand Up @@ -283,7 +285,6 @@ def train_model(self, dataset):
dataset, current_batch_size, random=True
)
generated_time_series = self.generator(noise, gen_categorical_vars)

validity, aux_outputs = self.discriminator(generated_time_series)

g_loss_rare = self.adversarial_loss(
Expand All @@ -297,6 +298,18 @@ def train_model(self, dataset):
* torch.logical_not(rare_mask)
* soft_one,
)

if self.opt.include_auxiliary_losses:
for var_name in self.categorical_dims.keys():
labels = gen_categorical_vars[var_name]
g_loss_rare += self.auxiliary_loss(
aux_outputs[var_name] * rare_mask, labels * rare_mask
)
g_loss_non_rare += self.auxiliary_loss(
aux_outputs[var_name] * (torch.logical_not(rare_mask)),
labels * (torch.logical_not(rare_mask)),
)

_lambda = self.sparse_conditioning_loss_weight
N_r = rare_mask.sum().item()
N_nr = (torch.logical_not(rare_mask)).sum().item()
Expand All @@ -306,18 +319,12 @@ def train_model(self, dataset):
+ (1 - _lambda) * (N_nr / N) * g_loss_non_rare
)

if self.opt.include_auxiliary_losses:
for var_name in self.categorical_dims.keys():
labels = gen_categorical_vars[var_name]
g_loss += self.auxiliary_loss(aux_outputs[var_name], labels)

g_loss.backward()
self.optimizer_G.step()

# -------------------
# TensorBoard Logging
# TensorBoard Loss Logging
# -------------------
global_step = epoch * len(train_loader) + batch_index
# global_step = epoch * len(train_loader) + batch_index

# Log overall losses for both generator and discriminator
# self.writer.add_scalars('Losses', {'Discriminator': d_loss.item(), 'Generator': g_loss.item()}, global_step)
Expand All @@ -340,11 +347,14 @@ def train_model(self, dataset):
)

if (epoch + 1) % self.opt.save_cycle == 0:
os.mkdir(os.path.join(self.opt.results_folder, self.train_timestamp))

checkpoint_path = os.path.join(
self.opt.results_folder, f"acgan_checkpoint_epoch_{epoch + 1}.pt"
os.path.join(self.opt.results_folder, self.train_timestamp),
f"acgan_checkpoint_{epoch + 1}.pt",
)

self.save(checkpoint_path, self.current_epoch)
print(f"Saved checkpoint at {checkpoint_path}.")

self.writer.close()

Expand Down

0 comments on commit 87bbc81

Please sign in to comment.