import torch from networks.discriminator import Discriminator from networks.generator import Generator import torch.nn.functional as F from torch import nn, optim import os from vgg19 import VGGLoss from torch.nn.parallel import DistributedDataParallel as DDP def requires_grad(net, flag=True): for p in net.parameters(): p.requires_grad = flag def update_v(state_dict, k, v): if state_dict[k].shape == v.shape: state_dict.update({k: v}) elif state_dict[k].shape[0] == 4: state_dict[k][:3] = v elif state_dict[k].shape[1] == 4: state_dict[k][:, :3] = v elif v.shape[1] == 3: state_dict[k][:, :3] = v else: print(k, state_dict[k].shape, v.shape) class Trainer(nn.Module): def __init__(self, args, device, rank): super(Trainer, self).__init__() self.args = args self.batch_size = args.batch_size self.gen = Generator(args.size, args.latent_dim_style, args.latent_dim_motion, args.channel_multiplier, args.in_channels, args.latent_dim_depth_motion, distilling=args.distilling).to( device) # distributed computing self.gen = DDP(self.gen, device_ids=[rank], find_unused_parameters=True) g_reg_ratio = 1 self.g_optim = optim.Adam( filter(lambda p: p.requires_grad, self.gen.parameters()), lr=args.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio) ) self.g_scheduler = torch.optim.lr_scheduler.StepLR(self.g_optim, step_size=args.lr_freq, gamma=0.2) self.criterion_vgg = VGGLoss().to(rank) self.lambda_loss_l1 = args.lambda_loss_l1 self.lambda_loss_sm = args.lambda_loss_sm self.lambda_loss_gr = args.lambda_loss_gr self.lambda_loss_sp = args.lambda_loss_sp self.gradient_x_weight = torch.Tensor([[0., 0., 0.], [1., 0., -1.], [0., 0., 0.]]).view(1, 1, 3, 3).to(device) self.gradient_y_weight = torch.Tensor([[0., 1., 0.], [0., 0., 0.], [0., -1., 0.]]).view(1, 1, 3, 3).to(device) self.smooth_loss_weight = torch.Tensor([[0., -1., 0.], [-1., 4., -1.], [0., -1., 0.]]).view(1, 1, 3, 3).to(device) self.eps = 1e-6 def g_nonsaturating_loss(self, fake_pred): return F.softplus(-fake_pred).mean() def calc_charbonnier_loss(self, X, Y): diff = X - Y error = torch.sqrt(diff * diff + self.eps) return torch.mean(error) def calc_depth_loss(self, img_recon, img_target): l1_loss = F.l1_loss(img_recon, img_target) * self.lambda_loss_l1 x_grad_recon = F.conv2d(img_recon, weight=self.gradient_x_weight, padding=1) y_grad_recon = F.conv2d(img_recon, weight=self.gradient_y_weight, padding=1) x_grad_target = F.conv2d(img_target, weight=self.gradient_x_weight, padding=1) y_grad_target = F.conv2d(img_target, weight=self.gradient_y_weight, padding=1) # Reconstruction-based Pairwise Depth Dataset for Depth Image Enhancement Using CNN x_mask = (x_grad_target.abs() > 0.1).float() y_mask = (y_grad_target.abs() > 0.1).float() x_grad_recon *= x_mask y_grad_recon *= y_mask x_grad_target *= x_mask y_grad_target *= y_mask xl = self.calc_charbonnier_loss(F.max_pool2d(x_grad_recon.abs(), kernel_size=5, padding=2, stride=1), F.max_pool2d(x_grad_target.abs(), kernel_size=5, padding=2, stride=1)) yl = self.calc_charbonnier_loss(F.max_pool2d(y_grad_recon.abs(), kernel_size=5, padding=2, stride=1), F.max_pool2d(y_grad_target.abs(), kernel_size=5, padding=2, stride=1)) structure_preserve_loss = (xl + yl) * self.lambda_loss_sp lap_recon = F.conv2d(img_recon, weight=self.smooth_loss_weight, padding=1) * (x_grad_target.abs() < 0.1).float() * (y_grad_target.abs() < 0.1).float() lap_target = F.conv2d(img_target, weight=self.smooth_loss_weight, padding=1) * (x_grad_target.abs() < 0.1).float() * (y_grad_target.abs() < 0.1).float() smooth_loss = self.calc_charbonnier_loss(lap_recon, lap_target) * self.lambda_loss_sm depth_loss = {"l1_loss": l1_loss, # "gradient_loss": gradient_loss, "smooth_loss": smooth_loss, "structure_preserve_loss": structure_preserve_loss} return depth_loss def gen_update(self, img_source, img_target, distilling=False): self.gen.train() self.gen.zero_grad() if distilling: img_target_recon, student_recon_result = self.gen(img_source, img_target) student_img_target_recon = student_recon_result['out_inpaint'] img_target_recon = img_target_recon[:, :3, :, :] vgg_loss = self.criterion_vgg(student_img_target_recon, img_target_recon).mean() l1_loss = F.l1_loss(student_img_target_recon, img_target_recon) * self.lambda_loss_l1 student_img_target_warp = student_recon_result['out_warp'] mask = student_recon_result['mask'] vgg_loss_warp = self.criterion_vgg(student_img_target_warp*mask, img_target_recon*mask).mean() l1_loss_warp = F.l1_loss(student_img_target_warp*mask, img_target_recon*mask) * self.lambda_loss_l1 losses = {'l1_loss': l1_loss, 'vgg_loss': vgg_loss, 'l1_loss_warp': l1_loss_warp, 'vgg_loss_warp': vgg_loss_warp} img_target_recon = student_img_target_recon else: img_target_recon = self.gen(img_source, img_target) if img_target.shape[1] > 3: vgg_loss = self.criterion_vgg(img_target_recon[:, :3, :, :], img_target[:, :3, :, :]).mean() losses = self.calc_depth_loss(img_target_recon[:, 3, :, :].unsqueeze(1), img_target[:, 3, :, :].unsqueeze(1)) losses['vgg_loss'] = vgg_loss else: vgg_loss = self.criterion_vgg(img_target_recon, img_target).mean() l1_loss = F.l1_loss(img_target_recon, img_target) * self.lambda_loss_l1 losses = {'l1_loss': l1_loss, 'vgg_loss': vgg_loss} g_loss = sum(losses.values()) g_loss.backward() self.g_optim.step() self.g_scheduler.step() return losses, img_target_recon def sample(self, img_source, img_target, distilling=False): with torch.no_grad(): self.gen.eval() if distilling: _, student_img_recon = self.gen(img_source, img_target) img_recon = student_img_recon img_source_ref, _ = self.gen(img_source, None) img_source_ref = img_source_ref[:, :3, :, :] else: img_recon = self.gen(img_source, img_target) img_source_ref = self.gen(img_source, None) return img_recon, img_source_ref def resume(self, resume_ckpt): print("load model:", resume_ckpt) ckpt = torch.load(resume_ckpt, map_location=torch.device('cpu')) ckpt_name = os.path.basename(resume_ckpt) start_iter = int(os.path.splitext(ckpt_name)[0]) model_weights = self.gen.module.state_dict().copy() for k, v in ckpt["gen"].items(): if k in model_weights: update_v(model_weights, k, v) else: print(k, v.shape) self.gen.module.load_state_dict(model_weights, strict=False) return start_iter def save(self, idx, checkpoint_path): torch.save( { "gen": self.gen.module.state_dict(), "g_optim": self.g_optim.state_dict(), "args": self.args }, f"{checkpoint_path}/{str(idx).zfill(6)}.pt" )