From bba8d594c25327d4901ea5482688692497b1b6fd Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Fri, 17 May 2024 15:44:20 +0200 Subject: [PATCH 1/4] add one example around gemm --- _doc/benchmarks.rst | 7 + _doc/examples/plot_op_gemm2_cuda.py | 270 ++++++++++++++++++++++++++++ pyproject.toml | 1 + 3 files changed, 278 insertions(+) create mode 100644 _doc/examples/plot_op_gemm2_cuda.py diff --git a/_doc/benchmarks.rst b/_doc/benchmarks.rst index a97d8b1d..80647d55 100644 --- a/_doc/benchmarks.rst +++ b/_doc/benchmarks.rst @@ -91,6 +91,13 @@ The benchmark profiles the execution of Gemm for different types and configuration. That includes a custom operator only available on CUDA calling function :epkg:`cublasLtMatmul`. +plot_op_gemm2_cuda +++++++++++++++++++ + +See :ref:`l-example-op-gemm2_cuda`. + +One big Gemm or two smaller gemm. + plot_op_mul_cuda ++++++++++++++++ diff --git a/_doc/examples/plot_op_gemm2_cuda.py b/_doc/examples/plot_op_gemm2_cuda.py new file mode 100644 index 00000000..177e0fb5 --- /dev/null +++ b/_doc/examples/plot_op_gemm2_cuda.py @@ -0,0 +1,270 @@ +""" +.. _l-example-op-gemm2_cuda: + +Gemm Exploration with CUDA +========================== + +One big Gemm or two smaller gemm? + +Cache Performance ++++++++++++++++++ +""" + +from onnx_extended.args import get_parsed_args + +script_args = get_parsed_args( + "plot_op_gemm2_cuda", + description=__doc__, + config=( + "small", + "small, short optimization (default), " + "medium for medium sizes, " + "large for big sizes", + ), + warmup=3, + repeat=5, + itype=(1, "1 or 10 for float or float16"), + expose="config,itype,warmup,repeat", +) + +itype = script_args.itype +config = script_args.config +print(f"config={config}") +print(f"itype={itype}") + +if config == "small": + sizes = (256, 512, 1024) +elif config == "medium": + sizes = (512, 1024, 2048) +elif config == "large": + sizes = (1024, 2048, 4096, 8192) +else: + try: + sizes = list(map(int, config.split(","))) + except (ValueError, TypeError) as e: + raise AssertionError(f"Unexpected config value {config!r}.") from e + +import time +import numpy as np +import onnx.helper as oh +from tqdm import tqdm +from pandas import DataFrame +from onnxruntime import InferenceSession, SessionOptions, get_available_providers +from onnx_array_api.plotting.text_plot import onnx_simple_text_plot +from onnx_extended.ortops.optim.cuda import get_ort_ext_libs + + +def get_model1(itype): + return oh.make_model( + oh.make_graph( + [ + oh.make_node("Gemm", ["X", "Y"], ["XY"]), + oh.make_node("Gemm", ["X", "Z"], ["XZ"]), + oh.make_node("Concat", ["XY", "XZ"], ["XYZ"], axis=1), + ], + "nd", + [ + oh.make_tensor_value_info("X", itype, [None, None]), + oh.make_tensor_value_info("Y", itype, [None, None]), + oh.make_tensor_value_info("Z", itype, [None, None]), + ], + [oh.make_tensor_value_info("XYZ", itype, [None, None])], + ), + opset_imports=[oh.make_opsetid("", 18)], + ir_version=9, + ) + + +print(onnx_simple_text_plot(get_model1(itype))) + + +######################################## +# And the other model + + +def get_model2(itype): + return oh.make_model( + oh.make_graph( + [ + oh.make_node("Concat", ["Y", "Z"], ["YZ"], axis=1), + oh.make_node("Gemm", ["X", "YZ"], ["XYZ"]), + ], + "nd", + [ + oh.make_tensor_value_info("X", itype, [None, None]), + oh.make_tensor_value_info("Y", itype, [None, None]), + oh.make_tensor_value_info("Z", itype, [None, None]), + ], + [oh.make_tensor_value_info("XYZ", itype, [None, None])], + ), + opset_imports=[oh.make_opsetid("", 18)], + ir_version=9, + ) + + +print(onnx_simple_text_plot(get_model2(itype))) + +########################################### +# InferenceSession +# ++++++++++++++++ + +has_cuda = "CUDAExecutionProvider" in get_available_providers() + +if has_cuda: + + dtype = np.float32 if itype == 1 else np.float16 + + x = np.random.randn(16, 16).astype(dtype) + y = np.random.randn(16, 16).astype(dtype) + z = np.random.randn(16, 16).astype(dtype) + feeds = dict(X=x, Y=y, Z=z) + + sess1 = InferenceSession( + get_model1(itype).SerializeToString(), providers=["CUDAExecutionProvider"] + ) + expected = sess1.run(None, feeds)[0] + +######################################### +# The other model. + +if has_cuda: + + opts = SessionOptions() + opts.register_custom_ops_library(get_ort_ext_libs()[0]) + + sess2 = InferenceSession( + get_model2(itype).SerializeToString(), opts, providers=["CUDAExecutionProvider"] + ) + got = sess2.run(None, feeds)[0] + +######################################## +# Discrepancies + +if has_cuda: + + diff = np.abs(got - expected).max() + print(f"diff={diff}") + + +############################################ +# Benchmark +# +++++++++ +# +# some code to avoid measuring copying the data from host to device + + +def move_inputs(sess, feeds): + from onnxruntime.capi._pybind_state import ( + SessionIOBinding, + OrtDevice as C_OrtDevice, + OrtValue as C_OrtValue, + ) + + input_names = [i.name for i in sess.get_inputs()] + + ort_device = C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) + + feed_ort_value = [ + (name, C_OrtValue.ortvalue_from_numpy(feeds[name], ort_device)) + for name in input_names + ] + + bind = SessionIOBinding(sess._sess) + for name, value in feed_ort_value: + bind.bind_input( + name, ort_device, feeds[name].dtype, value.shape(), value.data_ptr() + ) + for o in sess.get_outputs(): + bind.bind_output(o.name, ort_device) + return bind, feed_ort_value + + +################################### +# Benchmark function + + +def benchmark(sess, sizes, label): + + data = [] + for size in tqdm(sizes): + + x = np.random.randn(size, size).astype(dtype) + y = np.random.randn(size, size).astype(dtype) + z = np.random.randn(size, size).astype(dtype) + feeds = dict(X=x, Y=y, Z=z) + bind, cuda_feeds = move_inputs(sess, feeds) + + begin = time.perf_counter() + for i in range(script_args.warmup): + # sess.run(None, feeds) + sess._sess.run_with_iobinding(bind, None) + warmup = time.perf_counter() - begin + + times = [] + for i in range(script_args.repeat): + begin = time.perf_counter() + # sess.run(None, feeds) + sess._sess.run_with_iobinding(bind, None) + times.append(time.perf_counter() - begin) + + npt = np.array(times) + obs = dict( + warmup=warmup, + time=npt.mean(), + std=npt.std(), + min=npt.min(), + max=npt.max(), + repeat=script_args.repeat, + size=size, + label=label, + ) + data.append(obs) + return data + + +####################################### +# Not Fused. + +if has_cuda: + + print(f"sizes={sizes}") + + data_mul = benchmark(sess1, sizes, "Not Fused") + +####################################### +# Fused. + +if has_cuda: + + data_mulmul = benchmark(sess2, sizes, "Fused") + + +########################################## +# Data +# ++++ + +if has_cuda: + + df = DataFrame(data_mul + data_mulmul) + df.to_csv("plot_op_gemm2_cuda.csv", index=False) + df.to_csv("plot_op_gemm2_cuda.xlsx", index=False) + print(df.head()) + +##################### +# Pivot. + +if has_cuda: + + pivot = df.pivot(index="size", columns="label", values="time") + pivot["ratio"] = pivot["Fused"] / pivot["Not Fused"] + print(pivot) + + ax = pivot[["Not Fused", "Fused"]].plot( + logx=True, + logy=True, + title=f"Fused/Unfused element wise multiplication on CUDA\nitype={itype}", + ) + ax.get_figure().savefig("plot_op_gemm2_cuda.png") + +############################## +# It seems the fused operator is 33% faster. diff --git a/pyproject.toml b/pyproject.toml index 10982b0c..d241d9cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -322,6 +322,7 @@ line-length = 88 max-complexity = 10 [tool.ruff.lint.per-file-ignores] +"_doc/examples/plot_op_gemm2_cuda.py" = ["E402"] "_doc/examples/plot_op_mul_cuda.py" = ["E402"] "_doc/examples/plot_op_scatternd_cuda.py" = ["E402"] "_doc/examples/plot_op_scatternd_mask_cuda.py" = ["E402"] From a25ef9b6ee39fbb85d5b9accd5a1b4a91df825e9 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sat, 18 May 2024 12:25:37 +0200 Subject: [PATCH 2/4] fix issues --- _unittests/ut_ortops/test_optim_cuda.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/_unittests/ut_ortops/test_optim_cuda.py b/_unittests/ut_ortops/test_optim_cuda.py index 1b73db25..da36e10d 100644 --- a/_unittests/ut_ortops/test_optim_cuda.py +++ b/_unittests/ut_ortops/test_optim_cuda.py @@ -621,7 +621,7 @@ def test_mulsub_cuda_negative(self): self._addmul_cuda(TensorProto.FLOAT, "Mul", "Sub", negative=True) self._addmul_cuda(TensorProto.FLOAT16, "Mul", "Sub", negative=True) - def _rotary_cuda(self, itype, side): + def _rotary_cuda(self, itype, side, input_shape=(3, 2, 3, 4)): import onnxruntime from onnx_extended.ortops.optim.cuda import get_ort_ext_libs @@ -651,7 +651,7 @@ def _rotary_cuda(self, itype, side): ) dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 - x = (np.arange(18 * 4) + 1).reshape((3, 2, 3, 4)).astype(dtype) + x = (np.arange(np.prod(input_shape)) + 1).reshape(input_shape).astype(dtype) splits = np.array([x.shape[-1] // 2, x.shape[-1] // 2], dtype=np.int64) expected = x.copy() @@ -670,15 +670,33 @@ def _rotary_cuda(self, itype, side): model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"] ) got = sess.run(None, feeds)[0] + + rexp = expected.reshape((-1, expected.shape[-1])) + rgot = got.reshape((-1, got.shape[-1])) + print(expected.shape, rexp.shape, rgot.shape) + for i in range(rgot.shape[0]): + self.assertEqualArray( + rexp[i], + rgot[i], + msg=f"row {i} is wrong,\nexp={rexp[i]}\ngot={rgot[i]}", + ) self.assertEqualArray(expected, got) @unittest.skipIf(not has_cuda(), reason="cuda not available") def test_rotary_cuda(self): self._rotary_cuda(TensorProto.FLOAT, "left") - self._rotary_cuda(TensorProto.FLOAT16, "left") self._rotary_cuda(TensorProto.FLOAT, "right") + self._rotary_cuda(TensorProto.FLOAT16, "left") self._rotary_cuda(TensorProto.FLOAT16, "right") + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_bigger_rotary_cuda(self): + sh = (2, 2, 1024, 8) + self._rotary_cuda(TensorProto.FLOAT, "left", input_shape=sh) + self._rotary_cuda(TensorProto.FLOAT, "right", input_shape=sh) + self._rotary_cuda(TensorProto.FLOAT16, "left", input_shape=sh) + self._rotary_cuda(TensorProto.FLOAT16, "right", input_shape=sh) + def _mul_sigmoid_cuda(self, itype): import onnxruntime from onnx_extended.ortops.optim.cuda import get_ort_ext_libs From b066517fe8035ac026c7e8f7ac372944b52916ee Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sat, 18 May 2024 12:32:53 +0200 Subject: [PATCH 3/4] fix implementation for lower architecture --- _unittests/ut_ortops/test_optim_cuda.py | 17 ++++++++--------- onnx_extended/ortops/optim/cuda/rotary.cu | 22 ++++++++++------------ 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/_unittests/ut_ortops/test_optim_cuda.py b/_unittests/ut_ortops/test_optim_cuda.py index da36e10d..088fefeb 100644 --- a/_unittests/ut_ortops/test_optim_cuda.py +++ b/_unittests/ut_ortops/test_optim_cuda.py @@ -671,15 +671,14 @@ def _rotary_cuda(self, itype, side, input_shape=(3, 2, 3, 4)): ) got = sess.run(None, feeds)[0] - rexp = expected.reshape((-1, expected.shape[-1])) - rgot = got.reshape((-1, got.shape[-1])) - print(expected.shape, rexp.shape, rgot.shape) - for i in range(rgot.shape[0]): - self.assertEqualArray( - rexp[i], - rgot[i], - msg=f"row {i} is wrong,\nexp={rexp[i]}\ngot={rgot[i]}", - ) + # rexp = expected.reshape((-1, expected.shape[-1])) + # rgot = got.reshape((-1, got.shape[-1])) + # for i in range(rgot.shape[0]): + # self.assertEqualArray( + # rexp[i], + # rgot[i], + # msg=f"row {i} is wrong,\nexp={rexp[i]}\ngot={rgot[i]}", + # ) self.assertEqualArray(expected, got) @unittest.skipIf(not has_cuda(), reason="cuda not available") diff --git a/onnx_extended/ortops/optim/cuda/rotary.cu b/onnx_extended/ortops/optim/cuda/rotary.cu index 47c7f5e5..c64b432a 100644 --- a/onnx_extended/ortops/optim/cuda/rotary.cu +++ b/onnx_extended/ortops/optim/cuda/rotary.cu @@ -20,6 +20,14 @@ struct GridDim { }; }; +template __device__ __inline__ T _neg(const T x) { return -x; } + +#if __CUDA_ARCH__ < 700 +template <> __device__ __inline__ half _neg(const half x) { + return __float2half(-__half2float(x)); +} +#endif + template __global__ void _RotaryKernelLeft(T *output_data, const T *input_data, CUDA_LONG half_N, CUDA_LONG half_stride) { @@ -28,23 +36,13 @@ __global__ void _RotaryKernelLeft(T *output_data, const T *input_data, CUDA_LONG return; CUDA_LONG last = id % half_stride; id = (id - last) * 2 + last; -#if __CUDA_ARCH__ < 700 if (side == RotarySide::LEFT) { output_data[id + half_stride] = input_data[id]; - output_data[id] = __float2half(-__half2float(input_data[id + half_stride])); + output_data[id] = _neg(input_data[id + half_stride]); } else { - output_data[id + half_stride] = __float2half(-__half2float(input_data[id])); + output_data[id + half_stride] = _neg(input_data[id]); output_data[id] = input_data[id + half_stride]; } -#else - if (side == RotarySide::LEFT) { - output_data[id + half_stride] = input_data[id]; - output_data[id] = -input_data[id + half_stride]; - } else { - output_data[id + half_stride] = -input_data[id]; - output_data[id] = input_data[id + half_stride]; - } -#endif } template From fe94e853bafec385d369474580df7ec8d60cd800 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 19 May 2024 11:16:26 +0200 Subject: [PATCH 4/4] fix rotary --- _unittests/ut_ortops/test_optim_cuda.py | 2 +- onnx_extended/ortops/optim/cuda/rotary.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/_unittests/ut_ortops/test_optim_cuda.py b/_unittests/ut_ortops/test_optim_cuda.py index 088fefeb..47d6342b 100644 --- a/_unittests/ut_ortops/test_optim_cuda.py +++ b/_unittests/ut_ortops/test_optim_cuda.py @@ -656,7 +656,7 @@ def _rotary_cuda(self, itype, side, input_shape=(3, 2, 3, 4)): expected = x.copy() half = x.shape[-1] // 2 - if side == "right": + if side == "left": expected[:, :, :, :half] = x[:, :, :, half:] expected[:, :, :, half:] = -x[:, :, :, :half] else: diff --git a/onnx_extended/ortops/optim/cuda/rotary.cu b/onnx_extended/ortops/optim/cuda/rotary.cu index c64b432a..e71fce71 100644 --- a/onnx_extended/ortops/optim/cuda/rotary.cu +++ b/onnx_extended/ortops/optim/cuda/rotary.cu @@ -36,7 +36,7 @@ __global__ void _RotaryKernelLeft(T *output_data, const T *input_data, CUDA_LONG return; CUDA_LONG last = id % half_stride; id = (id - last) * 2 + last; - if (side == RotarySide::LEFT) { + if (side == RotarySide::RIGHT) { output_data[id + half_stride] = input_data[id]; output_data[id] = _neg(input_data[id + half_stride]); } else {