Skip to content

Commit

Permalink
[Torch] Decompose AtenMaskedScatterOp (llvm#3353)
Browse files Browse the repository at this point in the history
Co-authored-by: Yuanqiang Liu <liuyuanqiang.yqliu@bytedance.com>
  • Loading branch information
2 people authored and Branko Trifkovic committed May 24, 2024
1 parent e507f30 commit 1d88e6f
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 0 deletions.
3 changes: 3 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ bool isBuiltInType(Type type);
// std::nullopt is returned if the tensorRank can't be determined.
std::optional<unsigned> getTensorRank(Value tensor);

// Helper function to get the number of elements in a tensor.
std::optional<int64_t> getTensorNumel(Value tensor);

bool isViewLikeOp(Operation *op);

Value getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc,
Expand Down
99 changes: 99 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3506,6 +3506,104 @@ class DecomposeAtenMaskedFillScalarOp
};
} // namespace

// Decompose aten.masked_scatter:
// def masked_scatter(self: Tensor, mask: Tensor, source: Tensor) -> Tensor:
// mask_int = mask + torch.zeros_like(self)
// prefix_sum = torch.cumsum(mask_int.flatten(), dim=0)
// mask_prefix = torch.clamp(prefix_sum - 1, min=0)
// mask = mask.to(torch.bool)
// source = source.flatten()[mask_prefix].reshape(mask.shape)
// return torch.where(mask, source, self)
namespace {
class DecomposeAtenMaskedScatterOp
: public OpRewritePattern<AtenMaskedScatterOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenMaskedScatterOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto context = op.getContext();
Value mask = op.getMask();
Value source = op.getSource();
Value self = op.getSelf();

auto selfTy = cast<BaseTensorType>(self.getType());
auto resTy = cast<BaseTensorType>(op.getType());
auto sourceTy = cast<BaseTensorType>(source.getType());

if (!resTy || !resTy.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
if (!selfTy || !selfTy.areAllSizesKnown())
return rewriter.notifyMatchFailure(
op, "Unimplemented: no implementation for rankless tensor");
if (!sourceTy || !sourceTy.areAllSizesKnown() || !sourceTy.hasDtype())
return rewriter.notifyMatchFailure(
op, "Unimplemented: no implementation for rankless tensor");

int64_t selfNumel = getTensorNumel(self).value(); // as selfTy has sizes
int64_t sourceNumel =
getTensorNumel(source).value(); // as sourceTy has sizes
int64_t selfRank = selfTy.getSizes().size();
int64_t sourceRank = sourceTy.getSizes().size();

Value constZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value constOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value constNone = rewriter.create<ConstantNoneOp>(loc);
Value selfLastDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(selfRank - 1));
Value sourceLastDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(sourceRank - 1));

auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
auto int64Dtype = getDtypeIntValueForType(
rewriter, loc,
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
auto selfIntType = selfTy.getWithSizesAndDtype(selfTy.getSizes(), si64Type);

Value zerosLike = rewriter.create<Torch::AtenZerosLikeOp>(
loc, selfIntType, self, int64Dtype, constNone, constNone, constNone,
constNone);
Value maskInt = rewriter.create<Torch::AtenAddTensorOp>(
loc, selfIntType, mask, zerosLike, constOne);

auto flattenMaskedType = selfTy.getWithSizesAndDtype(
/*optionalSizes=*/{selfNumel}, si64Type);
Value maskIntFlatten = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flattenMaskedType, maskInt, constZero, selfLastDim);
Value prefixSum = rewriter.create<Torch::AtenCumsumOp>(
loc, flattenMaskedType, maskIntFlatten,
/*dim=*/constZero, constNone);
Value prefixSumMinusOne = rewriter.create<Torch::AtenSubScalarOp>(
loc, flattenMaskedType, prefixSum, constOne, constOne);
Value maskPrefix = rewriter.create<Torch::AtenClampOp>(
loc, flattenMaskedType, prefixSumMinusOne, /*min=*/constZero,
/*max=*/constNone);

auto sourceFlattenType = sourceTy.getWithSizesAndDtype(
/*optionalSizes=*/{sourceNumel}, sourceTy.getDtype());
Value sourceFlatten = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, sourceFlattenType, source, constZero, sourceLastDim);

auto selectSourceType = sourceTy.getWithSizesAndDtype(
/*optionalSizes=*/{selfNumel}, sourceTy.getDtype());
Value selectSource = rewriter.create<Torch::AtenIndexSelectOp>(
loc, selectSourceType, sourceFlatten, constZero, maskPrefix);

// Reshape normalized output back to the original input shape
auto selfShape = rewriter.create<AtenSizeOp>(
loc, Torch::ListType::get(IntType::get(context)), self);
Value sourceReshape = rewriter.create<Torch::AtenViewOp>(
loc, selfTy, selectSource, selfShape);
rewriter.replaceOpWithNewOp<Torch::AtenWhereSelfOp>(op, resTy, mask,
sourceReshape, self);
return success();
}
};
} // namespace

// Decompose aten._convolution-like to aten.convolution
namespace {
template <typename ConvolutionLikeOp>
Expand Down Expand Up @@ -7974,6 +8072,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNanToNumOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedScatterOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>(
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenWhereScalarOtherOp>();
target.addIllegalOp<AtenWhereScalarSelfOp>();
target.addIllegalOp<AtenMaskedFillScalarOp>();
target.addIllegalOp<AtenMaskedScatterOp>();
target.addIllegalOp<AtenSizeOp>();
target.addIllegalOp<AtenReshapeOp>();
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
Expand Down
13 changes: 13 additions & 0 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,19 @@ std::optional<unsigned> Torch::getTensorRank(Value tensor) {
return tensorType.getSizes().size();
}

std::optional<int64_t> Torch::getTensorNumel(Value tensor) {
BaseTensorType tensorType = cast<BaseTensorType>(tensor.getType());
if (!tensorType.hasSizes())
return std::nullopt;
int64_t numel = 1;
for (auto dim : tensorType.getSizes()) {
if (dim == ShapedType::kDynamic)
return ShapedType::kDynamic;
numel *= dim;
}
return numel;
}

bool Torch::isViewLikeOp(Operation *op) {
// AtenContiguousOp might return a view, so this is conservatively
// correct. We could potentially be more precise and identify the cases
Expand Down
2 changes: 2 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,7 @@
"LinspaceTwoSizeModule_basic",
"MaskedFillScalarFloatValueStaticModule_basic",
"MaskedFillScalarIntValueStaticModule_basic",
"MaskedScatterStaticBasic_basic",
"Matmul4dStatic_basic",
"Matmul_2d",
"Matmul_dot",
Expand Down Expand Up @@ -2366,6 +2367,7 @@
"LinalgNormKeepDimComplexModule_basic",
"LinalgVectorNormComplexModule_basic",
"LogSoftmaxBackwardModule_basic",
"MaskedScatterStaticBasic_basic",
"MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dModule_basic",
Expand Down
25 changes: 25 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,31 @@
# ==============================================================================


class MaskedScatterStaticBasic(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([4, 4], torch.float32, True),
([4, 4], torch.bool, True),
([8, 8], torch.float32, True),
]
)
def forward(self, x, mask, y):
return torch.masked_scatter(x, mask, y)


@register_test_case(module_factory=lambda: MaskedScatterStaticBasic())
def MaskedScatterStaticBasic_basic(module, tu: TestUtils):
x = torch.rand(4, 4)
mask = torch.rand(4, 4) > 0.5
y = torch.rand(8, 8)
module.forward(x, mask, y)


class IndexPutImpl1DFloatNonAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 1d88e6f

Please sign in to comment.