Skip to content

Commit

Permalink
fix: only onnx export fort segformer G
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau authored and pnsuau committed May 31, 2022
1 parent 2aecf46 commit 543ce28
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@ def __init__(self, opt, rank):

self.margin = self.opt.data_online_context_pixels * 2

if "segformer" in self.opt.G_netG:
self.onnx_opset_version = 11
else:
self.onnx_opset_version = 9

@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new model-specific options, and rewrite default values for existing options.
Expand Down Expand Up @@ -481,22 +486,20 @@ def export_networks(self, epoch):

# onnx
export_path_onnx = save_path.replace(".pth", ".onnx")
if "segformer" in name:
opset_version = 11
else:
opset_version = 9

torch.onnx.export(
net,
dummy_input,
export_path_onnx,
verbose=False,
opset_version=opset_version,
opset_version=self.onnx_opset_version,
)

# jit
export_path_jit = save_path.replace(".pth", ".pt")
jit_model = torch.jit.trace(net, dummy_input)
jit_model.save(export_path_jit)
if not "segformer" in self.opt.G_netG:
export_path_jit = save_path.replace(".pth", ".pt")
jit_model = torch.jit.trace(net, dummy_input)
jit_model.save(export_path_jit)

def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
Expand Down

0 comments on commit 543ce28

Please sign in to comment.