Skip to content

Commit

Permalink
updating the validation part
Browse files Browse the repository at this point in the history
  • Loading branch information
victorcaquilpan committed Jul 25, 2024
1 parent 7b95524 commit 0986b32
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 5 deletions.
2 changes: 2 additions & 0 deletions options/train_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def initialize(self, parser):
parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
parser.add_argument('--wdb_disabled', action='store_true', default=False, help='Wandb logging')

# network saving and loading parameters
parser.add_argument('--save_latest_freq', type=int, default=10000, help='frequency of saving the latest results')
parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ opencv-python==3.4.8.29
pillow
wandb
protobuf==3.20.*
elasticdeform==0.5.1
elasticdeform==0.5.1
torchmetrics
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@

if __name__ == '__main__':
opt = TrainOptions().parse() # get training options
wandb.init(project="testing-maskgan", name=opt.name)
if not opt.wdb_disabled:
wandb.init(project="testing-maskgan", name=opt.name)
train_dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
dataset_size = len(train_dataset) # get the number of images in the dataset.
print('The number of training images = %d' % dataset_size)
Expand Down
5 changes: 2 additions & 3 deletions util/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import torch
import numpy as np
import wandb
from torchmetrics.image import PeakSignalNoiseRatio
import monai
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure

def val_visualizations_over_batches(real_a,
real_b,
Expand Down Expand Up @@ -41,7 +40,7 @@ def validation(val_set, model, opt):
# Getting MSE
metric_mse = torch.nn.MSELoss()
# Getting SSIM
metric_ssim = monai.metrics.SSIMMetric(spatial_dims=2, reduction = 'mean')
metric_ssim = StructuralSimilarityIndexMeasure(data_range=None,reduction="elementwise_mean")
# Getting PSNR
metric_psnr = PeakSignalNoiseRatio()

Expand Down

0 comments on commit 0986b32

Please sign in to comment.