-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_utp_ddp.py
131 lines (111 loc) · 5.87 KB
/
train_utp_ddp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Tutorial DDP: https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html
import math
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import wandb
import config
import utils
import train_utils
from dataset.pannuke import PanNuke
from discriminator_model import Discriminator
from generator_model import Generator
transform_train = transforms.Compose([
transforms.FiveCrop(256),
transforms.Lambda(lambda crops: torch.stack([transforms.RandomHorizontalFlip()(crop) for crop in crops])),
transforms.Lambda(lambda crops: torch.stack([transforms.RandomVerticalFlip()(crop) for crop in crops])),
transforms.Lambda(lambda crops: torch.stack([utils.RandomRotate90()(crop) for crop in crops])),
])
transform_test = transforms.Compose([transforms.RandomCrop(1024)])
WANDB_PROJECT_NAME = "unitopatho-generative"
def main(gpu):
print(f"GPU #{gpu} started")
# DDP
world_size = config.NGPU * config.NUM_NODES
nr = 0 # it is the rank of the current node. Now we use only one node
rank = nr * config.NGPU + gpu
utils.setup_ddp(rank, world_size)
torch.cuda.set_device(gpu)
is_master = rank == 0
do_wandb_log = config.LOG_WANDB and is_master # only master logs on wandb.
if do_wandb_log:
train_utils.wandb_init(config.WANDB_KEY_LOGIN, WANDB_PROJECT_NAME)
# Load models
num_classes = len(PanNuke.labels())
disc = Discriminator(in_channels=3 + num_classes).cuda(gpu)
gen = Generator(in_channels=num_classes, features=64).cuda(gpu)
# Use SynchBatchNorm for Multi-GPU trainings
disc = nn.SyncBatchNorm.convert_sync_batchnorm(disc)
gen = nn.SyncBatchNorm.convert_sync_batchnorm(gen)
if do_wandb_log:
print(disc)
print(gen)
# DDP
disc = nn.parallel.DistributedDataParallel(disc, device_ids=[gpu])
gen = nn.parallel.DistributedDataParallel(gen, device_ids=[gpu])
# Optimizers
opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(config.ADAM_BETA1, config.ADAM_BETA2))
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(config.ADAM_BETA1, config.ADAM_BETA2))
# Losses
bce = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
# Load checkpoints from wandb
if config.LOAD_MODEL:
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
wandb_run_path = "daviderubi/pix2pixgan/1l0hnnnn" # The wandb run is daviderubi/pix2pixgan/upbeat-river-42
train_utils.wandb_load_model(wandb_run_path, "disc.pth", disc, opt_disc, config.LEARNING_RATE, map_location)
train_utils.wandb_load_model(wandb_run_path, "gen.pth", gen, opt_gen, config.LEARNING_RATE, map_location)
# load dataset
train_dataset, test_dataset = train_utils.load_dataset_UTP(transform_train, transform_test)
# DistributedSampler
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=False, batch_size=config.BATCH_SIZE,
num_workers=config.NUM_WORKERS, sampler=train_sampler)
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas=world_size, rank=rank)
test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, batch_size=config.BATCH_SIZE,
num_workers=config.NUM_WORKERS, sampler=test_sampler)
# grad_scaler
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
if do_wandb_log:
# Get some images from testloader. Every epoch we will log the generated images for this batch on wandb.
test_batch_im, test_batch_masks = train_utils.wandb_get_images_to_log(test_loader)
img_masks_test = [PanNuke.get_img_mask(mask) for mask in test_batch_masks]
wandb.log({"Reals": wandb.Image(torchvision.utils.make_grid(test_batch_im), caption="Reals"),
"Masks": wandb.Image(torchvision.utils.make_grid(img_masks_test), caption="Masks")})
# Training loop
for epoch in range(config.NUM_EPOCHS):
g_adv_loss, g_l1_loss, d_loss = train_utils.train_epoch(disc, gen, train_loader, opt_disc, opt_gen, l1_loss,
bce, g_scaler, d_scaler, gpu)
# Save checkpoint.
if config.SAVE_MODEL and (epoch + 1) % 10 == 0 and is_master:
print(f"Saving checkpoint at epoch {epoch + 1}...")
utils.save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN, epoch=epoch + 1)
utils.save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC, epoch=epoch + 1)
if do_wandb_log:
wandb.save(config.CHECKPOINT_GEN)
wandb.save(config.CHECKPOINT_DISC)
# Log generated images after the training epoch.
if do_wandb_log:
train_utils.wandb_log_epoch(gen, test_batch_masks, g_adv_loss, g_l1_loss, d_loss)
# Save generator and discriminator models.
if is_master:
utils.save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN, epoch=config.NUM_EPOCHS)
utils.save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC, epoch=config.NUM_EPOCHS)
# Log on wandb some generated images.
if do_wandb_log:
train_utils.wandb_log_generated_images(gen, test_loader, batch_to_log=math.ceil(100 / config.BATCH_SIZE))
wandb.finish()
torch.distributed.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
print(f"Working on {config.DEVICE} device.")
if "cuda" in str(config.DEVICE):
for idx in range(torch.cuda.device_count()):
print(torch.cuda.get_device_properties(idx))
# DistributedDataParallel
mp.spawn(main, nprocs=config.NGPU, args=())