Skip to content

Commit

Permalink
plot
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Sep 12, 2023
1 parent 4d1d5bd commit fc0db33
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
15 changes: 9 additions & 6 deletions _unittests/ut_tools/test_js_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from onnx.numpy_helper import from_array
from onnxruntime import InferenceSession, SessionOptions
import matplotlib.pyplot as plt
from onnx_extended.ext_test_case import ExtTestCase
from onnx_extended.ext_test_case import ExtTestCase, ignore_warnings
from onnx_extended.tools.js_profile import js_profile_to_dataframe, plot_ort_profile


Expand All @@ -23,7 +23,7 @@ def _get_model(self):
make_graph(
[
make_node("Add", ["X", "init1"], ["X1"]),
make_node("Add", ["X", "init2"], ["X2"]),
make_node("Abs", ["X"], ["X2"]),
make_node("Add", ["X", "init3"], ["inter"]),
make_node("Mul", ["X1", "inter"], ["Xm"]),
make_node("Sub", ["X2", "Xm"], ["final"]),
Expand All @@ -33,7 +33,6 @@ def _get_model(self):
[make_tensor_value_info("final", TensorProto.FLOAT, [None])],
[
from_array(np.array([1], dtype=np.float32), name="init1"),
from_array(np.array([2], dtype=np.float32), name="init2"),
from_array(np.array([3], dtype=np.float32), name="init3"),
],
),
Expand All @@ -55,7 +54,7 @@ def test_js_profile_to_dataframe(self):
prof = sess.end_profiling()

df = js_profile_to_dataframe(prof, first_it_out=True)
self.assertEqual(df.shape, (189, 19))
self.assertEqual(df.shape, (189, 17))
self.assertEqual(
set(df.columns),
set(
Expand All @@ -82,11 +81,11 @@ def test_js_profile_to_dataframe(self):
)

df = js_profile_to_dataframe(prof, agg=True)
self.assertEqual(df.shape, (15, 1))
self.assertEqual(df.shape, (17, 1))
self.assertEqual(list(df.columns), ["dur"])

df = js_profile_to_dataframe(prof, agg_op_name=True)
self.assertEqual(df.shape, (189, 19))
self.assertEqual(df.shape, (189, 17))
self.assertEqual(
set(df.columns),
set(
Expand Down Expand Up @@ -114,6 +113,7 @@ def test_js_profile_to_dataframe(self):

os.remove(prof)

@ignore_warnings(UserWarning)
def test_plot_profile_2(self):
sess_options = SessionOptions()
sess_options.enable_profiling = True
Expand All @@ -135,6 +135,7 @@ def test_plot_profile_2(self):

os.remove(prof)

@ignore_warnings(UserWarning)
def test_plot_profile_2_shape(self):
sess_options = SessionOptions()
sess_options.enable_profiling = True
Expand All @@ -156,6 +157,7 @@ def test_plot_profile_2_shape(self):

os.remove(prof)

@ignore_warnings(UserWarning)
def test_plot_profile_agg(self):
sess_options = SessionOptions()
sess_options.enable_profiling = True
Expand Down Expand Up @@ -202,6 +204,7 @@ def _get_model_domain(self):
check_model(model_def0)
return model_def0

@ignore_warnings(UserWarning)
def test_plot_domain_agg(self):
from onnx_extended.ortops.tutorial.cpu import get_ort_ext_libs

Expand Down
19 changes: 14 additions & 5 deletions onnx_extended/tools/js_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def _process_shape(shape_df):
if len(shape_df) == 0:
if isinstance(shape_df, float) or len(shape_df) == 0:
return ""
values = []
for val in shape_df:
Expand Down Expand Up @@ -53,6 +53,15 @@ def sep_event(s):
df.loc[i, "iteration"] = current

if not agg:
if add_shape:
df["args_input_type_shape"] = df["args_input_type_shape"].apply(
_process_shape
)
df["args_output_type_shape"] = df["args_output_type_shape"].apply(
_process_shape
)
else:
df = df.drop(["args_input_type_shape", "args_output_type_shape"], axis=1)
return df

agg_cols = ["cat", "args_node_index", "args_op_name", "args_provider", "event_name"]
Expand Down Expand Up @@ -137,7 +146,7 @@ def _preprocess_graph1(df):
gr_n = df[agg_cols].groupby(agg_cols[1:]).count().sort_values("dur")
gr_n = gr_n.loc[gr_dur.index, :]
gr_n.columns = ["count"]
gr = gr_dur.merge(gr_n)
gr = gr_dur.merge(gr_n, left_index=True, right_index=True)
gr["ratio"] = gr["dur"] / gr["dur"].sum()
return gr_dur, gr_n, gr

Expand Down Expand Up @@ -188,8 +197,8 @@ def plot_ort_profile(
# Aggregation by operator
gr_dur, gr_n, _ = _preprocess_graph1(df)
gr_dur.plot.barh(ax=ax0)
ax0.get_yaxis().set_label_text("")
ax0.set_xticklabels(ax0.get_xticklabels(), fontsize=fontsize)
ax0.get_yaxis().set_label_text("")
ax0.set_yticklabels(
ax0.get_yticklabels(), rotation=45, ha="right", fontsize=fontsize
)
Expand All @@ -198,17 +207,17 @@ def plot_ort_profile(
if ax1 is not None:
gr_n.plot.barh(ax=ax1)
ax1.set_title("n occurences")
ax1.get_yaxis().set_label_text("")
ax1.set_xticklabels(ax1.get_xticklabels(), fontsize=fontsize)
ax1.get_yaxis().set_label_text("")
ax1.set_yticklabels(
ax1.get_yticklabels(), rotation=45, ha="right", fontsize=fontsize
)
return ax0

df = _preprocess_graph2(df)
df[["dur"]].plot.barh(ax=ax0)
ax0.get_yaxis().set_label_text("")
ax0.set_xticklabels(ax0.get_xticklabels(), fontsize=fontsize)
ax0.get_yaxis().set_label_text("")
ax0.set_yticklabels(ax0.get_yticklabels(), fontsize=fontsize)
if title is not None:
ax0.set_title(title)
Expand Down

0 comments on commit fc0db33

Please sign in to comment.