diff --git a/pytorch_blade/pytorch_blade/compiler/mlir/converters/torch_mlir_op_filter.cpp b/pytorch_blade/pytorch_blade/compiler/mlir/converters/torch_mlir_op_filter.cpp index 0bc9e283926..3d13694eaa5 100644 --- a/pytorch_blade/pytorch_blade/compiler/mlir/converters/torch_mlir_op_filter.cpp +++ b/pytorch_blade/pytorch_blade/compiler/mlir/converters/torch_mlir_op_filter.cpp @@ -77,6 +77,7 @@ const std::unordered_set &GetTorchMlirWhiteList() { "aten::gelu", "aten::gelu_backward", "aten::glu", + "aten::group_norm", "aten::hardsigmoid", "aten::hardswish", "aten::hardtanh", diff --git a/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/ApplyDiscPdlPatterns.cpp b/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/ApplyDiscPdlPatterns.cpp index 9a8e9377034..8cecfcfd95e 100644 --- a/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/ApplyDiscPdlPatterns.cpp +++ b/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/ApplyDiscPdlPatterns.cpp @@ -102,6 +102,229 @@ bool isOpTriviallyDeadDisc(Operation* op) { // add pre-defined pdll patterns here. std::string getTorchPredefinedPDLPatterns() { std::string preDefinedPatterns; +#if defined(PLATFORM_ALIBABA) and defined(ENABLE_BLADE_GEMM) + preDefinedPatterns += R"pdll( + Rewrite ConvertToF16(value: Value) -> Value { + let f16_dtype = op {value = attr<"5">} -> (type<"!torch.int">); + let old_type = GetTorchTensorType(value); + let new_type = ConvertTorchTensorElemType(old_type, attr<"\"f16\"">); + let false_val = op {value = attr<"false">} -> (type<"!torch.bool">); + let none_val = op -> (type<"!torch.none">); + let f16_value = op( + value, f16_dtype, false_val, false_val, none_val + ) -> (new_type); + + return f16_value.0; + } + + Rewrite ConvertToF32(value: Value) -> Value { + let f32_dtype = op {value = attr<"6">} -> (type<"!torch.int">); + let old_type = GetTorchTensorType(value); + let new_type = ConvertTorchTensorElemType(old_type, attr<"\"f32\"">); + let false_val = op {value = attr<"false">} -> (type<"!torch.bool">); + let none_val = op -> (type<"!torch.none">); + let f32_value = op( + value, f32_dtype, false_val, false_val, none_val + ) -> (new_type); + + return f32_value.0; + } + + Pattern TorchGroupNormOpF32 { + /// match phase: define the pattern + let eps_attr : Attr; + let eps = op { value = eps_attr }; + let gn = op( + input: Value, + num_group: Value, + weight: Value, + bias: Value, + eps.0, + cudnn_enabled: Value + ) -> (old_type: Type); + CheckNotTorchNone(bias); + CheckTorchConstantInt(num_group); + CheckTorchTensorElemType(input, attr<"\"f32\"">); + + /// rewrite phase + rewrite gn with { + let f16_input = ConvertToF16(input); + let f16_weight = ConvertToF16(weight); + let f16_bias = ConvertToF16(bias); + let f16_output = ConvertToF16(gn.0); + + /// 1. create custom call op + let inputs = PackValue_3(attr<"\"in\"">, f16_input, f16_weight, f16_bias); + let outputs = PackValue_1(attr<"\"out\"">, f16_output); + let infos = CreateTorchCustomCall(attr<"\"op\"">, inputs, outputs); + + /// 2. set attrs that are used by bladedisc. + SetAttr(infos.op, attr<"\"call_target_name\"">, attr<"\"ral_pdll_group_norm\"">); + SetAttr(infos.op, attr<"\"input_placements\"">, attr<"\"x,x,x\"">); + SetAttr(infos.op, attr<"\"output_placements\"">, attr<"\"x\"">); + SetAttr(infos.op, attr<"\"device\"">, attr<"\"x\"">); + SetAttr(infos.op, attr<"\"input_layouts\"">, attr<"\"NCHW,*,*\"">); + SetAttr(infos.op, attr<"\"output_layouts\"">, attr<"\"NCHW\"">); + SetAttr(infos.op, attr<"\"expected_input_layouts\"">, attr<"\"NHWC,*,*\"">); + SetAttr(infos.op, attr<"\"expected_output_layouts\"">, attr<"\"NHWC\"">); + + /// 3. set attrs that are directly passed to the custom call kernel. + let num_group_attr = ConvertTorchConstantIntToI64Attr(num_group); + SetCustomAttr(infos.op, attr<"\"num_group\"">, num_group_attr); + SetCustomAttr(infos.op, attr<"\"eps\"">, eps_attr); + SetCustomAttr(infos.op, attr<"\"silu\"">, attr<"false">); + + let rs = UnpackValue_1(infos.new_outputs); + let new_output = ConvertToF32(rs); + + replace gn with new_output; + }; + } + + Pattern TorchGroupNormWithSiluOpF32 { + /// match phase: define the pattern + let eps_attr : Attr; + let eps = op { value = eps_attr }; + let gn = op( + input: Value, + num_group: Value, + weight: Value, + bias: Value, + eps.0, + cudnn_enabled: Value + ) -> (old_type: Type); + let silu = op(gn.0); + CheckNotTorchNone(bias); + CheckTorchConstantInt(num_group); + CheckTorchTensorElemType(input, attr<"\"f32\"">); + + /// rewrite phase + rewrite silu with { + let f16_input = ConvertToF16(input); + let f16_weight = ConvertToF16(weight); + let f16_bias = ConvertToF16(bias); + let f16_output = ConvertToF16(silu.0); + + /// 1. create custom call op + let inputs = PackValue_3(attr<"\"in\"">, f16_input, f16_weight, f16_bias); + let outputs = PackValue_1(attr<"\"out\"">, f16_output); + let infos = CreateTorchCustomCall(attr<"\"op\"">, inputs, outputs); + + /// 2. set attrs that are used by bladedisc. + SetAttr(infos.op, attr<"\"call_target_name\"">, attr<"\"ral_pdll_group_norm\"">); + SetAttr(infos.op, attr<"\"input_placements\"">, attr<"\"x,x,x\"">); + SetAttr(infos.op, attr<"\"output_placements\"">, attr<"\"x\"">); + SetAttr(infos.op, attr<"\"device\"">, attr<"\"x\"">); + SetAttr(infos.op, attr<"\"input_layouts\"">, attr<"\"NCHW,*,*\"">); + SetAttr(infos.op, attr<"\"output_layouts\"">, attr<"\"NCHW\"">); + SetAttr(infos.op, attr<"\"expected_input_layouts\"">, attr<"\"NHWC,*,*\"">); + SetAttr(infos.op, attr<"\"expected_output_layouts\"">, attr<"\"NHWC\"">); + + /// 3. set attrs that are directly passed to the custom call kernel. + let num_group_attr = ConvertTorchConstantIntToI64Attr(num_group); + SetCustomAttr(infos.op, attr<"\"num_group\"">, num_group_attr); + SetCustomAttr(infos.op, attr<"\"eps\"">, eps_attr); + SetCustomAttr(infos.op, attr<"\"silu\"">, attr<"true">); + + let rs = UnpackValue_1(infos.new_outputs); + let new_output = ConvertToF32(rs); + + replace silu with new_output; + }; + } + + Pattern TorchGroupNormOpF16 { + /// match phase: define the pattern + let eps_attr : Attr; + let eps = op { value = eps_attr }; + let gn = op( + input: Value, + num_group: Value, + weight: Value, + bias: Value, + eps.0, + cudnn_enabled: Value + ) -> (old_type: Type); + CheckNotTorchNone(bias); + CheckTorchConstantInt(num_group); + CheckTorchTensorElemType(input, attr<"\"f16\"">); + + /// rewrite phase + rewrite gn with { + /// 1. create custom call op + let inputs = PackValue_3(attr<"\"in\"">, input, weight, bias); + let outputs = PackValue_1(attr<"\"out\"">, gn.0); + let infos = CreateTorchCustomCall(attr<"\"op\"">, inputs, outputs); + + /// 2. set attrs that are used by bladedisc. + SetAttr(infos.op, attr<"\"call_target_name\"">, attr<"\"ral_pdll_group_norm\"">); + SetAttr(infos.op, attr<"\"input_placements\"">, attr<"\"x,x,x\"">); + SetAttr(infos.op, attr<"\"output_placements\"">, attr<"\"x\"">); + SetAttr(infos.op, attr<"\"device\"">, attr<"\"x\"">); + SetAttr(infos.op, attr<"\"input_layouts\"">, attr<"\"NCHW,*,*\"">); + SetAttr(infos.op, attr<"\"output_layouts\"">, attr<"\"NCHW\"">); + SetAttr(infos.op, attr<"\"expected_input_layouts\"">, attr<"\"NHWC,*,*\"">); + SetAttr(infos.op, attr<"\"expected_output_layouts\"">, attr<"\"NHWC\"">); + + /// 3. set attrs that are directly passed to the custom call kernel. + let num_group_attr = ConvertTorchConstantIntToI64Attr(num_group); + SetCustomAttr(infos.op, attr<"\"num_group\"">, num_group_attr); + SetCustomAttr(infos.op, attr<"\"eps\"">, eps_attr); + SetCustomAttr(infos.op, attr<"\"silu\"">, attr<"false">); + + let rs = UnpackValue_1(infos.new_outputs); + replace gn with rs; + }; + } + + Pattern TorchGroupNormWithSiluOpF16 { + /// match phase: define the pattern + let eps_attr : Attr; + let eps = op { value = eps_attr }; + let gn = op( + input: Value, + num_group: Value, + weight: Value, + bias: Value, + eps.0, + cudnn_enabled: Value + ) -> (old_type: Type); + let silu = op(gn.0); + CheckNotTorchNone(bias); + CheckTorchConstantInt(num_group); + CheckTorchTensorElemType(input, attr<"\"f16\"">); + + /// rewrite phase + rewrite silu with { + /// 1. create custom call op + let inputs = PackValue_3(attr<"\"in\"">, input, weight, bias); + let outputs = PackValue_1(attr<"\"out\"">, silu.0); + let infos = CreateTorchCustomCall(attr<"\"op\"">, inputs, outputs); + + /// 2. set attrs that are used by bladedisc. + SetAttr(infos.op, attr<"\"call_target_name\"">, attr<"\"ral_pdll_group_norm\"">); + SetAttr(infos.op, attr<"\"input_placements\"">, attr<"\"x,x,x\"">); + SetAttr(infos.op, attr<"\"output_placements\"">, attr<"\"x\"">); + SetAttr(infos.op, attr<"\"device\"">, attr<"\"x\"">); + SetAttr(infos.op, attr<"\"input_layouts\"">, attr<"\"NCHW,*,*\"">); + SetAttr(infos.op, attr<"\"output_layouts\"">, attr<"\"NCHW\"">); + SetAttr(infos.op, attr<"\"expected_input_layouts\"">, attr<"\"NHWC,*,*\"">); + SetAttr(infos.op, attr<"\"expected_output_layouts\"">, attr<"\"NHWC\"">); + + /// 3. set attrs that are directly passed to the custom call kernel. + let num_group_attr = ConvertTorchConstantIntToI64Attr(num_group); + SetCustomAttr(infos.op, attr<"\"num_group\"">, num_group_attr); + SetCustomAttr(infos.op, attr<"\"eps\"">, eps_attr); + SetCustomAttr(infos.op, attr<"\"silu\"">, attr<"true">); + + let rs = UnpackValue_1(infos.new_outputs); + replace silu with rs; + }; + } + + )pdll"; +#endif + return preDefinedPatterns; } @@ -157,13 +380,19 @@ void ApplyDiscPdlPatternsPass::runOnOperation() { MLIRContext* context = &getContext(); RewritePatternSet patterns(context); + auto pdll_include_dirs = mlir::disc_ral::ParseFileString(pdll_include_dirs_); + (void)mlir::disc_ral::populateDiscPdlPatternsFromString( - &patterns, getTorchPredefinedPDLPatterns()); + &patterns, + getTorchPredefinedPDLPatterns(), + pdll_include_dirs, + torch::kDefaultHelperFunctionDeclarations, + torch::registerPredefinedHelperFunctions); (void)mlir::disc_ral::populateDiscPdlPatternsFromFiles( &patterns, mlir::disc_ral::ParseFileString(pdll_files_), - mlir::disc_ral::ParseFileString(pdll_include_dirs_), + pdll_include_dirs, torch::kDefaultHelperFunctionDeclarations, torch::registerPredefinedHelperFunctions); diff --git a/pytorch_blade/tests/disc/ops/test_group_norm.py b/pytorch_blade/tests/disc/ops/test_group_norm.py new file mode 100644 index 00000000000..a817f28167d --- /dev/null +++ b/pytorch_blade/tests/disc/ops/test_group_norm.py @@ -0,0 +1,33 @@ +# Copyright 2021 The BladeDISC Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import unittest + +from torch_blade.version import cuda_available +from tests.disc.testing_base import DiscTestCase + +class TestDiscGroupNorm(DiscTestCase): + def _test_group_norm(self, groupnorm): + test_data = torch.randn([2, 320, 64, 64], device=self.device) + annotation = ([-1, -1, -1, -1], torch.float) + self._test_disc(groupnorm, [annotation], (test_data,)) + + def test_groupnorm_module(self): + groupnorm = torch.nn.GroupNorm(32, 320, affine=False) + self._test_group_norm(groupnorm) + + def test_groupnorm_module_has_affine(self): + groupnorm = torch.nn.GroupNorm(32, 320, affine=True) + self._test_group_norm(groupnorm) + +if __name__ == "__main__": + unittest.main() diff --git a/pytorch_blade/tests/torch-disc-pdll/utils.cpp b/pytorch_blade/tests/torch-disc-pdll/utils.cpp index ae6e71e6c5d..a0543c5012f 100644 --- a/pytorch_blade/tests/torch-disc-pdll/utils.cpp +++ b/pytorch_blade/tests/torch-disc-pdll/utils.cpp @@ -34,6 +34,7 @@ const std::string kDefaultHelperFunctionDeclarations = R"pdll( Constraint CheckTorchConstantBoolFalse(v : Value); Constraint CheckTorchConstantIntList(v : Value); Constraint CheckTorchValueTensorLiteral(v : Value); + Constraint CheckTorchTensorElemType(v: Value, type_str: Attr); Rewrite CreateTorchCustomCall(tag : Attr, inputs : ValueRange, outputs : ValueRange) -> (op: Op, new_outputs : ValueRange); Rewrite ConvertTorchConstantIntListToI64DenseElemsAttr(cst: Value) -> Attr; @@ -150,6 +151,40 @@ static LogicalResult checkTorchValueTensorLiteral( return success(); } +static LogicalResult checkTorchTensorElemType( + PatternRewriter& rewriter, + ArrayRef values) { + assert(values.size() == 2); + + auto v = values[0].cast(); + auto tensorTy = v.getType().dyn_cast(); + if (!tensorTy) + return failure(); + + auto type_str = + values[1].cast().cast().getValue().str(); + + std::unordered_map typeconvert_dict = { + {"i1", rewriter.getI1Type()}, + {"ui8", + IntegerType::get(rewriter.getContext(), 8, IntegerType::Unsigned)}, + {"i8", IntegerType::get(rewriter.getContext(), 8, IntegerType::Signed)}, + {"i32", IntegerType::get(rewriter.getContext(), 32, IntegerType::Signed)}, + {"ui32", + IntegerType::get(rewriter.getContext(), 32, IntegerType::Unsigned)}, + {"i64", IntegerType::get(rewriter.getContext(), 64, IntegerType::Signed)}, + {"ui64", + IntegerType::get(rewriter.getContext(), 64, IntegerType::Unsigned)}, + {"f16", rewriter.getF16Type()}, + {"bf16", rewriter.getBF16Type()}, + {"f32", rewriter.getF32Type()}}; + + assert(typeconvert_dict.find(type_str) != typeconvert_dict.end()); + + return (tensorTy.getDtype() == typeconvert_dict[type_str]) ? success() + : failure(); +} + static void getTorchTensorType( PatternRewriter& rewriter, PDLResultList& results, @@ -305,6 +340,8 @@ void registerPredefinedHelperFunctions(PDLPatternModule& pdlPatterns) { "CheckTorchConstantIntList", checkTorchConstantIntList); pdlPatterns.registerConstraintFunction( "CheckTorchValueTensorLiteral", checkTorchValueTensorLiteral); + pdlPatterns.registerConstraintFunction( + "CheckTorchTensorElemType", checkTorchTensorElemType); } } // namespace torch diff --git a/pytorch_blade/third_party/torch-mlir b/pytorch_blade/third_party/torch-mlir index a9ad6d266f2..591fe9bc304 160000 --- a/pytorch_blade/third_party/torch-mlir +++ b/pytorch_blade/third_party/torch-mlir @@ -1 +1 @@ -Subproject commit a9ad6d266f2bf24e319e5860eb93407370892621 +Subproject commit 591fe9bc304eeac0e32255363cb8e7a912391dca diff --git a/tao_compiler/mlir/xla/ral/context/stream_executor_based_impl.cc b/tao_compiler/mlir/xla/ral/context/stream_executor_based_impl.cc index 49caff24735..548982b22d2 100644 --- a/tao_compiler/mlir/xla/ral/context/stream_executor_based_impl.cc +++ b/tao_compiler/mlir/xla/ral/context/stream_executor_based_impl.cc @@ -1942,6 +1942,46 @@ MemRefType ral_conv_biasadd(ExecutionContext* ctx, /////////////// //////////////////////////////////////////////////////////////////////// +namespace groupnorm_impl { + +#if defined(PLATFORM_ALIBABA) and defined(ENABLE_BLADE_GEMM) + +// Pre-requirement: input has layout NHWC +template +MemRefType bladnn_groupnorm(ExecutionContext* ctx, void* stream_handle, + MemRefType input, + MemRefType weight, + MemRefType bias, void* customAttrs) { + size_t nElems = Size(input); + auto driver = ctx->getDriver(gpu::GPUDriver::name()); + TAO_CHECK(driver); + auto ptr = static_cast(driver->alloc(ctx, nElems * sizeof(T))); + auto output = assignMemRef(ptr, input.sizes); + + auto attr = getOrParsePDLAttr(ctx, customAttrs, "ral_groupnorm"); + if (!attr) { + ctx->signalError(Context::FAILURE, "fail to parse custom_attrs\n"); + } + auto& dictAttr = attr->as(); + auto num_group = + dictAttr.get("num_group").template as().getValue(); + auto eps = dictAttr.get("eps").template as().getValue(); + auto use_silu = dictAttr.get("silu").template as().getValue(); + auto stream = driver->asCUStream(ctx, stream_handle); + bool ret = + bladnn::groupnorm(output.data, input.data, weight.data, bias.data, + input.sizes[0], input.sizes[1], input.sizes[2], + input.sizes[3], num_group, use_silu, eps, stream); + if (!ret) { + ctx->signalError(Context::FAILURE, "fail to call bladnn::groupnorm\n"); + } + return output; +} + +#endif + +} // namespace groupnorm_impl + } // namespace se_impl } // namespace gpu @@ -1983,6 +2023,8 @@ TAO_RAL_API("ral_pdll_conv_bias", "gpu", gpu::se_impl::gpu_conv_impl::ral_conv_biasadd); TAO_RAL_API("ral_pdll_conv_bias", "gpu", gpu::se_impl::gpu_conv_impl::ral_conv_biasadd); +TAO_RAL_API("ral_pdll_group_norm", "gpu", + gpu::se_impl::groupnorm_impl::bladnn_groupnorm); #endif // compute-intensive fusion