diff --git a/pytorch_blade/pytorch_blade/compiler/jit/torch/shape_analysis.cpp b/pytorch_blade/pytorch_blade/compiler/jit/torch/shape_analysis.cpp index dc241424d88..25aef663fac 100644 --- a/pytorch_blade/pytorch_blade/compiler/jit/torch/shape_analysis.cpp +++ b/pytorch_blade/pytorch_blade/compiler/jit/torch/shape_analysis.cpp @@ -655,6 +655,41 @@ class ShapePropagator : public PropertyPropBase { } return; } + case prim::ListUnpack: { + auto input_node = node->input()->node(); + if (input_node->matches( + "aten::split.Tensor(Tensor(a) self, int split_size, int dim=0) -> Tensor(a)[]") || + input_node->matches( + "aten::chunk(Tensor(a) self, int chunks, int dim=0) -> Tensor(a)[]") || + input_node->matches( + "aten::unbind.int(Tensor(a) self, int dim=0) -> Tensor(a)[]")) { + if (auto self_type = + input_node->input(0)->type()->cast()) { + auto sizes_opt = self_type->symbolic_sizes().sizes(); + auto dim_opt = input_node->get(attr::dim); + if (!(sizes_opt && dim_opt)) + return; + + std::vector new_sizes = sizes_opt.value(); + int64_t input_rank = new_sizes.size(); + int64_t dim = + at::maybe_wrap_dim(dim_opt.value(), input_rank, false); + if (input_node->matches( + "aten::unbind.int(Tensor(a) self, int dim=0) -> Tensor(a)[]")) { + new_sizes.erase(new_sizes.begin() + dim); + } else { + // set default to dynamic + new_sizes[dim] = ShapeSymbol::newSymbol(); + } + + for (size_t i = 0; i < node->outputs().size(); ++i) { + if (auto type = node->output(i)->type()->cast()) + node->output(i)->setType(type->withSymbolicShapes(new_sizes)); + } + } + } + return; + } case prim::Constant: { if (node->output()->type()->isSubtypeOf(TensorType::get())) { node->output()->inferTypeFrom(node->t(attr::value)); diff --git a/pytorch_blade/pytorch_blade/compiler/mlir/converters/torch_mlir_op_filter.cpp b/pytorch_blade/pytorch_blade/compiler/mlir/converters/torch_mlir_op_filter.cpp index 85a3dda0f4f..0bc9e283926 100644 --- a/pytorch_blade/pytorch_blade/compiler/mlir/converters/torch_mlir_op_filter.cpp +++ b/pytorch_blade/pytorch_blade/compiler/mlir/converters/torch_mlir_op_filter.cpp @@ -51,6 +51,7 @@ const std::unordered_set &GetTorchMlirWhiteList() { "aten::bitwise_not", "aten::bmm", "aten::cat", + "aten::chunk", "aten::contiguous", "aten::_convolution", "aten::convolution", @@ -113,6 +114,7 @@ const std::unordered_set &GetTorchMlirWhiteList() { "aten::size", "aten::slice", "aten::softmax", + "aten::split", "aten::std", "aten::squeeze", "aten::sub", diff --git a/pytorch_blade/pytorch_blade/torch-mlir/lib/Conversion/TorchToMhlo/DiscTorchToMhlo.cpp b/pytorch_blade/pytorch_blade/torch-mlir/lib/Conversion/TorchToMhlo/DiscTorchToMhlo.cpp index 839d75002e6..7df778bab29 100644 --- a/pytorch_blade/pytorch_blade/torch-mlir/lib/Conversion/TorchToMhlo/DiscTorchToMhlo.cpp +++ b/pytorch_blade/pytorch_blade/torch-mlir/lib/Conversion/TorchToMhlo/DiscTorchToMhlo.cpp @@ -259,6 +259,27 @@ class ConvertAtenUnaryOp : public OpConversionPattern { }; } // namespace +namespace { +template +class ConvertAtenArithOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult matchAndRewrite( + AtenOpT op, + OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter()->convertType( + op.getType()), + adaptor.a(), + adaptor.b()); + return success(); + } +}; +} // namespace + namespace { template class ConvertAtenExtractOp : public OpConversionPattern { @@ -1261,6 +1282,14 @@ class DiscConvertTorchToMhlo INSERT_UNARY_PATTERN(AtenSinOp, mhlo::SineOp) #undef INSERT_UNARY_PATTERN +#define INSERT_ARITH_PATTERN(AtenOp, ArithOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_ARITH_PATTERN(AtenAddIntOp, arith::AddIOp) + INSERT_ARITH_PATTERN(AtenSubIntOp, arith::SubIOp) + INSERT_ARITH_PATTERN(AtenFloordivIntOp, arith::DivSIOp) +#undef INSERT_UNARY_PATTERN + #define INSERT_EXTRACT_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); diff --git a/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/DiscDecomposeComplexOps.cpp b/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/DiscDecomposeComplexOps.cpp index 6d31f370edb..61167264024 100644 --- a/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/DiscDecomposeComplexOps.cpp +++ b/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/DiscDecomposeComplexOps.cpp @@ -84,6 +84,63 @@ class ConvertAtenOp : public OpConversionPattern { ConversionPatternRewriter& rewriter) const override; }; +LogicalResult decomposeSplits( + ConversionPatternRewriter& rewriter, + OperatorOp op, + Value splitSize, + Value dim, + int64_t chunks, + bool keepDim = true) { + if (chunks < 0) { + return failure(); + } + int64_t dimInt; + if (!matchPattern(dim, m_TorchConstantInt(&dimInt))) + return rewriter.notifyMatchFailure(op, "unknown dim"); + + auto self = op.getOperand(0); + auto selfTy = self.getType().dyn_cast(); + ArrayRef inputShape = selfTy.getSizes(); + + dimInt = toPositiveDim(dimInt, getTensorRank(self)); + + SmallVector sizes; + sizes.append(inputShape.begin(), inputShape.end()); + sizes[dimInt] = kUnknownSize; + + int64_t splitSizeInt = -1; + if (matchPattern(splitSize, m_TorchConstantInt(&splitSizeInt)) && + splitSizeInt == 1) { + sizes[dimInt] = 1; + } + Type sliceTy = + selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes), selfTy.getDtype()); + sizes.erase(sizes.begin() + dimInt); + Type sequeezeTy = + selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes), selfTy.getDtype()); + + auto intType = Torch::IntType::get(op.getContext()); + Location loc = op.getLoc(); + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value end = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + SmallVector slices; + for (int64_t k = 0; k < chunks; ++k) { + Value start = end; + end = rewriter.create(loc, intType, start, splitSize); + Value slice = rewriter.create( + loc, sliceTy, self, dim, start, end, one); + if (splitSizeInt == 1 && not keepDim) { + slice = rewriter.create(loc, sequeezeTy, slice, dim); + } + slices.emplace_back(slice); + } + rewriter.replaceOpWithNewOp( + op, op.getResult(0).getType(), slices); + return success(); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( OperatorOp op, @@ -134,6 +191,58 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, outTy, op.getOperand(0), op.getOperand(1)); return success(); + } else if ("aten.split.Tensor" == name) { + int64_t chunksInt = -1; + for (Operation* user : op.getResult(0).getUsers()) { + if (mlir::isa(user)) { + chunksInt = user->getNumResults(); + break; + } + } + return decomposeSplits( + rewriter, op, op.getOperand(1), op.getOperand(2), chunksInt); + } else if ("aten.chunk" == name) { + int64_t chunksInt = -1; + auto chunks = op.getOperand(1); + if (!matchPattern(chunks, m_TorchConstantInt(&chunksInt))) { + for (Operation* user : op.getResult(0).getUsers()) { + if (mlir::isa(user)) { + chunksInt = user->getNumResults(); + break; + } + } + if (chunksInt < 0) { + return rewriter.notifyMatchFailure(op, "unknown chunks"); + } + } + auto self = op.getOperand(0); + auto dim = op.getOperand(2); + + auto loc = op.getLoc(); + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + auto intType = Torch::IntType::get(op.getContext()); + Value dimSize = rewriter.create(loc, self, dim); + Value dimSizePlusChunk = + rewriter.create(loc, intType, dimSize, chunks); + Value dimSizePlusChunkMinusOne = + rewriter.create(loc, intType, dimSizePlusChunk, one); + Value splitSize = rewriter.create( + loc, intType, dimSizePlusChunkMinusOne, chunks); + return decomposeSplits(rewriter, op, splitSize, dim, chunksInt); + } else if ("aten.unbind.int" == name) { + int64_t chunksInt = -1; + for (Operation* user : op.getResult(0).getUsers()) { + if (mlir::isa(user)) { + chunksInt = user->getNumResults(); + break; + } + } + auto loc = op.getLoc(); + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + return decomposeSplits( + rewriter, op, one, op.getOperand(1), chunksInt, /*keepDim*/ false); } return failure(); @@ -359,6 +468,9 @@ class DiscDecomposeComplexOpsPass "aten.div_inplace.Tensor", "aten.mul_inplace.Tensor", "aten.sub_inplace.Tensor", + "aten.split.Tensor", + "aten.chunk", + "aten.unbind.int", }; if (illegalSet.find(op.name().str()) != illegalSet.end()) { diff --git a/pytorch_blade/tests/disc/ops/test_slices.py b/pytorch_blade/tests/disc/ops/test_slices.py index 80c7f2686a5..a7f5db4fe89 100644 --- a/pytorch_blade/tests/disc/ops/test_slices.py +++ b/pytorch_blade/tests/disc/ops/test_slices.py @@ -102,7 +102,6 @@ def func(x, y): annotations = [([4, -1, 256], dtype), ([4, -1, 256], dtype)] self._test_cvt_to_disc(func, test_data, annotations) - @skipIfEnableTorchMlir() def test_unbind(self): x = torch.randn([4, 64, 256], device=self.device) y = torch.randn([1, 4, 256], device=self.device) @@ -124,7 +123,6 @@ def func(x, y): with tools.trust_tracing_shape(): self._test_cvt_to_disc(func, test_data, annotations) - @skipIfEnableTorchMlir() def test_chunk(self): @torch.jit.script @@ -132,12 +130,30 @@ def func(x): z1, z2, z3, z4, z5, z6 = torch.chunk(x, 6, -1) return z1, z2, z3, z4, z5, z6 - print(func.graph) x = torch.randn([4, 64, 11], device=self.device) self._test_slice(func, x=x) x = torch.randn([4, 64, 12], device=self.device) self._test_slice(func, x=x) + def test_split(self): + + @torch.jit.script + def func(x): + z1, z2, z3, z4, z5, z6 = torch.split(x, 2, -1) + return z1, z2, z3, z4, z5, z6 + + x = torch.randn([4, 64, 11], device=self.device) + annotations = [([-1, -1, 11], torch.float)] + self._test_disc(func, annotations, (x, )) + annotations = [([-1, -1, -1], torch.float)] + self._test_disc(func, annotations, (x, )) + + x = torch.randn([4, 64, 12], device=self.device) + annotations = [([4, -1, 12], torch.float)] + self._test_disc(func, annotations, (x, )) + annotations = [([4, -1, -1], torch.float)] + self._test_disc(func, annotations, (x, )) + if __name__ == "__main__": unittest.main() diff --git a/pytorch_blade/tests/mhlo/mem_ops.mlir b/pytorch_blade/tests/mhlo/mem_ops.mlir index de6deb2b722..74c89004670 100644 --- a/pytorch_blade/tests/mhlo/mem_ops.mlir +++ b/pytorch_blade/tests/mhlo/mem_ops.mlir @@ -25,90 +25,3 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[?,4],f32>, %arg1: !tor return %0 : !torch.vtensor<[2,4],f32> } -// ----- - -// CHECK-LABEL: func.func @torch.aten.roll( -// CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { -// CHECK: %[[CST:.*]] = arith.constant dense<1> : tensor<2xi32> -// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 -// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32 -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[C3_I64:.*]] = arith.constant 3 : i64 -// CHECK: %[[C:.*]]-9_i64 = arith.constant -9 : i64 -// CHECK: %[[T0:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[T1:.*]] = arith.index_cast %[[T0]] : index to i64 -// CHECK: %[[T2:.*]] = arith.subi %[[T1]], %[[C3_I64]] : i64 -// CHECK: %[[T3:.*]] = arith.remsi %[[T2]], %[[T1]] : i64 -// CHECK: %[[T4:.*]] = arith.subi %[[C0_I64]], %[[T1]] : i64 -// CHECK: %[[T5:.*]] = arith.maxsi %[[T4]], %[[T3]] : i64 -// CHECK: %[[T6:.*]] = arith.minsi %[[T1]], %[[T5]] : i64 -// CHECK: %[[T7:.*]] = arith.addi %[[T1]], %[[T6]] : i64 -// CHECK: %[[T8:.*]] = arith.cmpi sge, %[[T6]], %[[C0_I64]] : i64 -// CHECK: %[[T9:.*]] = arith.select %[[T8]], %[[T6]], %[[T7]] : i64 -// CHECK: %[[T10:.*]] = arith.trunci %[[T9]] : i64 to i32 -// CHECK: %[[T11:.*]] = arith.trunci %[[T1]] : i64 to i32 -// CHECK: %[[T12:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[T13:.*]] = arith.index_cast %[[T12]] : index to i32 -// CHECK: %[[T14:.*]] = arith.index_cast %[[T0]] : index to i32 -// CHECK: %[[T15:.*]] = arith.cmpi eq, %[[T11]], %[[C0_I32]] : i32 -// CHECK: %[[T16:.*]] = arith.select %[[T15]], %[[T14]], %[[T11]] : i32 -// CHECK: %[[T17:.*]] = tensor.from_elements %[[C0_I32]], %[[T10]] : tensor<2xi32> -// CHECK: %[[T18:.*]] = tensor.from_elements %[[T13]], %[[T16]] : tensor<2xi32> -// CHECK: %[[T19:.*]] = mhlo.real_dynamic_slice %[[ARG0]], %[[T17]], %[[T18]], %[[CST]] : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor -// CHECK: %[[T20:.*]] = arith.maxsi %[[T4]], %[[C0_I64]] : i64 -// CHECK: %[[T21:.*]] = arith.minsi %[[T1]], %[[T20]] : i64 -// CHECK: %[[T22:.*]] = arith.addi %[[T1]], %[[T21]] : i64 -// CHECK: %[[T23:.*]] = arith.cmpi sge, %[[T21]], %[[C0_I64]] : i64 -// CHECK: %[[T24:.*]] = arith.select %[[T23]], %[[T21]], %[[T22]] : i64 -// CHECK: %[[T25:.*]] = arith.trunci %[[T24]] : i64 to i32 -// CHECK: %[[T26:.*]] = arith.cmpi eq, %[[T10]], %[[C0_I32]] : i32 -// CHECK: %[[T27:.*]] = arith.select %[[T26]], %[[T14]], %[[T10]] : i32 -// CHECK: %[[T28:.*]] = tensor.from_elements %[[C0_I32]], %[[T25]] : tensor<2xi32> -// CHECK: %[[T29:.*]] = tensor.from_elements %[[T13]], %[[T27]] : tensor<2xi32> -// CHECK: %[[T30:.*]] = mhlo.real_dynamic_slice %[[ARG0]], %[[T28]], %[[T29]], %[[CST]] : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor -// CHECK: %[[T31:.*]] = "mhlo.concatenate"(%[[T19]], %[[T30]]) {dimension = 1 : i64} : (tensor, tensor) -> tensor -// CHECK: %[[T32:.*]] = tensor.dim %[[T31]], %[[C0]] : tensor -// CHECK: %[[T33:.*]] = arith.index_cast %[[T32]] : index to i64 -// CHECK: %[[T34:.*]] = arith.subi %[[T33]], %[[C]]-9_i64 : i64 -// CHECK: %[[T35:.*]] = arith.remsi %[[T34]], %[[T33]] : i64 -// CHECK: %[[T36:.*]] = arith.subi %[[C0_I64]], %[[T33]] : i64 -// CHECK: %[[T37:.*]] = arith.maxsi %[[T36]], %[[T35]] : i64 -// CHECK: %[[T38:.*]] = arith.minsi %[[T33]], %[[T37]] : i64 -// CHECK: %[[T39:.*]] = arith.addi %[[T33]], %[[T38]] : i64 -// CHECK: %[[T40:.*]] = arith.cmpi sge, %[[T38]], %[[C0_I64]] : i64 -// CHECK: %[[T41:.*]] = arith.select %[[T40]], %[[T38]], %[[T39]] : i64 -// CHECK: %[[T42:.*]] = arith.trunci %[[T41]] : i64 to i32 -// CHECK: %[[T43:.*]] = arith.trunci %[[T33]] : i64 to i32 -// CHECK: %[[T44:.*]] = arith.index_cast %[[T32]] : index to i32 -// CHECK: %[[T45:.*]] = tensor.dim %[[T31]], %[[C1]] : tensor -// CHECK: %[[T46:.*]] = arith.index_cast %[[T45]] : index to i32 -// CHECK: %[[T47:.*]] = arith.cmpi eq, %[[T43]], %[[C0_I32]] : i32 -// CHECK: %[[T48:.*]] = arith.select %[[T47]], %[[T44]], %[[T43]] : i32 -// CHECK: %[[T49:.*]] = tensor.from_elements %[[T42]], %[[C0_I32]] : tensor<2xi32> -// CHECK: %[[T50:.*]] = tensor.from_elements %[[T48]], %[[T46]] : tensor<2xi32> -// CHECK: %[[T51:.*]] = mhlo.real_dynamic_slice %[[T31]], %[[T49]], %[[T50]], %[[CST]] : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor -// CHECK: %[[T52:.*]] = arith.maxsi %[[T36]], %[[C0_I64]] : i64 -// CHECK: %[[T53:.*]] = arith.minsi %[[T33]], %[[T52]] : i64 -// CHECK: %[[T54:.*]] = arith.addi %[[T33]], %[[T53]] : i64 -// CHECK: %[[T55:.*]] = arith.cmpi sge, %[[T53]], %[[C0_I64]] : i64 -// CHECK: %[[T56:.*]] = arith.select %[[T55]], %[[T53]], %[[T54]] : i64 -// CHECK: %[[T57:.*]] = arith.trunci %[[T56]] : i64 to i32 -// CHECK: %[[T58:.*]] = arith.cmpi eq, %[[T42]], %[[C0_I32]] : i32 -// CHECK: %[[T59:.*]] = arith.select %[[T58]], %[[T44]], %[[T42]] : i32 -// CHECK: %[[T60:.*]] = tensor.from_elements %[[T57]], %[[C0_I32]] : tensor<2xi32> -// CHECK: %[[T61:.*]] = tensor.from_elements %[[T59]], %[[T46]] : tensor<2xi32> -// CHECK: %[[T62:.*]] = mhlo.real_dynamic_slice %[[T31]], %[[T60]], %[[T61]], %[[CST]] : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor -// CHECK: %[[T63:.*]] = "mhlo.concatenate"(%[[T51]], %[[T62]]) {dimension = 0 : i64} : (tensor, tensor) -> tensor -// CHECK: return %[[T63]] : tensor -func.func @torch.aten.roll(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %int-9 = torch.constant.int -9 - %int3 = torch.constant.int 3 - %0 = torch.prim.ListConstruct %int3, %int-9 : (!torch.int, !torch.int) -> !torch.list - %1 = torch.prim.ListConstruct %int1, %int0 : (!torch.int, !torch.int) -> !torch.list - %2 = torch.aten.roll %arg0, %0, %1 : !torch.vtensor<[?,?],f32>, !torch.list, !torch.list -> !torch.vtensor<[?,?],f32> - return %2 : !torch.vtensor<[?,?],f32> -} -