Skip to content

Commit

Permalink
Merge pull request #4 from PINTO0309/fix_irversion
Browse files Browse the repository at this point in the history
Fix to preserve domain and ir_version
  • Loading branch information
PINTO0309 authored Apr 30, 2024
2 parents e9c04e3 + ea41207 commit f94b1ac
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
2 changes: 1 addition & 1 deletion snc4onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from snc4onnx.onnx_network_combine import combine, main

__version__ = '1.0.12'
__version__ = '1.0.13'
18 changes: 13 additions & 5 deletions snc4onnx/onnx_network_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,15 @@ def has_duplicates(seq):
## 1. ONNX load
tmp_onnx_graphs = []
custom_domain_check_onnx_nodes = []
max_ir_version: int = 0
if len(onnx_graphs) > 0:
for onnx_graph in onnx_graphs:
domain: str = onnx_graph.domain
ir_version: int = onnx_graph.ir_version
max_ir_version = ir_version if max_ir_version < ir_version else max_ir_version
gs_graph = gs.import_onnx(onnx_graph)
gs_graph.cleanup().toposort()
tmp_onnx_graphs.append(gs.export_onnx(gs_graph))
tmp_onnx_graphs.append(gs.export_onnx(gs_graph, do_type_check=False, **{'domain': domain, 'ir_version': ir_version}))
custom_domain_check_onnx_nodes = \
custom_domain_check_onnx_nodes + \
[
Expand All @@ -243,9 +247,13 @@ def has_duplicates(seq):
]
else:
for onnx_path in input_onnx_file_paths:
gs_graph = gs.import_onnx(onnx.load(onnx_path))
onnx_graph = onnx.load(onnx_path)
domain: str = onnx_graph.domain
ir_version: int = onnx_graph.ir_version
max_ir_version = ir_version if max_ir_version < ir_version else max_ir_version
gs_graph = gs.import_onnx(onnx_graph)
gs_graph.cleanup().toposort()
tmp_onnx_graphs.append(gs.export_onnx(gs_graph))
tmp_onnx_graphs.append(gs.export_onnx(gs_graph, do_type_check=False, **{'domain': domain, 'ir_version': ir_version}))
custom_domain_check_onnx_graph = onnx.load(onnx_path)
custom_domain_check_onnx_nodes = \
custom_domain_check_onnx_nodes + \
Expand Down Expand Up @@ -436,7 +444,7 @@ def has_duplicates(seq):

# Cleaning
src_gs_model.cleanup().toposort()
combined_model = gs.export_onnx(src_gs_model)
combined_model = gs.export_onnx(src_gs_model, do_type_check=False, **{'ir_version': max_ir_version})

## Output of onnx files in the process of fusion
if output_of_onnx_file_in_the_process_of_fusion and output_onnx_file_path:
Expand Down Expand Up @@ -484,7 +492,7 @@ def has_duplicates(seq):
replaced_output_names.append(tmp_replaced_output_name)

gs_combined_model.cleanup().toposort()
combined_model = gs.export_onnx(gs_combined_model)
combined_model = gs.export_onnx(gs_combined_model, do_type_check=False, **{'ir_version': max_ir_version})

## 4. Optimize
try:
Expand Down

0 comments on commit f94b1ac

Please sign in to comment.