Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add one example around gemm #183

Merged
merged 4 commits into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions _doc/benchmarks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ The benchmark profiles the execution of Gemm for different
types and configuration. That includes a custom operator
only available on CUDA calling function :epkg:`cublasLtMatmul`.

plot_op_gemm2_cuda
++++++++++++++++++

See :ref:`l-example-op-gemm2_cuda`.

One big Gemm or two smaller gemm.

plot_op_mul_cuda
++++++++++++++++

Expand Down
270 changes: 270 additions & 0 deletions _doc/examples/plot_op_gemm2_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
"""
.. _l-example-op-gemm2_cuda:

Gemm Exploration with CUDA
==========================

One big Gemm or two smaller gemm?

Cache Performance
+++++++++++++++++
"""

from onnx_extended.args import get_parsed_args

script_args = get_parsed_args(
"plot_op_gemm2_cuda",
description=__doc__,
config=(
"small",
"small, short optimization (default), "
"medium for medium sizes, "
"large for big sizes",
),
warmup=3,
repeat=5,
itype=(1, "1 or 10 for float or float16"),
expose="config,itype,warmup,repeat",
)

itype = script_args.itype
config = script_args.config
print(f"config={config}")
print(f"itype={itype}")

if config == "small":
sizes = (256, 512, 1024)
elif config == "medium":
sizes = (512, 1024, 2048)
elif config == "large":
sizes = (1024, 2048, 4096, 8192)
else:
try:
sizes = list(map(int, config.split(",")))
except (ValueError, TypeError) as e:
raise AssertionError(f"Unexpected config value {config!r}.") from e

import time
import numpy as np
import onnx.helper as oh
from tqdm import tqdm
from pandas import DataFrame
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
from onnx_extended.ortops.optim.cuda import get_ort_ext_libs


def get_model1(itype):
return oh.make_model(
oh.make_graph(
[
oh.make_node("Gemm", ["X", "Y"], ["XY"]),
oh.make_node("Gemm", ["X", "Z"], ["XZ"]),
oh.make_node("Concat", ["XY", "XZ"], ["XYZ"], axis=1),
],
"nd",
[
oh.make_tensor_value_info("X", itype, [None, None]),
oh.make_tensor_value_info("Y", itype, [None, None]),
oh.make_tensor_value_info("Z", itype, [None, None]),
],
[oh.make_tensor_value_info("XYZ", itype, [None, None])],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=9,
)


print(onnx_simple_text_plot(get_model1(itype)))


########################################
# And the other model


def get_model2(itype):
return oh.make_model(
oh.make_graph(
[
oh.make_node("Concat", ["Y", "Z"], ["YZ"], axis=1),
oh.make_node("Gemm", ["X", "YZ"], ["XYZ"]),
],
"nd",
[
oh.make_tensor_value_info("X", itype, [None, None]),
oh.make_tensor_value_info("Y", itype, [None, None]),
oh.make_tensor_value_info("Z", itype, [None, None]),
],
[oh.make_tensor_value_info("XYZ", itype, [None, None])],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=9,
)


print(onnx_simple_text_plot(get_model2(itype)))

###########################################
# InferenceSession
# ++++++++++++++++

has_cuda = "CUDAExecutionProvider" in get_available_providers()

if has_cuda:

dtype = np.float32 if itype == 1 else np.float16

x = np.random.randn(16, 16).astype(dtype)
y = np.random.randn(16, 16).astype(dtype)
z = np.random.randn(16, 16).astype(dtype)
feeds = dict(X=x, Y=y, Z=z)

sess1 = InferenceSession(
get_model1(itype).SerializeToString(), providers=["CUDAExecutionProvider"]
)
expected = sess1.run(None, feeds)[0]

#########################################
# The other model.

if has_cuda:

opts = SessionOptions()
opts.register_custom_ops_library(get_ort_ext_libs()[0])

sess2 = InferenceSession(
get_model2(itype).SerializeToString(), opts, providers=["CUDAExecutionProvider"]
)
got = sess2.run(None, feeds)[0]

########################################
# Discrepancies

if has_cuda:

diff = np.abs(got - expected).max()
print(f"diff={diff}")


############################################
# Benchmark
# +++++++++
#
# some code to avoid measuring copying the data from host to device


def move_inputs(sess, feeds):
from onnxruntime.capi._pybind_state import (
SessionIOBinding,
OrtDevice as C_OrtDevice,
OrtValue as C_OrtValue,
)

input_names = [i.name for i in sess.get_inputs()]

ort_device = C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0)

feed_ort_value = [
(name, C_OrtValue.ortvalue_from_numpy(feeds[name], ort_device))
for name in input_names
]

bind = SessionIOBinding(sess._sess)
for name, value in feed_ort_value:
bind.bind_input(
name, ort_device, feeds[name].dtype, value.shape(), value.data_ptr()
)
for o in sess.get_outputs():
bind.bind_output(o.name, ort_device)
return bind, feed_ort_value


###################################
# Benchmark function


def benchmark(sess, sizes, label):

data = []
for size in tqdm(sizes):

x = np.random.randn(size, size).astype(dtype)
y = np.random.randn(size, size).astype(dtype)
z = np.random.randn(size, size).astype(dtype)
feeds = dict(X=x, Y=y, Z=z)
bind, cuda_feeds = move_inputs(sess, feeds)

begin = time.perf_counter()
for i in range(script_args.warmup):
# sess.run(None, feeds)
sess._sess.run_with_iobinding(bind, None)
warmup = time.perf_counter() - begin

times = []
for i in range(script_args.repeat):
begin = time.perf_counter()
# sess.run(None, feeds)
sess._sess.run_with_iobinding(bind, None)
times.append(time.perf_counter() - begin)

npt = np.array(times)
obs = dict(
warmup=warmup,
time=npt.mean(),
std=npt.std(),
min=npt.min(),
max=npt.max(),
repeat=script_args.repeat,
size=size,
label=label,
)
data.append(obs)
return data


#######################################
# Not Fused.

if has_cuda:

print(f"sizes={sizes}")

data_mul = benchmark(sess1, sizes, "Not Fused")

#######################################
# Fused.

if has_cuda:

data_mulmul = benchmark(sess2, sizes, "Fused")


##########################################
# Data
# ++++

if has_cuda:

df = DataFrame(data_mul + data_mulmul)
df.to_csv("plot_op_gemm2_cuda.csv", index=False)
df.to_csv("plot_op_gemm2_cuda.xlsx", index=False)
print(df.head())

#####################
# Pivot.

if has_cuda:

pivot = df.pivot(index="size", columns="label", values="time")
pivot["ratio"] = pivot["Fused"] / pivot["Not Fused"]
print(pivot)

ax = pivot[["Not Fused", "Fused"]].plot(
logx=True,
logy=True,
title=f"Fused/Unfused element wise multiplication on CUDA\nitype={itype}",
)
ax.get_figure().savefig("plot_op_gemm2_cuda.png")

##############################
# It seems the fused operator is 33% faster.
25 changes: 21 additions & 4 deletions _unittests/ut_ortops/test_optim_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def test_mulsub_cuda_negative(self):
self._addmul_cuda(TensorProto.FLOAT, "Mul", "Sub", negative=True)
self._addmul_cuda(TensorProto.FLOAT16, "Mul", "Sub", negative=True)

def _rotary_cuda(self, itype, side):
def _rotary_cuda(self, itype, side, input_shape=(3, 2, 3, 4)):
import onnxruntime
from onnx_extended.ortops.optim.cuda import get_ort_ext_libs

Expand Down Expand Up @@ -651,12 +651,12 @@ def _rotary_cuda(self, itype, side):
)

dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
x = (np.arange(18 * 4) + 1).reshape((3, 2, 3, 4)).astype(dtype)
x = (np.arange(np.prod(input_shape)) + 1).reshape(input_shape).astype(dtype)
splits = np.array([x.shape[-1] // 2, x.shape[-1] // 2], dtype=np.int64)

expected = x.copy()
half = x.shape[-1] // 2
if side == "right":
if side == "left":
expected[:, :, :, :half] = x[:, :, :, half:]
expected[:, :, :, half:] = -x[:, :, :, :half]
else:
Expand All @@ -670,15 +670,32 @@ def _rotary_cuda(self, itype, side):
model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]
)
got = sess.run(None, feeds)[0]

# rexp = expected.reshape((-1, expected.shape[-1]))
# rgot = got.reshape((-1, got.shape[-1]))
# for i in range(rgot.shape[0]):
# self.assertEqualArray(
# rexp[i],
# rgot[i],
# msg=f"row {i} is wrong,\nexp={rexp[i]}\ngot={rgot[i]}",
# )
self.assertEqualArray(expected, got)

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_rotary_cuda(self):
self._rotary_cuda(TensorProto.FLOAT, "left")
self._rotary_cuda(TensorProto.FLOAT16, "left")
self._rotary_cuda(TensorProto.FLOAT, "right")
self._rotary_cuda(TensorProto.FLOAT16, "left")
self._rotary_cuda(TensorProto.FLOAT16, "right")

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_bigger_rotary_cuda(self):
sh = (2, 2, 1024, 8)
self._rotary_cuda(TensorProto.FLOAT, "left", input_shape=sh)
self._rotary_cuda(TensorProto.FLOAT, "right", input_shape=sh)
self._rotary_cuda(TensorProto.FLOAT16, "left", input_shape=sh)
self._rotary_cuda(TensorProto.FLOAT16, "right", input_shape=sh)

def _mul_sigmoid_cuda(self, itype):
import onnxruntime
from onnx_extended.ortops.optim.cuda import get_ort_ext_libs
Expand Down
24 changes: 11 additions & 13 deletions onnx_extended/ortops/optim/cuda/rotary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ struct GridDim {
};
};

template <typename T> __device__ __inline__ T _neg(const T x) { return -x; }

#if __CUDA_ARCH__ < 700
template <> __device__ __inline__ half _neg(const half x) {
return __float2half(-__half2float(x));
}
#endif

template <typename T, RotarySide side>
__global__ void _RotaryKernelLeft(T *output_data, const T *input_data, CUDA_LONG half_N,
CUDA_LONG half_stride) {
Expand All @@ -28,23 +36,13 @@ __global__ void _RotaryKernelLeft(T *output_data, const T *input_data, CUDA_LONG
return;
CUDA_LONG last = id % half_stride;
id = (id - last) * 2 + last;
#if __CUDA_ARCH__ < 700
if (side == RotarySide::LEFT) {
if (side == RotarySide::RIGHT) {
output_data[id + half_stride] = input_data[id];
output_data[id] = __float2half(-__half2float(input_data[id + half_stride]));
output_data[id] = _neg(input_data[id + half_stride]);
} else {
output_data[id + half_stride] = __float2half(-__half2float(input_data[id]));
output_data[id + half_stride] = _neg(input_data[id]);
output_data[id] = input_data[id + half_stride];
}
#else
if (side == RotarySide::LEFT) {
output_data[id + half_stride] = input_data[id];
output_data[id] = -input_data[id + half_stride];
} else {
output_data[id + half_stride] = -input_data[id];
output_data[id] = input_data[id + half_stride];
}
#endif
}

template <typename T>
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ line-length = 88
max-complexity = 10

[tool.ruff.lint.per-file-ignores]
"_doc/examples/plot_op_gemm2_cuda.py" = ["E402"]
"_doc/examples/plot_op_mul_cuda.py" = ["E402"]
"_doc/examples/plot_op_scatternd_cuda.py" = ["E402"]
"_doc/examples/plot_op_scatternd_mask_cuda.py" = ["E402"]
Expand Down
Loading