From 810d236d21c8a1b56c75ecb58980ab465d2131c0 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Tue, 12 Oct 2021 01:30:23 +0100 Subject: [PATCH] [TIR][USMP] Add a parallel to serial for loop converter pass (#8469) * [TIR][USMP] Add a parallel to serial for loop converter pass This is an optional pass to convert all parallel for loops in TIR to serial ones for different reasons such as executor does not support parallel launch of for loops (e.g., AoT) or allocating space for parallel for loops might not be desired. * Additionally adding FFI scaffolding for USMP Change-Id: Id5e8ccb90140d2d3ae113b20a3ca152a54497c45 * [TIR][USMP] Add a parallel to serial for loop converter pass * remove unused import Change-Id: I29d5fdec92120418596f9dba1d6630f65620a603 * [TIR][USMP] Add a parallel to serial for loop converter pass *moved the pass to tir namespace Change-Id: I74720ca2f566066b3a4f22f504d8f0f684c99dc2 * [TIR][USMP] Add a parallel to serial for loop converter pass * fixed docstring Change-Id: I73bb9867fe2ed6a86f65666493c5c6e3edf87b49 * [TIR][USMP] Add a parallel to serial for loop converter pass * fixed mypy lint error Change-Id: I226ef27d5536674fbe4b2d2c6ff47b8cb3b41431 --- include/tvm/tir/transform.h | 9 +++ python/tvm/tir/transform/transform.py | 11 +++ .../transforms/convert_for_loops_serial.cc | 75 +++++++++++++++++++ ..._tir_transform_convert_for_loops_serial.py | 62 +++++++++++++++ 4 files changed, 157 insertions(+) create mode 100644 src/tir/transforms/convert_for_loops_serial.cc create mode 100644 tests/python/unittest/test_tir_transform_convert_for_loops_serial.py diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index e94b966bc0fc..017078bd7bf7 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -463,6 +463,15 @@ TVM_DLL Pass UnifyThreadBinding(); */ TVM_DLL Pass MergeDynamicSharedMemoryAllocations(); +/*! + * \brief This pass is post-scheduling pass to convert all + * Parallel For loops to Serial ones. This is run + * to attain lesser memory and/or executor/backend + * does not support parallel launch of For loops. + * \return The pass. + */ +TVM_DLL Pass ConvertForLoopsToSerial(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index f072f6b38a43..1abba77a801f 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -715,3 +715,14 @@ def MergeDynamicSharedMemoryAllocations(): The result pass """ return _ffi_api.MergeDynamicSharedMemoryAllocations() # type: ignore + + +def ConvertForLoopsToSerial(): + """Convert Parallel For Loops to Serial For Loops. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.ConvertForLoopsToSerial() # type: ignore diff --git a/src/tir/transforms/convert_for_loops_serial.cc b/src/tir/transforms/convert_for_loops_serial.cc new file mode 100644 index 000000000000..d01ae8a45113 --- /dev/null +++ b/src/tir/transforms/convert_for_loops_serial.cc @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/transforms/convert_for_loops_serial.cc + * \brief Convert all for loops to serial for lesser memory consumption + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +class ForLoopSerialConverter : public StmtExprMutator { + public: + ForLoopSerialConverter() = default; + Stmt operator()(const PrimFunc& func); + + private: + Stmt VisitStmt_(const ForNode* op) override; +}; + +Stmt ForLoopSerialConverter::VisitStmt_(const ForNode* op) { + if (op->kind == ForKind::kParallel) { + return For(op->loop_var, op->min, op->extent, ForKind::kSerial, op->body, op->thread_binding, + op->annotations, op->span); + } + return StmtExprMutator::VisitStmt_(op); +} + +Stmt ForLoopSerialConverter::operator()(const PrimFunc& func) { + return this->VisitStmt(func->body); +} + +PrimFunc ConvertForLoopsToSerial(PrimFunc func) { + PrimFuncNode* fptr = func.CopyOnWrite(); + fptr->body = ForLoopSerialConverter()(func); + return func; +} + +namespace transform { + +Pass ConvertForLoopsToSerial() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return ConvertForLoopsToSerial(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.ConvertForLoopsToSerial", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.ConvertForLoopsToSerial") + .set_body_typed(ConvertForLoopsToSerial); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py new file mode 100644 index 000000000000..272e0d45410f --- /dev/null +++ b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import tir, script +from tvm.script import ty +from tvm.tir import stmt_functor + +# fmt: off +@tvm.script.tir +def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: ty.handle, placeholder_31: ty.handle, placeholder_32: ty.handle, T_cast_8: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True}) + placeholder_33 = tir.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_34 = tir.match_buffer(placeholder_31, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_35 = tir.match_buffer(placeholder_32, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_9 = tir.match_buffer(T_cast_8, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_3 = tir.allocate([1, 28, 28, 192], "int16", "global") + for i0_i1_fused_3 in tir.parallel(0, 28): + for i2_3, i3_3 in tir.grid(28, 192): + tir.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), tir.load("int16", placeholder_33.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True) + for ax0_ax1_fused_ax2_fused_3 in tir.parallel(0, 784): + for ax3_2 in tir.serial(0, 16): + Conv2dOutput_3 = tir.allocate([1, 1, 1, 1], "int32", "global") + tir.store(Conv2dOutput_3, 0, 0, True) + for rc_3 in tir.serial(0, 192): + tir.store(Conv2dOutput_3, 0, (tir.load("int32", Conv2dOutput_3, 0) + (tir.cast(tir.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*tir.cast(tir.load("int16", placeholder_34.data, ((rc_3*16) + ax3_2)), "int32"))), True) + tir.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3*16) + ax3_2), tir.cast(tir.cast(tir.max(tir.min(tir.q_multiply_shift((tir.load("int32", Conv2dOutput_3, 0) + tir.load("int32", placeholder_35.data, ax3_2)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) +# fmt: on + + +def test_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(): + primfunc = fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2 + mod = tvm.IRModule.from_expr(primfunc) + mod = tvm.tir.transform.ConvertForLoopsToSerial()(mod) + + def verify_serial_loops(stmt): + if isinstance(stmt, tvm.tir.For): + assert stmt.kind == tvm.tir.ForKind.SERIAL + + for _, primfunc in mod.functions.items(): + stmt_functor.post_order_visit(primfunc.body, verify_serial_loops) + + +if __name__ == "__main__": + pytest.main([__file__])