Skip to content

Commit

Permalink
add aten.split.Tensor (#888)
Browse files Browse the repository at this point in the history
* add aten.split.Tensor

* add unbind & chunk & ut
  • Loading branch information
Tanyo Kwok committed Dec 26, 2022
1 parent 59359ff commit 71bb36e
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 90 deletions.
35 changes: 35 additions & 0 deletions pytorch_blade/pytorch_blade/compiler/jit/torch/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,41 @@ class ShapePropagator : public PropertyPropBase {
}
return;
}
case prim::ListUnpack: {
auto input_node = node->input()->node();
if (input_node->matches(
"aten::split.Tensor(Tensor(a) self, int split_size, int dim=0) -> Tensor(a)[]") ||
input_node->matches(
"aten::chunk(Tensor(a) self, int chunks, int dim=0) -> Tensor(a)[]") ||
input_node->matches(
"aten::unbind.int(Tensor(a) self, int dim=0) -> Tensor(a)[]")) {
if (auto self_type =
input_node->input(0)->type()->cast<TensorType>()) {
auto sizes_opt = self_type->symbolic_sizes().sizes();
auto dim_opt = input_node->get<int64_t>(attr::dim);
if (!(sizes_opt && dim_opt))
return;

std::vector<c10::ShapeSymbol> new_sizes = sizes_opt.value();
int64_t input_rank = new_sizes.size();
int64_t dim =
at::maybe_wrap_dim(dim_opt.value(), input_rank, false);
if (input_node->matches(
"aten::unbind.int(Tensor(a) self, int dim=0) -> Tensor(a)[]")) {
new_sizes.erase(new_sizes.begin() + dim);
} else {
// set default to dynamic
new_sizes[dim] = ShapeSymbol::newSymbol();
}

for (size_t i = 0; i < node->outputs().size(); ++i) {
if (auto type = node->output(i)->type()->cast<TensorType>())
node->output(i)->setType(type->withSymbolicShapes(new_sizes));
}
}
}
return;
}
case prim::Constant: {
if (node->output()->type()->isSubtypeOf(TensorType::get())) {
node->output()->inferTypeFrom(node->t(attr::value));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const std::unordered_set<std::string> &GetTorchMlirWhiteList() {
"aten::bitwise_not",
"aten::bmm",
"aten::cat",
"aten::chunk",
"aten::contiguous",
"aten::_convolution",
"aten::convolution",
Expand Down Expand Up @@ -113,6 +114,7 @@ const std::unordered_set<std::string> &GetTorchMlirWhiteList() {
"aten::size",
"aten::slice",
"aten::softmax",
"aten::split",
"aten::std",
"aten::squeeze",
"aten::sub",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,27 @@ class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> {
};
} // namespace

namespace {
template <typename AtenOpT, typename ArithOpT>
class ConvertAtenArithOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult matchAndRewrite(
AtenOpT op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
rewriter.replaceOpWithNewOp<ArithOpT>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.a(),
adaptor.b());
return success();
}
};
} // namespace

namespace {
template <typename AtenOpT>
class ConvertAtenExtractOp : public OpConversionPattern<AtenOpT> {
Expand Down Expand Up @@ -1261,6 +1282,14 @@ class DiscConvertTorchToMhlo
INSERT_UNARY_PATTERN(AtenSinOp, mhlo::SineOp)
#undef INSERT_UNARY_PATTERN

#define INSERT_ARITH_PATTERN(AtenOp, ArithOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenArithOp<AtenOp, ArithOp>>(typeConverter, context);
INSERT_ARITH_PATTERN(AtenAddIntOp, arith::AddIOp)
INSERT_ARITH_PATTERN(AtenSubIntOp, arith::SubIOp)
INSERT_ARITH_PATTERN(AtenFloordivIntOp, arith::DivSIOp)
#undef INSERT_UNARY_PATTERN

#define INSERT_EXTRACT_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenExtractOp<AtenOp>>(typeConverter, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,63 @@ class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
ConversionPatternRewriter& rewriter) const override;
};

LogicalResult decomposeSplits(
ConversionPatternRewriter& rewriter,
OperatorOp op,
Value splitSize,
Value dim,
int64_t chunks,
bool keepDim = true) {
if (chunks < 0) {
return failure();
}
int64_t dimInt;
if (!matchPattern(dim, m_TorchConstantInt(&dimInt)))
return rewriter.notifyMatchFailure(op, "unknown dim");

auto self = op.getOperand(0);
auto selfTy = self.getType().dyn_cast<BaseTensorType>();
ArrayRef<int64_t> inputShape = selfTy.getSizes();

dimInt = toPositiveDim(dimInt, getTensorRank(self));

SmallVector<int64_t> sizes;
sizes.append(inputShape.begin(), inputShape.end());
sizes[dimInt] = kUnknownSize;

int64_t splitSizeInt = -1;
if (matchPattern(splitSize, m_TorchConstantInt(&splitSizeInt)) &&
splitSizeInt == 1) {
sizes[dimInt] = 1;
}
Type sliceTy =
selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes), selfTy.getDtype());
sizes.erase(sizes.begin() + dimInt);
Type sequeezeTy =
selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes), selfTy.getDtype());

auto intType = Torch::IntType::get(op.getContext());
Location loc = op.getLoc();
Value one =
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value end =
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
SmallVector<Value, 4> slices;
for (int64_t k = 0; k < chunks; ++k) {
Value start = end;
end = rewriter.create<AtenAddIntOp>(loc, intType, start, splitSize);
Value slice = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, self, dim, start, end, one);
if (splitSizeInt == 1 && not keepDim) {
slice = rewriter.create<AtenSqueezeDimOp>(loc, sequeezeTy, slice, dim);
}
slices.emplace_back(slice);
}
rewriter.replaceOpWithNewOp<PrimListConstructOp>(
op, op.getResult(0).getType(), slices);
return success();
}

template <>
LogicalResult ConvertAtenOp<OperatorOp>::matchAndRewrite(
OperatorOp op,
Expand Down Expand Up @@ -134,6 +191,58 @@ LogicalResult ConvertAtenOp<OperatorOp>::matchAndRewrite(
rewriter.replaceOpWithNewOp<AtenDivTensorOp>(
op, outTy, op.getOperand(0), op.getOperand(1));
return success();
} else if ("aten.split.Tensor" == name) {
int64_t chunksInt = -1;
for (Operation* user : op.getResult(0).getUsers()) {
if (mlir::isa<PrimListUnpackOp>(user)) {
chunksInt = user->getNumResults();
break;
}
}
return decomposeSplits(
rewriter, op, op.getOperand(1), op.getOperand(2), chunksInt);
} else if ("aten.chunk" == name) {
int64_t chunksInt = -1;
auto chunks = op.getOperand(1);
if (!matchPattern(chunks, m_TorchConstantInt(&chunksInt))) {
for (Operation* user : op.getResult(0).getUsers()) {
if (mlir::isa<PrimListUnpackOp>(user)) {
chunksInt = user->getNumResults();
break;
}
}
if (chunksInt < 0) {
return rewriter.notifyMatchFailure(op, "unknown chunks");
}
}
auto self = op.getOperand(0);
auto dim = op.getOperand(2);

auto loc = op.getLoc();
Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
auto intType = Torch::IntType::get(op.getContext());
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
Value dimSizePlusChunk =
rewriter.create<AtenAddIntOp>(loc, intType, dimSize, chunks);
Value dimSizePlusChunkMinusOne =
rewriter.create<AtenSubIntOp>(loc, intType, dimSizePlusChunk, one);
Value splitSize = rewriter.create<AtenFloordivIntOp>(
loc, intType, dimSizePlusChunkMinusOne, chunks);
return decomposeSplits(rewriter, op, splitSize, dim, chunksInt);
} else if ("aten.unbind.int" == name) {
int64_t chunksInt = -1;
for (Operation* user : op.getResult(0).getUsers()) {
if (mlir::isa<PrimListUnpackOp>(user)) {
chunksInt = user->getNumResults();
break;
}
}
auto loc = op.getLoc();
Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
return decomposeSplits(
rewriter, op, one, op.getOperand(1), chunksInt, /*keepDim*/ false);
}

return failure();
Expand Down Expand Up @@ -359,6 +468,9 @@ class DiscDecomposeComplexOpsPass
"aten.div_inplace.Tensor",
"aten.mul_inplace.Tensor",
"aten.sub_inplace.Tensor",
"aten.split.Tensor",
"aten.chunk",
"aten.unbind.int",
};

if (illegalSet.find(op.name().str()) != illegalSet.end()) {
Expand Down
22 changes: 19 additions & 3 deletions pytorch_blade/tests/disc/ops/test_slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def func(x, y):
annotations = [([4, -1, 256], dtype), ([4, -1, 256], dtype)]
self._test_cvt_to_disc(func, test_data, annotations)

@skipIfEnableTorchMlir()
def test_unbind(self):
x = torch.randn([4, 64, 256], device=self.device)
y = torch.randn([1, 4, 256], device=self.device)
Expand All @@ -124,20 +123,37 @@ def func(x, y):
with tools.trust_tracing_shape():
self._test_cvt_to_disc(func, test_data, annotations)

@skipIfEnableTorchMlir()
def test_chunk(self):

@torch.jit.script
def func(x):
z1, z2, z3, z4, z5, z6 = torch.chunk(x, 6, -1)
return z1, z2, z3, z4, z5, z6

print(func.graph)
x = torch.randn([4, 64, 11], device=self.device)
self._test_slice(func, x=x)

x = torch.randn([4, 64, 12], device=self.device)
self._test_slice(func, x=x)

def test_split(self):

@torch.jit.script
def func(x):
z1, z2, z3, z4, z5, z6 = torch.split(x, 2, -1)
return z1, z2, z3, z4, z5, z6

x = torch.randn([4, 64, 11], device=self.device)
annotations = [([-1, -1, 11], torch.float)]
self._test_disc(func, annotations, (x, ))
annotations = [([-1, -1, -1], torch.float)]
self._test_disc(func, annotations, (x, ))

x = torch.randn([4, 64, 12], device=self.device)
annotations = [([4, -1, 12], torch.float)]
self._test_disc(func, annotations, (x, ))
annotations = [([4, -1, -1], torch.float)]
self._test_disc(func, annotations, (x, ))

if __name__ == "__main__":
unittest.main()
87 changes: 0 additions & 87 deletions pytorch_blade/tests/mhlo/mem_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,90 +25,3 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[?,4],f32>, %arg1: !tor
return %0 : !torch.vtensor<[2,4],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.roll(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[CST:.*]] = arith.constant dense<1> : tensor<2xi32>
// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64
// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C3_I64:.*]] = arith.constant 3 : i64
// CHECK: %[[C:.*]]-9_i64 = arith.constant -9 : i64
// CHECK: %[[T0:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[T1:.*]] = arith.index_cast %[[T0]] : index to i64
// CHECK: %[[T2:.*]] = arith.subi %[[T1]], %[[C3_I64]] : i64
// CHECK: %[[T3:.*]] = arith.remsi %[[T2]], %[[T1]] : i64
// CHECK: %[[T4:.*]] = arith.subi %[[C0_I64]], %[[T1]] : i64
// CHECK: %[[T5:.*]] = arith.maxsi %[[T4]], %[[T3]] : i64
// CHECK: %[[T6:.*]] = arith.minsi %[[T1]], %[[T5]] : i64
// CHECK: %[[T7:.*]] = arith.addi %[[T1]], %[[T6]] : i64
// CHECK: %[[T8:.*]] = arith.cmpi sge, %[[T6]], %[[C0_I64]] : i64
// CHECK: %[[T9:.*]] = arith.select %[[T8]], %[[T6]], %[[T7]] : i64
// CHECK: %[[T10:.*]] = arith.trunci %[[T9]] : i64 to i32
// CHECK: %[[T11:.*]] = arith.trunci %[[T1]] : i64 to i32
// CHECK: %[[T12:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[T13:.*]] = arith.index_cast %[[T12]] : index to i32
// CHECK: %[[T14:.*]] = arith.index_cast %[[T0]] : index to i32
// CHECK: %[[T15:.*]] = arith.cmpi eq, %[[T11]], %[[C0_I32]] : i32
// CHECK: %[[T16:.*]] = arith.select %[[T15]], %[[T14]], %[[T11]] : i32
// CHECK: %[[T17:.*]] = tensor.from_elements %[[C0_I32]], %[[T10]] : tensor<2xi32>
// CHECK: %[[T18:.*]] = tensor.from_elements %[[T13]], %[[T16]] : tensor<2xi32>
// CHECK: %[[T19:.*]] = mhlo.real_dynamic_slice %[[ARG0]], %[[T17]], %[[T18]], %[[CST]] : (tensor<?x?xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x?xf32>
// CHECK: %[[T20:.*]] = arith.maxsi %[[T4]], %[[C0_I64]] : i64
// CHECK: %[[T21:.*]] = arith.minsi %[[T1]], %[[T20]] : i64
// CHECK: %[[T22:.*]] = arith.addi %[[T1]], %[[T21]] : i64
// CHECK: %[[T23:.*]] = arith.cmpi sge, %[[T21]], %[[C0_I64]] : i64
// CHECK: %[[T24:.*]] = arith.select %[[T23]], %[[T21]], %[[T22]] : i64
// CHECK: %[[T25:.*]] = arith.trunci %[[T24]] : i64 to i32
// CHECK: %[[T26:.*]] = arith.cmpi eq, %[[T10]], %[[C0_I32]] : i32
// CHECK: %[[T27:.*]] = arith.select %[[T26]], %[[T14]], %[[T10]] : i32
// CHECK: %[[T28:.*]] = tensor.from_elements %[[C0_I32]], %[[T25]] : tensor<2xi32>
// CHECK: %[[T29:.*]] = tensor.from_elements %[[T13]], %[[T27]] : tensor<2xi32>
// CHECK: %[[T30:.*]] = mhlo.real_dynamic_slice %[[ARG0]], %[[T28]], %[[T29]], %[[CST]] : (tensor<?x?xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x?xf32>
// CHECK: %[[T31:.*]] = "mhlo.concatenate"(%[[T19]], %[[T30]]) {dimension = 1 : i64} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T32:.*]] = tensor.dim %[[T31]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[T33:.*]] = arith.index_cast %[[T32]] : index to i64
// CHECK: %[[T34:.*]] = arith.subi %[[T33]], %[[C]]-9_i64 : i64
// CHECK: %[[T35:.*]] = arith.remsi %[[T34]], %[[T33]] : i64
// CHECK: %[[T36:.*]] = arith.subi %[[C0_I64]], %[[T33]] : i64
// CHECK: %[[T37:.*]] = arith.maxsi %[[T36]], %[[T35]] : i64
// CHECK: %[[T38:.*]] = arith.minsi %[[T33]], %[[T37]] : i64
// CHECK: %[[T39:.*]] = arith.addi %[[T33]], %[[T38]] : i64
// CHECK: %[[T40:.*]] = arith.cmpi sge, %[[T38]], %[[C0_I64]] : i64
// CHECK: %[[T41:.*]] = arith.select %[[T40]], %[[T38]], %[[T39]] : i64
// CHECK: %[[T42:.*]] = arith.trunci %[[T41]] : i64 to i32
// CHECK: %[[T43:.*]] = arith.trunci %[[T33]] : i64 to i32
// CHECK: %[[T44:.*]] = arith.index_cast %[[T32]] : index to i32
// CHECK: %[[T45:.*]] = tensor.dim %[[T31]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[T46:.*]] = arith.index_cast %[[T45]] : index to i32
// CHECK: %[[T47:.*]] = arith.cmpi eq, %[[T43]], %[[C0_I32]] : i32
// CHECK: %[[T48:.*]] = arith.select %[[T47]], %[[T44]], %[[T43]] : i32
// CHECK: %[[T49:.*]] = tensor.from_elements %[[T42]], %[[C0_I32]] : tensor<2xi32>
// CHECK: %[[T50:.*]] = tensor.from_elements %[[T48]], %[[T46]] : tensor<2xi32>
// CHECK: %[[T51:.*]] = mhlo.real_dynamic_slice %[[T31]], %[[T49]], %[[T50]], %[[CST]] : (tensor<?x?xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x?xf32>
// CHECK: %[[T52:.*]] = arith.maxsi %[[T36]], %[[C0_I64]] : i64
// CHECK: %[[T53:.*]] = arith.minsi %[[T33]], %[[T52]] : i64
// CHECK: %[[T54:.*]] = arith.addi %[[T33]], %[[T53]] : i64
// CHECK: %[[T55:.*]] = arith.cmpi sge, %[[T53]], %[[C0_I64]] : i64
// CHECK: %[[T56:.*]] = arith.select %[[T55]], %[[T53]], %[[T54]] : i64
// CHECK: %[[T57:.*]] = arith.trunci %[[T56]] : i64 to i32
// CHECK: %[[T58:.*]] = arith.cmpi eq, %[[T42]], %[[C0_I32]] : i32
// CHECK: %[[T59:.*]] = arith.select %[[T58]], %[[T44]], %[[T42]] : i32
// CHECK: %[[T60:.*]] = tensor.from_elements %[[T57]], %[[C0_I32]] : tensor<2xi32>
// CHECK: %[[T61:.*]] = tensor.from_elements %[[T59]], %[[T46]] : tensor<2xi32>
// CHECK: %[[T62:.*]] = mhlo.real_dynamic_slice %[[T31]], %[[T60]], %[[T61]], %[[CST]] : (tensor<?x?xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x?xf32>
// CHECK: %[[T63:.*]] = "mhlo.concatenate"(%[[T51]], %[[T62]]) {dimension = 0 : i64} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: return %[[T63]] : tensor<?x?xf32>
func.func @torch.aten.roll(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int-9 = torch.constant.int -9
%int3 = torch.constant.int 3
%0 = torch.prim.ListConstruct %int3, %int-9 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.roll %arg0, %0, %1 : !torch.vtensor<[?,?],f32>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
return %2 : !torch.vtensor<[?,?],f32>
}

0 comments on commit 71bb36e

Please sign in to comment.