diff --git a/intermediate_source/inductor_debug_cpu.py b/intermediate_source/inductor_debug_cpu.py index 40bc736334..d418deda18 100644 --- a/intermediate_source/inductor_debug_cpu.py +++ b/intermediate_source/inductor_debug_cpu.py @@ -313,65 +313,66 @@ def call(args): # --------------------- # # For this part, we will describe how to analyze the inductor model performance. -# Firsly, we choose an eager model as a baseline. We set up a benchmark to compare the end to end performance between eager model and inductor model. +# First, we choose an eager model as a baseline. We set up a benchmark to compare the end-to-end performance between the eager model and the inductor model. -from transformers import T5ForConditionalGeneration +from transformers import MobileBertForQuestionAnswering +import torch # init an eager model -eager_model = T5ForConditionalGeneration.from_pretrained("t5-small") -seq_length = 1024 -bs = 4 +model = MobileBertForQuestionAnswering.from_pretrained("csarron/mobilebert-uncased-squad-v2") +seq_length = 128 +bs = 128 vocab_size = model.config.vocab_size input = torch.randint(0, vocab_size, (bs, seq_length), dtype=torch.int64) input_dict = {"input_ids": input} -input_dict["decoder_input_ids"] = input + # init inductor model inductor_model = torch.compile(model) -compiled(**input_dict) -eager_t = 0 -inductor_t = 0 -for _ in range(100): - model(**input_dict) -for _ in range(1000): - eager_start = time.time() - model(**input_dict) - eager_end = time.time() - eager_t += eager_end - eager_start - -for _ in range(100): - model(**input_dict) -for _ in range(1000): - inductor_start = time.time() - compiled(**input_dict) - inductor_end = time.time() - inductor_t += inductor_end - inductor_start - -print(model.__class__) -print("eager use:", eager_t) -print("inductor use:", inductor_t) -print("ratio:", eager_t / inductor_t) +with torch.no_grad(): + inductor_model(**input_dict) + +NUM_ITERS=100 +import timeit +with torch.no_grad(): + # warmup + for _ in range(10): + model(**input_dict) + eager_t = timeit.timeit("model(**input_dict)", number=NUM_ITERS, globals=globals()) + +with torch.no_grad(): + # warmup + for _ in range(10): + inductor_model(**input_dict) + inductor_t = timeit.timeit("inductor_model(**input_dict)", number=NUM_ITERS, globals=globals()) +print(f"eager use: {eager_t * 1000 / NUM_ITERS} ms/iter") +print(f"inductor use: {inductor_t * 1000 / NUM_ITERS} ms/iter") +print(f"speed up ratio: {eager_t / inductor_t}") + ###################################################################### # Output: # # .. code-block:: shell # -# eager use: 410.12550354003906 -# inductor use: 478.59081745147705 -# ratio: 0.8569439458198976 +# eager use: 802.1023553796113 ms/iter +# inductor use: 339.95180135127157 ms/iter +# speed up ratio: 2.359459053287382 # -# We see that the inductor model execution time is longer than the eager model, which does not meet our expectation. -# To deep dive op-level performance, we can use `Pytorch Profiler `_ +# The inductor model speed-up is 2.58x. # -# To enable kernel profile in inductor, we need set ``enable_kernel_profile`` by: +# +# Secondly, we can deep dive into op-level performance to understand where is the speed-up comes from. +# `Pytorch Profiler `_ is a good tool to help us. +# To enable kernel profile with inductor model, we need to set ``enable_kernel_profile`` by: from torch._inductor import config config.cpp.enable_kernel_profile = True ###################################################################### # Following the steps in `Pytorch Profiler `_ -# we are able to get the profiling table and trace files. +# We are able to get the profiling table and trace files. from torch.profiler import profile, schedule, ProfilerActivity +RESULT_DIR = "./prof_trace" my_schedule = schedule( skip_first=10, wait=5, @@ -382,10 +383,10 @@ def call(args): def trace_handler(p): output = p.key_averages().table(sort_by="self_cpu_time_total", row_limit=20) print(output) - p.export_chrome_trace(RESULT_DIR + "/" + str(p.step_num) + ".json") + p.export_chrome_trace(f"{RESULT_DIR}/{p.step_num}.json") -for _ in range(nwarmup): - model(**input_dict) +for _ in range(10): + model(**input_dict) # inductor_model(**input_dict) to get inductor model profiling total = 0 with profile( @@ -394,210 +395,109 @@ def trace_handler(p): on_trace_ready=trace_handler ) as p: for _ in range(100): - begin = time.time() - model(**input_dict) - end=time.time() - total += (end - begin) + model(**input_dict) # inductor_model(**input_dict) to get inductor model profiling p.step() -print("latency: {} ms".format(1000*(total)/100)) ###################################################################### -# We will get the following profile tables for eager model: +# We will get the following profile table for the eager model: # # .. code-block:: shell # -# ----------------------- ------------ ------------ ------------ ------------ ------------ ------------ -# Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls -# ----------------------- ------------ ------------ ------------ ------------ ------------ ------------ -# aten::mm 33.33% 138.616ms 33.33% 138.616ms 1.429ms 97 -# aten::add_ 19.38% 80.596ms 19.38% 80.596ms 4.242ms 19 -# aten::bmm 18.78% 78.104ms 18.78% 78.104ms 2.170ms 36 -# aten::_softmax 11.32% 47.082ms 11.32% 47.082ms 2.616ms 18 -# aten::copy_ 3.89% 16.190ms 3.89% 16.190ms 103.121us 157 -# ProfilerStep* 3.53% 14.702ms 100.00% 415.949ms 415.949ms 1 -# aten::add 2.37% 9.849ms 2.39% 9.958ms 144.319us 69 -# aten::mul 1.13% 4.693ms 1.14% 4.726ms 65.639us 72 -# aten::clamp_min 0.85% 3.541ms 0.85% 3.541ms 295.083us 12 -# aten::index_select 0.84% 3.480ms 1.06% 4.401ms 1.100ms 4 -# aten::linear 0.63% 2.637ms 33.95% 141.194ms 1.456ms 97 -# aten::pow 0.61% 2.520ms 0.61% 2.554ms 79.812us 32 -# aten::matmul 0.50% 2.067ms 56.53% 235.132ms 1.768ms 133 -# aten::select 0.22% 900.000us 0.22% 910.000us 113.750us 8 -# aten::log 0.18% 740.000us 0.18% 740.000us 370.000us 2 -# aten::_unsafe_view 0.17% 718.000us 0.17% 718.000us 3.840us 187 -# aten::sum 0.17% 715.000us 0.20% 831.000us 25.969us 32 -# aten::transpose 0.15% 642.000us 0.18% 741.000us 3.963us 187 -# aten::reshape 0.15% 622.000us 3.66% 15.241ms 88.098us 173 -# aten::fill_ 0.15% 613.000us 0.15% 613.000us 15.718us 39 -# ----------------------- ------------ ------------ ------------ ------------ ------------ ------------ -# Self CPU time total: 415.949ms +# ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ +# Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls +# ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ +# aten::addmm 33.36% 270.520ms 45.73% 370.814ms 1.024ms 362 +# aten::add 19.89% 161.276ms 19.89% 161.276ms 444.287us 363 +# aten::copy_ 14.97% 121.416ms 14.97% 121.416ms 248.803us 488 +# aten::mul 9.02% 73.151ms 9.02% 73.154ms 377.082us 194 +# aten::clamp_min 8.81% 71.444ms 8.81% 71.444ms 744.208us 96 +# aten::bmm 5.46% 44.258ms 5.46% 44.258ms 922.042us 48 +# ProfilerStep* 3.00% 24.362ms 100.00% 810.920ms 810.920ms 1 +# aten::div 2.85% 23.071ms 2.89% 23.447ms 976.958us 24 +# aten::_softmax 1.00% 8.087ms 1.00% 8.087ms 336.958us 24 +# aten::linear 0.32% 2.624ms 46.48% 376.888ms 1.041ms 362 +# aten::clone 0.23% 1.859ms 2.77% 22.430ms 228.878us 98 +# aten::t 0.14% 1.162ms 0.31% 2.502ms 6.912us 362 +# aten::view 0.14% 1.161ms 0.14% 1.161ms 1.366us 850 +# aten::transpose 0.12% 938.000us 0.17% 1.377ms 3.567us 386 +# aten::index_select 0.12% 933.000us 0.12% 952.000us 317.333us 3 +# aten::expand 0.11% 865.000us 0.12% 986.000us 2.153us 458 +# aten::matmul 0.10% 808.000us 8.31% 67.420ms 1.405ms 48 +# aten::cat 0.09% 701.000us 0.09% 703.000us 703.000us 1 +# aten::as_strided 0.08% 656.000us 0.08% 656.000us 0.681us 963 +# aten::relu 0.05% 420.000us 8.86% 71.864ms 748.583us 96 +# ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ +# Self CPU time total: 810.920ms # # Similarly, get the table for the inductor model: # # .. code-block:: shell # -# -------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ -# Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls -# -------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ -# mkl::_mkl_linear 28.24% 133.979ms 28.39% 134.689ms 1.389ms 97 -# aten::bmm 15.65% 74.250ms 15.65% 74.251ms 2.063ms 36 -# graph_0_cpp_fused__softmax_7 4.24% 20.123ms 4.24% 20.123ms 20.123ms 1 -# graph_0_cpp_fused__softmax_42 4.17% 19.773ms 4.17% 19.773ms 19.773ms 1 -# graph_0_cpp_fused__softmax_35 4.16% 19.751ms 4.16% 19.751ms 19.751ms 1 -# graph_0_cpp_fused__softmax_21 4.15% 19.674ms 4.15% 19.674ms 19.674ms 1 -# graph_0_cpp_fused__softmax_14 4.14% 19.654ms 4.14% 19.654ms 19.654ms 1 -# graph_0_cpp_fused__softmax_28 4.13% 19.576ms 4.13% 19.576ms 19.576ms 1 -# graph_0_cpp_fused__softmax_56 2.83% 13.404ms 2.83% 13.404ms 13.404ms 1 -# graph_0_cpp_fused__softmax_80 2.82% 13.371ms 2.82% 13.371ms 13.371ms 1 -# graph_0_cpp_fused__softmax_68 2.81% 13.323ms 2.81% 13.323ms 13.323ms 1 -# graph_0_cpp_fused__softmax_92 2.80% 13.297ms 2.80% 13.297ms 13.297ms 1 -# graph_0_cpp_fused__softmax_104 2.78% 13.208ms 2.78% 13.208ms 13.208ms 1 -# graph_0_cpp_fused__softmax_2 2.63% 12.468ms 2.63% 12.468ms 12.468ms 1 -# ProfilerStep* 1.61% 7.616ms 100.00% 474.360ms 474.360ms 1 -# graph_0_cpp_fused__softmax_73 0.49% 2.320ms 0.49% 2.320ms 2.320ms 1 -# graph_0_cpp_fused__softmax_85 0.49% 2.309ms 0.49% 2.309ms 2.309ms 1 -# graph_0_cpp_fused__softmax_97 0.48% 2.283ms 0.48% 2.283ms 2.283ms 1 -# graph_0_cpp_fused__softmax_61 0.48% 2.268ms 0.48% 2.268ms 2.268ms 1 -# graph_0_cpp_fused__softmax_49 0.48% 2.255ms 0.48% 2.255ms 2.255ms 1 -# -------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ -# Self CPU time total: 474.360ms -# -# We can search the most time consuming ``graph_0_cpp_fused__softmax_7`` in ``output_code.py`` to see the generated code: - - -cpp_fused__softmax_7 = async_compile.cpp(''' +# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ +# Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls +# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ +# mkl::_mkl_linear 68.52% 230.662ms 68.79% 231.573ms 639.704us 362 +# aten::bmm 8.02% 26.991ms 8.02% 26.992ms 562.333us 48 +# ProfilerStep* 3.35% 11.292ms 100.00% 336.642ms 336.642ms 1 +# graph_0_cpp_fused_constant_pad_nd_embedding_0 0.27% 915.000us 0.27% 915.000us 915.000us 1 +# aten::empty 0.27% 911.000us 0.27% 911.000us 2.517us 362 +# graph_0_cpp_fused__mkl_linear_add_mul_relu_151 0.27% 901.000us 0.27% 901.000us 901.000us 1 +# graph_0_cpp_fused__mkl_linear_add_mul_relu_226 0.27% 899.000us 0.27% 899.000us 899.000us 1 +# graph_0_cpp_fused__mkl_linear_add_mul_relu_361 0.27% 898.000us 0.27% 898.000us 898.000us 1 +# graph_0_cpp_fused__mkl_linear_add_mul_relu_121 0.27% 895.000us 0.27% 895.000us 895.000us 1 +# graph_0_cpp_fused__mkl_linear_add_mul_relu_31 0.27% 893.000us 0.27% 893.000us 893.000us 1 +# graph_0_cpp_fused__mkl_linear_add_mul_relu_76 0.26% 892.000us 0.26% 892.000us 892.000us 1 +# graph_0_cpp_fused__mkl_linear_add_mul_relu_256 0.26% 892.000us 0.26% 892.000us 892.000us 1 +# graph_0_cpp_fused__mkl_linear_add_mul_relu_346 0.26% 892.000us 0.26% 892.000us 892.000us 1 +# graph_0_cpp_fused__mkl_linear_add_mul_relu_241 0.26% 891.000us 0.26% 891.000us 891.000us 1 +# graph_0_cpp_fused__mkl_linear_add_mul_relu_316 0.26% 891.000us 0.26% 891.000us 891.000us 1 +# graph_0_cpp_fused__mkl_linear_add_mul_relu_91 0.26% 890.000us 0.26% 890.000us 890.000us 1 +# graph_0_cpp_fused__mkl_linear_add_mul_relu_106 0.26% 890.000us 0.26% 890.000us 890.000us 1 +# graph_0_cpp_fused__mkl_linear_add_mul_relu_211 0.26% 890.000us 0.26% 890.000us 890.000us 1 +# graph_0_cpp_fused__mkl_linear_add_mul_relu_61 0.26% 889.000us 0.26% 889.000us 889.000us 1 +# graph_0_cpp_fused__mkl_linear_add_mul_relu_286 0.26% 889.000us 0.26% 889.000us 889.000us 1 +# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ +# Self CPU time total: 336.642ms +# +# From the profiling table of the eager model, we can see the most time consumption ops are [aten::addmm, aten::add, aten::copy_, aten::mul, aten::clamp_min, aten::bmm]. +# Comparing with the inductor model profiling table, we notice there are ``mkl::_mkl_linear`` and fused kernel called ``graph_0_cpp_fused_*``. They are the major +# optimization that the inductor model is doing. Let us discuss them separately. +# (1) Regard to ``mkl::_mkl_linear```: You may notice the number of calls to this kernel is 362, which is exactly the same as ``aten::linear``` in the eager model profiling table. +# The CPU total of ``aten::linear`` is 376.888ms, at the mean time it is 231.573ms for ``mkl::_mkl_linear``. This suggests inductor model speed up ~1.63x for the "linear" part. +# (2) Regarding non-linear part: The end-to-end latency for the eager/inductor model is 802/339ms. The speed up for the non-linear part is ~3.94x. +# Let's read the generated code to understand how the inductor achieves this impressive optimization. You are able to find the generated code by +# searching ``cpp_fused__mkl_linear_add_mul_relu_151`` in ``output_code.py`` +# + + +cpp_fused__mkl_linear_add_mul_relu_151 = async_compile.cpp(''' #include -#include "/tmp/torchinductor_root/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h" +#include "/tmp/torchinductor_root/lr/clrlgu27q4ggd472umdzwsu6qcpqxcuusjxqvx2hwitjbujiiz7z.h" extern "C" void kernel(float* in_out_ptr0, - const float* in_ptr1, - float* out_ptr0, - float* out_ptr1) + const float* in_ptr0, + const float* in_ptr1, + const float* in_ptr2, + const float* in_ptr3) { - RECORD_FUNCTION("graph_0_cpp_fused__softmax_7", c10::ArrayRef({})); - auto in_ptr0 = in_out_ptr0; + RECORD_FUNCTION("graph_0_cpp_fused__mkl_linear_add_mul_relu_151", c10::ArrayRef({})); #pragma omp parallel num_threads(32) { - { - #pragma omp for collapse(2) - for(long i0=static_cast(0L); i0(4L); i0+=static_cast(1L)) - { - for(long i1=static_cast(0L); i1(8L); i1+=static_cast(1L)) - { - #pragma GCC ivdep - for(long i2=static_cast(0L); i2(1024L); i2+=static_cast(1L)) - { - { - float tmp_acc0 = -std::numeric_limits::infinity(); - for(long i3=static_cast(0L); i3(1024L); i3+=static_cast(1L)) - { - auto tmp0 = in_ptr0[static_cast(i3 + (1024L*i2) + (1048576L*i1) + (8388608L*i0))]; - auto tmp1 = static_cast(i3 + ((-1L)*i2)); - auto tmp2 = static_cast(0); - auto tmp3 = tmp1 > tmp2; - auto tmp4 = static_cast(tmp3); - auto tmp5 = static_cast(16); - auto tmp6 = decltype(tmp4)(tmp4 * tmp5); - auto tmp7 = tmp6 + tmp2; - auto tmp8 = std::abs(tmp1); - auto tmp9 = static_cast(8); - auto tmp10 = tmp8 < tmp9; - auto tmp11 = static_cast(tmp8); - auto tmp12 = static_cast(8.0); - auto tmp13 = tmp11 / tmp12; - auto tmp14 = std::log(tmp13); - auto tmp15 = static_cast(2.772588722239781); - auto tmp16 = tmp14 / tmp15; - auto tmp17 = decltype(tmp16)(tmp16 * tmp12); - auto tmp18 = static_cast(tmp17); - auto tmp19 = tmp18 + tmp9; - auto tmp20 = static_cast(15); - auto tmp21 = min_propagate_nan(tmp19, tmp20); - auto tmp22 = tmp10 ? tmp8 : tmp21; - auto tmp23 = tmp7 + tmp22; - auto tmp24 = in_ptr1[static_cast(i1 + (8L*tmp23))]; - auto tmp25 = static_cast(0.0); - auto tmp26 = tmp24 + tmp25; - auto tmp27 = tmp0 + tmp26; - tmp_acc0 = max_propagate_nan(tmp_acc0, tmp27); - } - out_ptr0[static_cast(i2 + (1024L*i1) + (8192L*i0))] = tmp_acc0; - } - } - } - } - } - { - #pragma omp for collapse(2) - for(long i0=static_cast(0L); i0(4L); i0+=static_cast(1L)) - { - for(long i1=static_cast(0L); i1(8L); i1+=static_cast(1L)) - { - #pragma GCC ivdep - for(long i2=static_cast(0L); i2(1024L); i2+=static_cast(1L)) - { - #pragma GCC ivdep - for(long i3=static_cast(0L); i3(1024L); i3+=static_cast(1L)) - { - auto tmp0 = in_out_ptr0[static_cast(i3 + (1024L*i2) + (1048576L*i1) + (8388608L*i0))]; - auto tmp28 = out_ptr0[static_cast(i2 + (1024L*i1) + (8192L*i0))]; - auto tmp1 = static_cast(i3 + ((-1L)*i2)); - auto tmp2 = static_cast(0); - auto tmp3 = tmp1 > tmp2; - auto tmp4 = static_cast(tmp3); - auto tmp5 = static_cast(16); - auto tmp6 = decltype(tmp4)(tmp4 * tmp5); - auto tmp7 = tmp6 + tmp2; - auto tmp8 = std::abs(tmp1); - auto tmp9 = static_cast(8); - auto tmp10 = tmp8 < tmp9; - auto tmp11 = static_cast(tmp8); - auto tmp12 = static_cast(8.0); - auto tmp13 = tmp11 / tmp12; - auto tmp14 = std::log(tmp13); - auto tmp15 = static_cast(2.772588722239781); - auto tmp16 = tmp14 / tmp15; - auto tmp17 = decltype(tmp16)(tmp16 * tmp12); - auto tmp18 = static_cast(tmp17); - auto tmp19 = tmp18 + tmp9; - auto tmp20 = static_cast(15); - auto tmp21 = min_propagate_nan(tmp19, tmp20); - auto tmp22 = tmp10 ? tmp8 : tmp21; - auto tmp23 = tmp7 + tmp22; - auto tmp24 = in_ptr1[static_cast(i1 + (8L*tmp23))]; - auto tmp25 = static_cast(0.0); - auto tmp26 = tmp24 + tmp25; - auto tmp27 = tmp0 + tmp26; - auto tmp29 = tmp27 - tmp28; - in_out_ptr0[static_cast(i3 + (1024L*i2) + (1048576L*i1) + (8388608L*i0))] = tmp29; - } - } - } - } - } { #pragma omp for - for(long i0=static_cast(0L); i0(33554432L); i0+=static_cast(16L)) - { - auto tmp0 = at::vec::Vectorized::loadu(in_out_ptr0 + static_cast(i0)); - auto tmp1 = tmp0.exp(); - tmp1.store(in_out_ptr0 + static_cast(i0)); - } - } - { - #pragma omp for - for(long i0=static_cast(0L); i0(32768L); i0+=static_cast(1L)) + for(long i0=static_cast(0L); i0(16384L); i0+=static_cast(1L)) { + for(long i1=static_cast(0L); i1(512L); i1+=static_cast(8L)) { - #pragma omp declare reduction(+:at::vec::Vectorized:omp_out += omp_in) initializer(omp_priv={{0}}) - float tmp_acc0 = 0; - auto tmp_acc0_vec = at::vec::Vectorized(tmp_acc0); - for(long i1=static_cast(0L); i1(1024L); i1+=static_cast(16L)) - { - auto tmp0 = at::vec::Vectorized::loadu(in_out_ptr0 + static_cast(i1 + (1024L*i0))); - tmp_acc0_vec += tmp0; - } - tmp_acc0 += at::vec::vec_reduce_all([](at::vec::Vectorized& x, at::vec::Vectorized&y) {return x + y;}, tmp_acc0_vec); - out_ptr1[static_cast(i0)] = tmp_acc0; + auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + static_cast(i1 + (512L*i0))); + auto tmp1 = at::vec::Vectorized::loadu(in_ptr1 + static_cast(i1)); + auto tmp3 = at::vec::Vectorized::loadu(in_out_ptr0 + static_cast(i1 + (512L*i0))); + auto tmp5 = at::vec::Vectorized::loadu(in_ptr2 + static_cast(i1)); + auto tmp7 = at::vec::Vectorized::loadu(in_ptr3 + static_cast(i1)); + auto tmp2 = tmp0 + tmp1; + auto tmp4 = tmp2 + tmp3; + auto tmp6 = tmp4 * tmp5; + auto tmp8 = tmp6 + tmp7; + tmp8.store(in_out_ptr0 + static_cast(i1 + (512L*i0))); } } } @@ -606,9 +506,57 @@ def trace_handler(p): ''') ###################################################################### -# With the kernel name ``cpp_fused__softmax_*`` and considering the profile -# results together, we may suspect the generated code for ``softmax`` is -# inefficient. We encourage you to report an issue with all you findings above. +# From the generated code above, we can see this kernel has done a typical `Loop Fusion `_ on [add, add, mul, add]. +# We can infer the sizes and stride of the inputs and further bench this [add, add, mul, add] pattern. + +import torch +def func(x0, x1, x3, x5, x7): + x2 = x0 + x1 + x4 = x2 + x3 + x6 = x4 * x5 + x8 = x6 + x7 + x3 = x8 + return x3 + +x0 = torch.rand(16384, 512) +x1 = torch.rand(1, 512) +x3 = torch.zeros(16384, 512) +x5 = torch.rand(1, 512) +x7 = torch.rand(1, 512) + +input = (x0, x1, x3, x5, x7) +inductor_func = torch.compile(func) +with torch.no_grad(): + inductor_func(*input) + +import timeit +NUM_ITERS=1000 +with torch.no_grad(): + # warmup + for _ in range(10): + func(*input) + eager_t = timeit.timeit("func(*input)", number=NUM_ITERS, globals=globals()) + +with torch.no_grad(): + # warmup + for _ in range(10): + inductor_func(*input) + inductor_t = timeit.timeit("inductor_func(*input)", number=NUM_ITERS, globals=globals()) +print(f"eager use: {eager_t * 1000 / NUM_ITERS} ms/iter") +print(f"inductor use: {inductor_t * 1000 / NUM_ITERS} ms/iter") +print(f"speed up ratio: {eager_t / inductor_t}") +###################################################################### +# Output: +# +# .. code-block:: shell +# +# eager use: 5.780875144992024 ms/iter +# inductor use: 0.9588955780491233 ms/iter +# speed up ratio: 6.0286805751604735 + + +# This is just an example. The profiling table shows all element-wise op are fused within the inductor automatically in this model. You can read more kernels in +# `output_code.py` ######################################################################