Skip to content

Commit

Permalink
Adds function to analyse and plot a profile from onnxruntime (#71)
Browse files Browse the repository at this point in the history
* add function to plot a profile

* add command line

* improves graph

* add shape

* plot

* rename into with_shape
  • Loading branch information
xadupre authored Sep 13, 2023
1 parent e288bc2 commit a76b74d
Show file tree
Hide file tree
Showing 11 changed files with 887 additions and 96 deletions.
2 changes: 1 addition & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Change Logs
0.2.0
+++++

* :pr:`68`: add CPU implementation for CustomGemmFloat8
* :pr:`68`, :pr:`69`, :pr:`70`: add CPU implementation for CustomGemmFloat8
* :pr:`67`: add a function to extract a subgraph of a model
* :pr:`59`, :pr:`60`, :pr:`61`, :pr:`62`, :pr:`63`, :pr:`65`,
:pr:`66`, :pr:`68`, :pr:`69`:
Expand Down
12 changes: 12 additions & 0 deletions _doc/api/command_lines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ Split the model and the coefficients. The coefficients goes to an external file.
from onnx_extended._command_lines_parser import get_parser_external
get_parser_external().print_help()

plot
====

Plots a graph like a profiling.

.. runpython::

from onnx_extended._command_lines_parser import get_parser_plot
get_parser_plot().print_help()

.. autofunction:: onnx_extended._command_lines.cmd_plot

print
=====

Expand Down
1 change: 1 addition & 0 deletions _doc/api/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ Graphs
tools_graph
tools_transformer
tools_manipulations
tools_other
17 changes: 17 additions & 0 deletions _doc/api/tools_other.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

===========
Other tools
===========

Profiling
=========

js_profile_to_dataframe
+++++++++++++++++++++++

.. autofunction:: onnx_extended.tools.js_profile.js_profile_to_dataframe

plot_ort_profile
++++++++++++++++

.. autofunction:: onnx_extended.tools.js_profile.plot_ort_profile
245 changes: 245 additions & 0 deletions _unittests/ut_tools/test_js_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import os
import unittest
import numpy as np
from onnx import TensorProto
from onnx.checker import check_model
from onnx.helper import (
make_model,
make_graph,
make_node,
make_opsetid,
make_tensor_value_info,
)
from onnx.numpy_helper import from_array
from onnxruntime import InferenceSession, SessionOptions
import matplotlib.pyplot as plt
from onnx_extended.ext_test_case import ExtTestCase, ignore_warnings
from onnx_extended.tools.js_profile import js_profile_to_dataframe, plot_ort_profile


class TestJsProfile(ExtTestCase):
def _get_model(self):
model_def0 = make_model(
make_graph(
[
make_node("Add", ["X", "init1"], ["X1"]),
make_node("Abs", ["X"], ["X2"]),
make_node("Add", ["X", "init3"], ["inter"]),
make_node("Mul", ["X1", "inter"], ["Xm"]),
make_node("Sub", ["X2", "Xm"], ["final"]),
],
"test",
[make_tensor_value_info("X", TensorProto.FLOAT, [None])],
[make_tensor_value_info("final", TensorProto.FLOAT, [None])],
[
from_array(np.array([1], dtype=np.float32), name="init1"),
from_array(np.array([3], dtype=np.float32), name="init3"),
],
),
opset_imports=[make_opsetid("", 18)],
)
check_model(model_def0)
return model_def0

def test_js_profile_to_dataframe(self):
sess_options = SessionOptions()
sess_options.enable_profiling = True
sess = InferenceSession(
self._get_model().SerializeToString(),
sess_options,
providers=["CPUExecutionProvider"],
)
for _ in range(11):
sess.run(None, dict(X=np.arange(10).astype(np.float32)))
prof = sess.end_profiling()

df = js_profile_to_dataframe(prof, first_it_out=True)
self.assertEqual(df.shape, (189, 17))
self.assertEqual(
set(df.columns),
set(
[
"cat",
"pid",
"tid",
"dur",
"ts",
"ph",
"name",
"args_op_name",
"op_name",
"args_thread_scheduling_stats",
"args_output_size",
"args_parameter_size",
"args_activation_size",
"args_node_index",
"args_provider",
"event_name",
"iteration",
]
),
)

df = js_profile_to_dataframe(prof, agg=True)
self.assertEqual(df.shape, (17, 1))
self.assertEqual(list(df.columns), ["dur"])

df = js_profile_to_dataframe(prof, agg_op_name=True)
self.assertEqual(df.shape, (189, 17))
self.assertEqual(
set(df.columns),
set(
[
"cat",
"pid",
"tid",
"dur",
"ts",
"ph",
"name",
"args_op_name",
"op_name",
"args_thread_scheduling_stats",
"args_output_size",
"args_parameter_size",
"args_activation_size",
"args_node_index",
"args_provider",
"event_name",
"iteration",
]
),
)

os.remove(prof)

@ignore_warnings(UserWarning)
def test_plot_profile_2(self):
sess_options = SessionOptions()
sess_options.enable_profiling = True
sess = InferenceSession(
self._get_model().SerializeToString(),
sess_options,
providers=["CPUExecutionProvider"],
)
for _ in range(11):
sess.run(None, dict(X=np.arange(10).astype(np.float32)))
prof = sess.end_profiling()

df = js_profile_to_dataframe(prof, first_it_out=True)

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
plot_ort_profile(df, ax[0], ax[1], "test_title")
# fig.savefig("graph1.png")
self.assertNotEmpty(fig)

os.remove(prof)

@ignore_warnings(UserWarning)
def test_plot_profile_2_shape(self):
sess_options = SessionOptions()
sess_options.enable_profiling = True
sess = InferenceSession(
self._get_model().SerializeToString(),
sess_options,
providers=["CPUExecutionProvider"],
)
for _ in range(11):
sess.run(None, dict(X=np.arange(10).astype(np.float32)))
prof = sess.end_profiling()

df = js_profile_to_dataframe(prof, first_it_out=True, with_shape=True)

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
plot_ort_profile(df, ax[0], ax[1], "test_title")
# fig.savefig("graph1.png")
self.assertNotEmpty(fig)

os.remove(prof)

@ignore_warnings(UserWarning)
def test_plot_profile_agg(self):
sess_options = SessionOptions()
sess_options.enable_profiling = True
sess = InferenceSession(
self._get_model().SerializeToString(),
sess_options,
providers=["CPUExecutionProvider"],
)
for _ in range(11):
sess.run(None, dict(X=np.arange(10).astype(np.float32)))
prof = sess.end_profiling()

df = js_profile_to_dataframe(prof, first_it_out=True, agg=True)

fig, ax = plt.subplots(1, 1, figsize=(10, 5))
plot_ort_profile(df, ax, title="test_title")
fig.tight_layout()
# fig.savefig("graph2.png")
self.assertNotEmpty(fig)

os.remove(prof)

def _get_model_domain(self):
model_def0 = make_model(
make_graph(
[
make_node("Transpose", ["X"], ["Xt"], perm=[1, 0]),
make_node(
"CustomGemmFloat",
["X", "Xt"],
["final"],
domain="onnx_extented.ortops.tutorial.cpu",
),
],
"test",
[make_tensor_value_info("X", TensorProto.FLOAT, [None, None])],
[make_tensor_value_info("final", TensorProto.FLOAT, [None, None])],
),
opset_imports=[
make_opsetid("", 18),
make_opsetid("onnx_extented.ortops.tutorial.cpu", 1),
],
)
check_model(model_def0)
return model_def0

@ignore_warnings(UserWarning)
def test_plot_domain_agg(self):
from onnx_extended.ortops.tutorial.cpu import get_ort_ext_libs

sess_options = SessionOptions()
sess_options.enable_profiling = True
sess_options.register_custom_ops_library(get_ort_ext_libs()[0])
sess = InferenceSession(
self._get_model_domain().SerializeToString(),
sess_options,
providers=["CPUExecutionProvider"],
)
for _ in range(11):
sess.run(None, dict(X=np.arange(16).astype(np.float32).reshape((-1, 4))))
prof = sess.end_profiling()

df = js_profile_to_dataframe(prof, first_it_out=True)

fig, ax = plt.subplots(1, 1, figsize=(10, 5))
plot_ort_profile(df, ax, title="test_title")
fig.tight_layout()
# fig.savefig("graph3.png")
self.assertNotEmpty(fig)

os.remove(prof)


if __name__ == "__main__":
import logging

for name in [
"matplotlib.font_manager",
"PIL.PngImagePlugin",
"matplotlib",
"matplotlib.pyplot",
]:
log = logging.getLogger(name)
log.setLevel(logging.ERROR)
unittest.main(verbosity=2)
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)


class TestCommandLines(ExtTestCase):
class TestCommandLines1(ExtTestCase):
def test_main_parser(self):
st = StringIO()
with redirect_stdout(st):
Expand Down
85 changes: 85 additions & 0 deletions _unittests/ut_xrun_doc/test_command_lines2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os
import tempfile
import unittest
from contextlib import redirect_stdout
from io import StringIO
import numpy as np
from onnx import TensorProto
from onnx.checker import check_model
from onnx.helper import (
make_graph,
make_model,
make_node,
make_opsetid,
make_tensor_value_info,
)
from onnxruntime import InferenceSession, SessionOptions
from onnx_extended.ext_test_case import ExtTestCase
from onnx_extended._command_lines_parser import (
get_parser_plot,
main,
)


class TestCommandLines2(ExtTestCase):
def test_parser_plot(self):
st = StringIO()
with redirect_stdout(st):
get_parser_plot().print_help()
text = st.getvalue()
self.assertIn("kind", text)
self.assertIn("verbose", text)

def test_command_store(self):
X = make_tensor_value_info("X", TensorProto.FLOAT, [None])
Z = make_tensor_value_info("Z", TensorProto.FLOAT, [None])
graph = make_graph(
[
make_node("Neg", ["X"], ["res"]),
make_node("Cos", ["res"], ["Z"]),
],
"g",
[X],
[Z],
)
onnx_model = make_model(graph, opset_imports=[make_opsetid("", 18)])
check_model(onnx_model)

sess_options = SessionOptions()
sess_options.enable_profiling = True
sess = InferenceSession(
onnx_model.SerializeToString(),
sess_options,
providers=["CPUExecutionProvider"],
)
for _ in range(11):
sess.run(None, dict(X=np.arange(10).astype(np.float32)))
prof = sess.end_profiling()

with tempfile.TemporaryDirectory() as root:
csv = os.path.join(root, "o.csv")
png = os.path.join(root, "o.png")
args = [
"plot",
"-i",
prof,
"-k",
"profile_node",
"-c",
csv,
"-o",
png,
"-v",
]
st = StringIO()
with redirect_stdout(st):
main(args)
self.assertIn("[plot_profile] save", st.getvalue())
self.assertExists(png)
self.assertExists(csv)

os.remove(prof)


if __name__ == "__main__":
unittest.main(verbosity=2)
Loading

0 comments on commit a76b74d

Please sign in to comment.