-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds function to analyse and plot a profile from onnxruntime (#71)
* add function to plot a profile * add command line * improves graph * add shape * plot * rename into with_shape
- Loading branch information
Showing
11 changed files
with
887 additions
and
96 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,3 +16,4 @@ Graphs | |
tools_graph | ||
tools_transformer | ||
tools_manipulations | ||
tools_other |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
|
||
=========== | ||
Other tools | ||
=========== | ||
|
||
Profiling | ||
========= | ||
|
||
js_profile_to_dataframe | ||
+++++++++++++++++++++++ | ||
|
||
.. autofunction:: onnx_extended.tools.js_profile.js_profile_to_dataframe | ||
|
||
plot_ort_profile | ||
++++++++++++++++ | ||
|
||
.. autofunction:: onnx_extended.tools.js_profile.plot_ort_profile |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
import os | ||
import unittest | ||
import numpy as np | ||
from onnx import TensorProto | ||
from onnx.checker import check_model | ||
from onnx.helper import ( | ||
make_model, | ||
make_graph, | ||
make_node, | ||
make_opsetid, | ||
make_tensor_value_info, | ||
) | ||
from onnx.numpy_helper import from_array | ||
from onnxruntime import InferenceSession, SessionOptions | ||
import matplotlib.pyplot as plt | ||
from onnx_extended.ext_test_case import ExtTestCase, ignore_warnings | ||
from onnx_extended.tools.js_profile import js_profile_to_dataframe, plot_ort_profile | ||
|
||
|
||
class TestJsProfile(ExtTestCase): | ||
def _get_model(self): | ||
model_def0 = make_model( | ||
make_graph( | ||
[ | ||
make_node("Add", ["X", "init1"], ["X1"]), | ||
make_node("Abs", ["X"], ["X2"]), | ||
make_node("Add", ["X", "init3"], ["inter"]), | ||
make_node("Mul", ["X1", "inter"], ["Xm"]), | ||
make_node("Sub", ["X2", "Xm"], ["final"]), | ||
], | ||
"test", | ||
[make_tensor_value_info("X", TensorProto.FLOAT, [None])], | ||
[make_tensor_value_info("final", TensorProto.FLOAT, [None])], | ||
[ | ||
from_array(np.array([1], dtype=np.float32), name="init1"), | ||
from_array(np.array([3], dtype=np.float32), name="init3"), | ||
], | ||
), | ||
opset_imports=[make_opsetid("", 18)], | ||
) | ||
check_model(model_def0) | ||
return model_def0 | ||
|
||
def test_js_profile_to_dataframe(self): | ||
sess_options = SessionOptions() | ||
sess_options.enable_profiling = True | ||
sess = InferenceSession( | ||
self._get_model().SerializeToString(), | ||
sess_options, | ||
providers=["CPUExecutionProvider"], | ||
) | ||
for _ in range(11): | ||
sess.run(None, dict(X=np.arange(10).astype(np.float32))) | ||
prof = sess.end_profiling() | ||
|
||
df = js_profile_to_dataframe(prof, first_it_out=True) | ||
self.assertEqual(df.shape, (189, 17)) | ||
self.assertEqual( | ||
set(df.columns), | ||
set( | ||
[ | ||
"cat", | ||
"pid", | ||
"tid", | ||
"dur", | ||
"ts", | ||
"ph", | ||
"name", | ||
"args_op_name", | ||
"op_name", | ||
"args_thread_scheduling_stats", | ||
"args_output_size", | ||
"args_parameter_size", | ||
"args_activation_size", | ||
"args_node_index", | ||
"args_provider", | ||
"event_name", | ||
"iteration", | ||
] | ||
), | ||
) | ||
|
||
df = js_profile_to_dataframe(prof, agg=True) | ||
self.assertEqual(df.shape, (17, 1)) | ||
self.assertEqual(list(df.columns), ["dur"]) | ||
|
||
df = js_profile_to_dataframe(prof, agg_op_name=True) | ||
self.assertEqual(df.shape, (189, 17)) | ||
self.assertEqual( | ||
set(df.columns), | ||
set( | ||
[ | ||
"cat", | ||
"pid", | ||
"tid", | ||
"dur", | ||
"ts", | ||
"ph", | ||
"name", | ||
"args_op_name", | ||
"op_name", | ||
"args_thread_scheduling_stats", | ||
"args_output_size", | ||
"args_parameter_size", | ||
"args_activation_size", | ||
"args_node_index", | ||
"args_provider", | ||
"event_name", | ||
"iteration", | ||
] | ||
), | ||
) | ||
|
||
os.remove(prof) | ||
|
||
@ignore_warnings(UserWarning) | ||
def test_plot_profile_2(self): | ||
sess_options = SessionOptions() | ||
sess_options.enable_profiling = True | ||
sess = InferenceSession( | ||
self._get_model().SerializeToString(), | ||
sess_options, | ||
providers=["CPUExecutionProvider"], | ||
) | ||
for _ in range(11): | ||
sess.run(None, dict(X=np.arange(10).astype(np.float32))) | ||
prof = sess.end_profiling() | ||
|
||
df = js_profile_to_dataframe(prof, first_it_out=True) | ||
|
||
fig, ax = plt.subplots(1, 2, figsize=(10, 5)) | ||
plot_ort_profile(df, ax[0], ax[1], "test_title") | ||
# fig.savefig("graph1.png") | ||
self.assertNotEmpty(fig) | ||
|
||
os.remove(prof) | ||
|
||
@ignore_warnings(UserWarning) | ||
def test_plot_profile_2_shape(self): | ||
sess_options = SessionOptions() | ||
sess_options.enable_profiling = True | ||
sess = InferenceSession( | ||
self._get_model().SerializeToString(), | ||
sess_options, | ||
providers=["CPUExecutionProvider"], | ||
) | ||
for _ in range(11): | ||
sess.run(None, dict(X=np.arange(10).astype(np.float32))) | ||
prof = sess.end_profiling() | ||
|
||
df = js_profile_to_dataframe(prof, first_it_out=True, with_shape=True) | ||
|
||
fig, ax = plt.subplots(1, 2, figsize=(10, 5)) | ||
plot_ort_profile(df, ax[0], ax[1], "test_title") | ||
# fig.savefig("graph1.png") | ||
self.assertNotEmpty(fig) | ||
|
||
os.remove(prof) | ||
|
||
@ignore_warnings(UserWarning) | ||
def test_plot_profile_agg(self): | ||
sess_options = SessionOptions() | ||
sess_options.enable_profiling = True | ||
sess = InferenceSession( | ||
self._get_model().SerializeToString(), | ||
sess_options, | ||
providers=["CPUExecutionProvider"], | ||
) | ||
for _ in range(11): | ||
sess.run(None, dict(X=np.arange(10).astype(np.float32))) | ||
prof = sess.end_profiling() | ||
|
||
df = js_profile_to_dataframe(prof, first_it_out=True, agg=True) | ||
|
||
fig, ax = plt.subplots(1, 1, figsize=(10, 5)) | ||
plot_ort_profile(df, ax, title="test_title") | ||
fig.tight_layout() | ||
# fig.savefig("graph2.png") | ||
self.assertNotEmpty(fig) | ||
|
||
os.remove(prof) | ||
|
||
def _get_model_domain(self): | ||
model_def0 = make_model( | ||
make_graph( | ||
[ | ||
make_node("Transpose", ["X"], ["Xt"], perm=[1, 0]), | ||
make_node( | ||
"CustomGemmFloat", | ||
["X", "Xt"], | ||
["final"], | ||
domain="onnx_extented.ortops.tutorial.cpu", | ||
), | ||
], | ||
"test", | ||
[make_tensor_value_info("X", TensorProto.FLOAT, [None, None])], | ||
[make_tensor_value_info("final", TensorProto.FLOAT, [None, None])], | ||
), | ||
opset_imports=[ | ||
make_opsetid("", 18), | ||
make_opsetid("onnx_extented.ortops.tutorial.cpu", 1), | ||
], | ||
) | ||
check_model(model_def0) | ||
return model_def0 | ||
|
||
@ignore_warnings(UserWarning) | ||
def test_plot_domain_agg(self): | ||
from onnx_extended.ortops.tutorial.cpu import get_ort_ext_libs | ||
|
||
sess_options = SessionOptions() | ||
sess_options.enable_profiling = True | ||
sess_options.register_custom_ops_library(get_ort_ext_libs()[0]) | ||
sess = InferenceSession( | ||
self._get_model_domain().SerializeToString(), | ||
sess_options, | ||
providers=["CPUExecutionProvider"], | ||
) | ||
for _ in range(11): | ||
sess.run(None, dict(X=np.arange(16).astype(np.float32).reshape((-1, 4)))) | ||
prof = sess.end_profiling() | ||
|
||
df = js_profile_to_dataframe(prof, first_it_out=True) | ||
|
||
fig, ax = plt.subplots(1, 1, figsize=(10, 5)) | ||
plot_ort_profile(df, ax, title="test_title") | ||
fig.tight_layout() | ||
# fig.savefig("graph3.png") | ||
self.assertNotEmpty(fig) | ||
|
||
os.remove(prof) | ||
|
||
|
||
if __name__ == "__main__": | ||
import logging | ||
|
||
for name in [ | ||
"matplotlib.font_manager", | ||
"PIL.PngImagePlugin", | ||
"matplotlib", | ||
"matplotlib.pyplot", | ||
]: | ||
log = logging.getLogger(name) | ||
log.setLevel(logging.ERROR) | ||
unittest.main(verbosity=2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import os | ||
import tempfile | ||
import unittest | ||
from contextlib import redirect_stdout | ||
from io import StringIO | ||
import numpy as np | ||
from onnx import TensorProto | ||
from onnx.checker import check_model | ||
from onnx.helper import ( | ||
make_graph, | ||
make_model, | ||
make_node, | ||
make_opsetid, | ||
make_tensor_value_info, | ||
) | ||
from onnxruntime import InferenceSession, SessionOptions | ||
from onnx_extended.ext_test_case import ExtTestCase | ||
from onnx_extended._command_lines_parser import ( | ||
get_parser_plot, | ||
main, | ||
) | ||
|
||
|
||
class TestCommandLines2(ExtTestCase): | ||
def test_parser_plot(self): | ||
st = StringIO() | ||
with redirect_stdout(st): | ||
get_parser_plot().print_help() | ||
text = st.getvalue() | ||
self.assertIn("kind", text) | ||
self.assertIn("verbose", text) | ||
|
||
def test_command_store(self): | ||
X = make_tensor_value_info("X", TensorProto.FLOAT, [None]) | ||
Z = make_tensor_value_info("Z", TensorProto.FLOAT, [None]) | ||
graph = make_graph( | ||
[ | ||
make_node("Neg", ["X"], ["res"]), | ||
make_node("Cos", ["res"], ["Z"]), | ||
], | ||
"g", | ||
[X], | ||
[Z], | ||
) | ||
onnx_model = make_model(graph, opset_imports=[make_opsetid("", 18)]) | ||
check_model(onnx_model) | ||
|
||
sess_options = SessionOptions() | ||
sess_options.enable_profiling = True | ||
sess = InferenceSession( | ||
onnx_model.SerializeToString(), | ||
sess_options, | ||
providers=["CPUExecutionProvider"], | ||
) | ||
for _ in range(11): | ||
sess.run(None, dict(X=np.arange(10).astype(np.float32))) | ||
prof = sess.end_profiling() | ||
|
||
with tempfile.TemporaryDirectory() as root: | ||
csv = os.path.join(root, "o.csv") | ||
png = os.path.join(root, "o.png") | ||
args = [ | ||
"plot", | ||
"-i", | ||
prof, | ||
"-k", | ||
"profile_node", | ||
"-c", | ||
csv, | ||
"-o", | ||
png, | ||
"-v", | ||
] | ||
st = StringIO() | ||
with redirect_stdout(st): | ||
main(args) | ||
self.assertIn("[plot_profile] save", st.getvalue()) | ||
self.assertExists(png) | ||
self.assertExists(csv) | ||
|
||
os.remove(prof) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main(verbosity=2) |
Oops, something went wrong.