Skip to content

Commit

Permalink
feat: export for unet_mha
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau authored and beniz committed Dec 2, 2022
1 parent 67bede6 commit b4c3cfd
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 33 deletions.
78 changes: 51 additions & 27 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def __init__(self, opt, rank):

if "segformer" in self.opt.G_netG:
self.onnx_opset_version = 11
elif "ittr" in self.opt.G_netG:
elif "ittr" in self.opt.G_netG or "unet_mha" in self.opt.G_netG:
self.onnx_opset_version = 12
else:
self.onnx_opset_version = 9
Expand Down Expand Up @@ -795,39 +795,63 @@ def export_networks(self, epoch):

net = getattr(self, "net" + name)

input_nc = self.opt.model_input_nc
if self.opt.model_multimodal:
input_nc += self.opt.train_mm_nz

dummy_input = torch.randn(
1,
input_nc,
self.opt.data_crop_size,
self.opt.data_crop_size,
device=self.device,
)

# onnx

if (
not "ittr" in self.opt.G_netG
and not "unet_mha" in self.opt.G_netG
and not "palette" in self.opt.model_type
):
export_path_onnx = save_path.replace(".pth", ".onnx")

torch.onnx.export(
net,
dummy_input,
export_path_onnx,
verbose=False,
opset_version=self.onnx_opset_version,
)
input_nc = self.opt.model_input_nc
if self.opt.model_multimodal:
input_nc += self.opt.train_mm_nz

dummy_input = torch.randn(
1,
input_nc,
self.opt.data_crop_size,
self.opt.data_crop_size,
device=self.device,
)

# onnx
if not "ittr" in self.opt.G_netG:
export_path_onnx = save_path.replace(".pth", ".onnx")

export_device = torch.device("cpu")
net = net.to(export_device)

torch.onnx.export(
net,
self.get_dummy_input(export_device),
export_path_onnx,
verbose=False,
opset_version=self.onnx_opset_version,
)
net.to(self.device)

# jit
if self.opt.train_export_jit and not "segformer" in self.opt.G_netG:
export_path_jit = save_path.replace(".pth", ".pt")
jit_model = torch.jit.trace(net, self.get_dummy_input())
jit_model.save(export_path_jit)

def get_dummy_input(self, device=None):
input_nc = self.opt.model_input_nc
if self.opt.model_multimodal:
input_nc += self.opt.train_mm_nz

if device is None:
device = self.device
dummy_input = torch.randn(
1,
input_nc,
self.opt.data_crop_size,
self.opt.data_crop_size,
device=device,
)

# jit
if self.opt.train_export_jit and not "segformer" in self.opt.G_netG:
export_path_jit = save_path.replace(".pth", ".pt")
jit_model = torch.jit.trace(net, dummy_input)
jit_model.save(export_path_jit)
return dummy_input

def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
Expand Down
2 changes: 1 addition & 1 deletion models/modules/unet_generator_attn/unet_generator_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def __init__(
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))

def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), True)
return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)

def _forward(self, x):
b, c, *spatial = x.shape
Expand Down
10 changes: 5 additions & 5 deletions scripts/export_onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
help="optional model configuration, e.g /path/to/segformer_config_b0.py",
)
parser.add_argument("--img-size", default=256, type=int, help="square image size")
parser.add_argument("--cpu", action="store_true", help="whether to export for CPU")
parser.add_argument("--cuda", action="store_true", help="whether to export using gpu")
parser.add_argument("--bw", action="store_true", help="whether input/output is bw")
parser.add_argument(
"--padding-type",
Expand Down Expand Up @@ -80,14 +80,14 @@
model.eval()
model.load_state_dict(torch.load(args.model_in_file))

if not args.cpu:
if args.cuda:
model = model.cuda()

# export to ONNX via tracing
if args.cpu:
device = "cpu"
else:
if args.cuda:
device = "cuda"
else:
device = "cpu"
dummy_input = torch.randn(1, input_nc, args.img_size, args.img_size, device=device)


Expand Down

0 comments on commit b4c3cfd

Please sign in to comment.