From b4c3cfde314f1aef5b59e4d4c4dfec7e8f82ac2b Mon Sep 17 00:00:00 2001 From: pnsuau Date: Tue, 15 Nov 2022 14:04:45 +0000 Subject: [PATCH] feat: export for unet_mha --- models/base_model.py | 78 ++++++++++++------- .../unet_generator_attn.py | 2 +- scripts/export_onnx_model.py | 10 +-- 3 files changed, 57 insertions(+), 33 deletions(-) diff --git a/models/base_model.py b/models/base_model.py index f8cd5a3cf..8eb37167c 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -227,7 +227,7 @@ def __init__(self, opt, rank): if "segformer" in self.opt.G_netG: self.onnx_opset_version = 11 - elif "ittr" in self.opt.G_netG: + elif "ittr" in self.opt.G_netG or "unet_mha" in self.opt.G_netG: self.onnx_opset_version = 12 else: self.onnx_opset_version = 9 @@ -795,39 +795,63 @@ def export_networks(self, epoch): net = getattr(self, "net" + name) - input_nc = self.opt.model_input_nc - if self.opt.model_multimodal: - input_nc += self.opt.train_mm_nz - - dummy_input = torch.randn( - 1, - input_nc, - self.opt.data_crop_size, - self.opt.data_crop_size, - device=self.device, - ) - # onnx + if ( not "ittr" in self.opt.G_netG and not "unet_mha" in self.opt.G_netG and not "palette" in self.opt.model_type ): - export_path_onnx = save_path.replace(".pth", ".onnx") - - torch.onnx.export( - net, - dummy_input, - export_path_onnx, - verbose=False, - opset_version=self.onnx_opset_version, - ) + input_nc = self.opt.model_input_nc + if self.opt.model_multimodal: + input_nc += self.opt.train_mm_nz + + dummy_input = torch.randn( + 1, + input_nc, + self.opt.data_crop_size, + self.opt.data_crop_size, + device=self.device, + ) + + # onnx + if not "ittr" in self.opt.G_netG: + export_path_onnx = save_path.replace(".pth", ".onnx") + + export_device = torch.device("cpu") + net = net.to(export_device) + + torch.onnx.export( + net, + self.get_dummy_input(export_device), + export_path_onnx, + verbose=False, + opset_version=self.onnx_opset_version, + ) + net.to(self.device) + + # jit + if self.opt.train_export_jit and not "segformer" in self.opt.G_netG: + export_path_jit = save_path.replace(".pth", ".pt") + jit_model = torch.jit.trace(net, self.get_dummy_input()) + jit_model.save(export_path_jit) + + def get_dummy_input(self, device=None): + input_nc = self.opt.model_input_nc + if self.opt.model_multimodal: + input_nc += self.opt.train_mm_nz + + if device is None: + device = self.device + dummy_input = torch.randn( + 1, + input_nc, + self.opt.data_crop_size, + self.opt.data_crop_size, + device=device, + ) - # jit - if self.opt.train_export_jit and not "segformer" in self.opt.G_netG: - export_path_jit = save_path.replace(".pth", ".pt") - jit_model = torch.jit.trace(net, dummy_input) - jit_model.save(export_path_jit) + return dummy_input def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" diff --git a/models/modules/unet_generator_attn/unet_generator_attn.py b/models/modules/unet_generator_attn/unet_generator_attn.py index cc3627dc0..feb1a6667 100644 --- a/models/modules/unet_generator_attn/unet_generator_attn.py +++ b/models/modules/unet_generator_attn/unet_generator_attn.py @@ -242,7 +242,7 @@ def __init__( self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) def forward(self, x): - return checkpoint(self._forward, (x,), self.parameters(), True) + return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) def _forward(self, x): b, c, *spatial = x.shape diff --git a/scripts/export_onnx_model.py b/scripts/export_onnx_model.py index 3a4cae6d7..c1fdbbc17 100644 --- a/scripts/export_onnx_model.py +++ b/scripts/export_onnx_model.py @@ -31,7 +31,7 @@ help="optional model configuration, e.g /path/to/segformer_config_b0.py", ) parser.add_argument("--img-size", default=256, type=int, help="square image size") -parser.add_argument("--cpu", action="store_true", help="whether to export for CPU") +parser.add_argument("--cuda", action="store_true", help="whether to export using gpu") parser.add_argument("--bw", action="store_true", help="whether input/output is bw") parser.add_argument( "--padding-type", @@ -80,14 +80,14 @@ model.eval() model.load_state_dict(torch.load(args.model_in_file)) -if not args.cpu: +if args.cuda: model = model.cuda() # export to ONNX via tracing -if args.cpu: - device = "cpu" -else: +if args.cuda: device = "cuda" +else: + device = "cpu" dummy_input = torch.randn(1, input_nc, args.img_size, args.img_size, device=device)