-
Notifications
You must be signed in to change notification settings - Fork 72
/
summaries.py
23 lines (20 loc) · 1.16 KB
/
summaries.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import os
import torch
from torchvision.utils import make_grid
from tensorboardX import SummaryWriter
from dataloaders.dataloader_utils import decode_seg_map_sequence
class TensorboardSummary(object):
def __init__(self, directory):
self.directory = directory
def create_summary(self):
writer = SummaryWriter(log_dir=os.path.join(self.directory))
return writer
def visualize_image(self, writer, dataset, image, target, output, global_step):
grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True)
writer.add_image('Image', grid_image, global_step)
grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy(),
dataset=dataset), 3, normalize=False, range=(0, 255))
writer.add_image('Predicted label', grid_image, global_step)
grid_image = make_grid(decode_seg_map_sequence(torch.squeeze(target[:3], 1).detach().cpu().numpy(),
dataset=dataset), 3, normalize=False, range=(0, 255))
writer.add_image('Groundtruth label', grid_image, global_step)