Skip to content

Commit

Permalink
support group norm lowering (#874)
Browse files Browse the repository at this point in the history
* support group norm lowering

* [pdl] groupnorm support

Co-authored-by: Wenyi Zhao <kevin.zwy@alibaba-inc.com>
  • Loading branch information
zzpmiracle and wyzero committed Dec 26, 2022
1 parent 71bb36e commit f8107f2
Show file tree
Hide file tree
Showing 6 changed files with 345 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ const std::unordered_set<std::string> &GetTorchMlirWhiteList() {
"aten::gelu",
"aten::gelu_backward",
"aten::glu",
"aten::group_norm",
"aten::hardsigmoid",
"aten::hardswish",
"aten::hardtanh",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,229 @@ bool isOpTriviallyDeadDisc(Operation* op) {
// add pre-defined pdll patterns here.
std::string getTorchPredefinedPDLPatterns() {
std::string preDefinedPatterns;
#if defined(PLATFORM_ALIBABA) and defined(ENABLE_BLADE_GEMM)
preDefinedPatterns += R"pdll(
Rewrite ConvertToF16(value: Value) -> Value {
let f16_dtype = op<torch.constant.int> {value = attr<"5">} -> (type<"!torch.int">);
let old_type = GetTorchTensorType(value);
let new_type = ConvertTorchTensorElemType(old_type, attr<"\"f16\"">);
let false_val = op<torch.constant.bool> {value = attr<"false">} -> (type<"!torch.bool">);
let none_val = op<torch.constant.none> -> (type<"!torch.none">);
let f16_value = op<torch.aten.to.dtype>(
value, f16_dtype, false_val, false_val, none_val
) -> (new_type);
return f16_value.0;
}
Rewrite ConvertToF32(value: Value) -> Value {
let f32_dtype = op<torch.constant.int> {value = attr<"6">} -> (type<"!torch.int">);
let old_type = GetTorchTensorType(value);
let new_type = ConvertTorchTensorElemType(old_type, attr<"\"f32\"">);
let false_val = op<torch.constant.bool> {value = attr<"false">} -> (type<"!torch.bool">);
let none_val = op<torch.constant.none> -> (type<"!torch.none">);
let f32_value = op<torch.aten.to.dtype>(
value, f32_dtype, false_val, false_val, none_val
) -> (new_type);
return f32_value.0;
}
Pattern TorchGroupNormOpF32 {
/// match phase: define the pattern
let eps_attr : Attr;
let eps = op<torch.constant.float> { value = eps_attr };
let gn = op<torch.aten.group_norm>(
input: Value,
num_group: Value,
weight: Value,
bias: Value,
eps.0,
cudnn_enabled: Value
) -> (old_type: Type);
CheckNotTorchNone(bias);
CheckTorchConstantInt(num_group);
CheckTorchTensorElemType(input, attr<"\"f32\"">);
/// rewrite phase
rewrite gn with {
let f16_input = ConvertToF16(input);
let f16_weight = ConvertToF16(weight);
let f16_bias = ConvertToF16(bias);
let f16_output = ConvertToF16(gn.0);
/// 1. create custom call op
let inputs = PackValue_3(attr<"\"in\"">, f16_input, f16_weight, f16_bias);
let outputs = PackValue_1(attr<"\"out\"">, f16_output);
let infos = CreateTorchCustomCall(attr<"\"op\"">, inputs, outputs);
/// 2. set attrs that are used by bladedisc.
SetAttr(infos.op, attr<"\"call_target_name\"">, attr<"\"ral_pdll_group_norm\"">);
SetAttr(infos.op, attr<"\"input_placements\"">, attr<"\"x,x,x\"">);
SetAttr(infos.op, attr<"\"output_placements\"">, attr<"\"x\"">);
SetAttr(infos.op, attr<"\"device\"">, attr<"\"x\"">);
SetAttr(infos.op, attr<"\"input_layouts\"">, attr<"\"NCHW,*,*\"">);
SetAttr(infos.op, attr<"\"output_layouts\"">, attr<"\"NCHW\"">);
SetAttr(infos.op, attr<"\"expected_input_layouts\"">, attr<"\"NHWC,*,*\"">);
SetAttr(infos.op, attr<"\"expected_output_layouts\"">, attr<"\"NHWC\"">);
/// 3. set attrs that are directly passed to the custom call kernel.
let num_group_attr = ConvertTorchConstantIntToI64Attr(num_group);
SetCustomAttr(infos.op, attr<"\"num_group\"">, num_group_attr);
SetCustomAttr(infos.op, attr<"\"eps\"">, eps_attr);
SetCustomAttr(infos.op, attr<"\"silu\"">, attr<"false">);
let rs = UnpackValue_1(infos.new_outputs);
let new_output = ConvertToF32(rs);
replace gn with new_output;
};
}
Pattern TorchGroupNormWithSiluOpF32 {
/// match phase: define the pattern
let eps_attr : Attr;
let eps = op<torch.constant.float> { value = eps_attr };
let gn = op<torch.aten.group_norm>(
input: Value,
num_group: Value,
weight: Value,
bias: Value,
eps.0,
cudnn_enabled: Value
) -> (old_type: Type);
let silu = op<torch.aten.silu>(gn.0);
CheckNotTorchNone(bias);
CheckTorchConstantInt(num_group);
CheckTorchTensorElemType(input, attr<"\"f32\"">);
/// rewrite phase
rewrite silu with {
let f16_input = ConvertToF16(input);
let f16_weight = ConvertToF16(weight);
let f16_bias = ConvertToF16(bias);
let f16_output = ConvertToF16(silu.0);
/// 1. create custom call op
let inputs = PackValue_3(attr<"\"in\"">, f16_input, f16_weight, f16_bias);
let outputs = PackValue_1(attr<"\"out\"">, f16_output);
let infos = CreateTorchCustomCall(attr<"\"op\"">, inputs, outputs);
/// 2. set attrs that are used by bladedisc.
SetAttr(infos.op, attr<"\"call_target_name\"">, attr<"\"ral_pdll_group_norm\"">);
SetAttr(infos.op, attr<"\"input_placements\"">, attr<"\"x,x,x\"">);
SetAttr(infos.op, attr<"\"output_placements\"">, attr<"\"x\"">);
SetAttr(infos.op, attr<"\"device\"">, attr<"\"x\"">);
SetAttr(infos.op, attr<"\"input_layouts\"">, attr<"\"NCHW,*,*\"">);
SetAttr(infos.op, attr<"\"output_layouts\"">, attr<"\"NCHW\"">);
SetAttr(infos.op, attr<"\"expected_input_layouts\"">, attr<"\"NHWC,*,*\"">);
SetAttr(infos.op, attr<"\"expected_output_layouts\"">, attr<"\"NHWC\"">);
/// 3. set attrs that are directly passed to the custom call kernel.
let num_group_attr = ConvertTorchConstantIntToI64Attr(num_group);
SetCustomAttr(infos.op, attr<"\"num_group\"">, num_group_attr);
SetCustomAttr(infos.op, attr<"\"eps\"">, eps_attr);
SetCustomAttr(infos.op, attr<"\"silu\"">, attr<"true">);
let rs = UnpackValue_1(infos.new_outputs);
let new_output = ConvertToF32(rs);
replace silu with new_output;
};
}
Pattern TorchGroupNormOpF16 {
/// match phase: define the pattern
let eps_attr : Attr;
let eps = op<torch.constant.float> { value = eps_attr };
let gn = op<torch.aten.group_norm>(
input: Value,
num_group: Value,
weight: Value,
bias: Value,
eps.0,
cudnn_enabled: Value
) -> (old_type: Type);
CheckNotTorchNone(bias);
CheckTorchConstantInt(num_group);
CheckTorchTensorElemType(input, attr<"\"f16\"">);
/// rewrite phase
rewrite gn with {
/// 1. create custom call op
let inputs = PackValue_3(attr<"\"in\"">, input, weight, bias);
let outputs = PackValue_1(attr<"\"out\"">, gn.0);
let infos = CreateTorchCustomCall(attr<"\"op\"">, inputs, outputs);
/// 2. set attrs that are used by bladedisc.
SetAttr(infos.op, attr<"\"call_target_name\"">, attr<"\"ral_pdll_group_norm\"">);
SetAttr(infos.op, attr<"\"input_placements\"">, attr<"\"x,x,x\"">);
SetAttr(infos.op, attr<"\"output_placements\"">, attr<"\"x\"">);
SetAttr(infos.op, attr<"\"device\"">, attr<"\"x\"">);
SetAttr(infos.op, attr<"\"input_layouts\"">, attr<"\"NCHW,*,*\"">);
SetAttr(infos.op, attr<"\"output_layouts\"">, attr<"\"NCHW\"">);
SetAttr(infos.op, attr<"\"expected_input_layouts\"">, attr<"\"NHWC,*,*\"">);
SetAttr(infos.op, attr<"\"expected_output_layouts\"">, attr<"\"NHWC\"">);
/// 3. set attrs that are directly passed to the custom call kernel.
let num_group_attr = ConvertTorchConstantIntToI64Attr(num_group);
SetCustomAttr(infos.op, attr<"\"num_group\"">, num_group_attr);
SetCustomAttr(infos.op, attr<"\"eps\"">, eps_attr);
SetCustomAttr(infos.op, attr<"\"silu\"">, attr<"false">);
let rs = UnpackValue_1(infos.new_outputs);
replace gn with rs;
};
}
Pattern TorchGroupNormWithSiluOpF16 {
/// match phase: define the pattern
let eps_attr : Attr;
let eps = op<torch.constant.float> { value = eps_attr };
let gn = op<torch.aten.group_norm>(
input: Value,
num_group: Value,
weight: Value,
bias: Value,
eps.0,
cudnn_enabled: Value
) -> (old_type: Type);
let silu = op<torch.aten.silu>(gn.0);
CheckNotTorchNone(bias);
CheckTorchConstantInt(num_group);
CheckTorchTensorElemType(input, attr<"\"f16\"">);
/// rewrite phase
rewrite silu with {
/// 1. create custom call op
let inputs = PackValue_3(attr<"\"in\"">, input, weight, bias);
let outputs = PackValue_1(attr<"\"out\"">, silu.0);
let infos = CreateTorchCustomCall(attr<"\"op\"">, inputs, outputs);
/// 2. set attrs that are used by bladedisc.
SetAttr(infos.op, attr<"\"call_target_name\"">, attr<"\"ral_pdll_group_norm\"">);
SetAttr(infos.op, attr<"\"input_placements\"">, attr<"\"x,x,x\"">);
SetAttr(infos.op, attr<"\"output_placements\"">, attr<"\"x\"">);
SetAttr(infos.op, attr<"\"device\"">, attr<"\"x\"">);
SetAttr(infos.op, attr<"\"input_layouts\"">, attr<"\"NCHW,*,*\"">);
SetAttr(infos.op, attr<"\"output_layouts\"">, attr<"\"NCHW\"">);
SetAttr(infos.op, attr<"\"expected_input_layouts\"">, attr<"\"NHWC,*,*\"">);
SetAttr(infos.op, attr<"\"expected_output_layouts\"">, attr<"\"NHWC\"">);
/// 3. set attrs that are directly passed to the custom call kernel.
let num_group_attr = ConvertTorchConstantIntToI64Attr(num_group);
SetCustomAttr(infos.op, attr<"\"num_group\"">, num_group_attr);
SetCustomAttr(infos.op, attr<"\"eps\"">, eps_attr);
SetCustomAttr(infos.op, attr<"\"silu\"">, attr<"true">);
let rs = UnpackValue_1(infos.new_outputs);
replace silu with rs;
};
}
)pdll";
#endif

return preDefinedPatterns;
}

Expand Down Expand Up @@ -157,13 +380,19 @@ void ApplyDiscPdlPatternsPass::runOnOperation() {
MLIRContext* context = &getContext();
RewritePatternSet patterns(context);

auto pdll_include_dirs = mlir::disc_ral::ParseFileString(pdll_include_dirs_);

(void)mlir::disc_ral::populateDiscPdlPatternsFromString(
&patterns, getTorchPredefinedPDLPatterns());
&patterns,
getTorchPredefinedPDLPatterns(),
pdll_include_dirs,
torch::kDefaultHelperFunctionDeclarations,
torch::registerPredefinedHelperFunctions);

(void)mlir::disc_ral::populateDiscPdlPatternsFromFiles(
&patterns,
mlir::disc_ral::ParseFileString(pdll_files_),
mlir::disc_ral::ParseFileString(pdll_include_dirs_),
pdll_include_dirs,
torch::kDefaultHelperFunctionDeclarations,
torch::registerPredefinedHelperFunctions);

Expand Down
33 changes: 33 additions & 0 deletions pytorch_blade/tests/disc/ops/test_group_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2021 The BladeDISC Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import unittest

from torch_blade.version import cuda_available
from tests.disc.testing_base import DiscTestCase

class TestDiscGroupNorm(DiscTestCase):
def _test_group_norm(self, groupnorm):
test_data = torch.randn([2, 320, 64, 64], device=self.device)
annotation = ([-1, -1, -1, -1], torch.float)
self._test_disc(groupnorm, [annotation], (test_data,))

def test_groupnorm_module(self):
groupnorm = torch.nn.GroupNorm(32, 320, affine=False)
self._test_group_norm(groupnorm)

def test_groupnorm_module_has_affine(self):
groupnorm = torch.nn.GroupNorm(32, 320, affine=True)
self._test_group_norm(groupnorm)

if __name__ == "__main__":
unittest.main()
37 changes: 37 additions & 0 deletions pytorch_blade/tests/torch-disc-pdll/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const std::string kDefaultHelperFunctionDeclarations = R"pdll(
Constraint CheckTorchConstantBoolFalse(v : Value);
Constraint CheckTorchConstantIntList(v : Value);
Constraint CheckTorchValueTensorLiteral(v : Value);
Constraint CheckTorchTensorElemType(v: Value, type_str: Attr);
Rewrite CreateTorchCustomCall(tag : Attr, inputs : ValueRange, outputs : ValueRange) -> (op: Op, new_outputs : ValueRange);
Rewrite ConvertTorchConstantIntListToI64DenseElemsAttr(cst: Value) -> Attr;
Expand Down Expand Up @@ -150,6 +151,40 @@ static LogicalResult checkTorchValueTensorLiteral(
return success();
}

static LogicalResult checkTorchTensorElemType(
PatternRewriter& rewriter,
ArrayRef<PDLValue> values) {
assert(values.size() == 2);

auto v = values[0].cast<Value>();
auto tensorTy = v.getType().dyn_cast<Torch::ValueTensorType>();
if (!tensorTy)
return failure();

auto type_str =
values[1].cast<Attribute>().cast<StringAttr>().getValue().str();

std::unordered_map<std::string, Type> typeconvert_dict = {
{"i1", rewriter.getI1Type()},
{"ui8",
IntegerType::get(rewriter.getContext(), 8, IntegerType::Unsigned)},
{"i8", IntegerType::get(rewriter.getContext(), 8, IntegerType::Signed)},
{"i32", IntegerType::get(rewriter.getContext(), 32, IntegerType::Signed)},
{"ui32",
IntegerType::get(rewriter.getContext(), 32, IntegerType::Unsigned)},
{"i64", IntegerType::get(rewriter.getContext(), 64, IntegerType::Signed)},
{"ui64",
IntegerType::get(rewriter.getContext(), 64, IntegerType::Unsigned)},
{"f16", rewriter.getF16Type()},
{"bf16", rewriter.getBF16Type()},
{"f32", rewriter.getF32Type()}};

assert(typeconvert_dict.find(type_str) != typeconvert_dict.end());

return (tensorTy.getDtype() == typeconvert_dict[type_str]) ? success()
: failure();
}

static void getTorchTensorType(
PatternRewriter& rewriter,
PDLResultList& results,
Expand Down Expand Up @@ -305,6 +340,8 @@ void registerPredefinedHelperFunctions(PDLPatternModule& pdlPatterns) {
"CheckTorchConstantIntList", checkTorchConstantIntList);
pdlPatterns.registerConstraintFunction(
"CheckTorchValueTensorLiteral", checkTorchValueTensorLiteral);
pdlPatterns.registerConstraintFunction(
"CheckTorchTensorElemType", checkTorchTensorElemType);
}

} // namespace torch
Expand Down
Loading

0 comments on commit f8107f2

Please sign in to comment.