Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

Commit

Permalink
aicompiler rebase 20220830
Browse files Browse the repository at this point in the history
fix float width

fix divide_floor & export promoteTypes api (#9)
  • Loading branch information
Tanyo Kwok committed Aug 31, 2022
1 parent e52e886 commit f93457b
Show file tree
Hide file tree
Showing 13 changed files with 102 additions and 45 deletions.
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions };
//===-----------------------------------------------------------------------===//
enum EmbeddingBagMode { MODE_SUM, MODE_MEAN, MODE_MAX };

ScalarType promoteTypes(ScalarType a, ScalarType b);
} // namespace torch_upstream
} // namespace torch
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor",
let assemblyFormat = [{
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand All @@ -61,6 +62,7 @@ def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tenso
let assemblyFormat = [{
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand Down
6 changes: 5 additions & 1 deletion lib/Conversion/TorchToArith/TorchToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,13 +364,17 @@ class ConvertTorchToArith : public ConvertTorchToArithBase<ConvertTorchToArith>
target.addIllegalOp<Torch::ConstantIntOp>();
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
context);
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>();
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp,
AtenRemainderIntOp>();
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenSubIntOp, arith::SubIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenRemainderIntOp, arith::RemSIOp>>(
typeConverter, context);

target.addIllegalOp<AtenSubFloatOp>();
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
typeConverter, context);
Expand Down
24 changes: 14 additions & 10 deletions lib/Conversion/TorchToMhlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;

bool skipMultiplyAlpha(Value alphaValue) {
double doubleValue;
Expand Down Expand Up @@ -413,10 +414,8 @@ class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
Value dValue = shape[i];
Value newD;
int64_t dInt;
if (!(matchPattern(dValue, m_TorchConstantInt(&dInt)))) {
return op->emitError("element of desired shape must be a scalar");
}
if (i >= leadingRank && dInt == -1) {
if (i >= leadingRank && matchPattern(dValue, m_TorchConstantInt(&dInt)) &&
dInt == -1) {
newD = rewriter.create<mlir::tensor::DimOp>(op->getLoc(), self,
i - leadingRank);
} else {
Expand Down Expand Up @@ -637,7 +636,11 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
false),
lhs);
rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, lhs, zeroTensor);
auto outType = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();

rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, outType, lhs, zeroTensor);
return success();
}

Expand Down Expand Up @@ -665,7 +668,11 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement);
auto erfAdd = rewriter.create<mhlo::AddOp>(loc, erf, one);
auto halfMul = rewriter.create<mhlo::MulOp>(loc, erfAdd, half);
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul);
auto outType = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();

rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, outType, input, halfMul);
return success();
}
} // namespace
Expand Down Expand Up @@ -1039,7 +1046,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp);
INSERT_UNARY_FPONLY_PATTERN(AtenCloneOp, mhlo::CopyOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenNegOp, mhlo::NegOp);
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, mhlo::RsqrtOp);
#undef INSERT_UNARY_FPONLY_PATTERN

#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
Expand Down Expand Up @@ -1091,14 +1098,11 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenContiguousOp);

INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenGeluOp);
INSERT_ATENOP_PATTERN(AtenErfOp);

INSERT_ATENOP_PATTERN(AtenCatOp);

INSERT_ATENOP_PATTERN(AtenBatchNormOp);
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
INSERT_ATENOP_PATTERN(AtenNumelOp);
Expand Down
11 changes: 9 additions & 2 deletions lib/Conversion/TorchToMhlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) return failure();

llvm::sort(dims.begin(), dims.end());
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));

Expand Down Expand Up @@ -426,6 +427,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) return failure();
llvm::sort(dims.begin(), dims.end());
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));

Expand All @@ -446,7 +448,9 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
rewriter.create<mhlo::ReturnOp>(op->getLoc(), maxResult);
}

rewriter.replaceOp(op, mhloReduceOp.getResults());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
mhloReduceOp.getResults());
return success();
}
} // namespace
Expand Down Expand Up @@ -510,6 +514,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) return failure();

llvm::sort(dims.begin(), dims.end());
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));

Expand Down Expand Up @@ -551,7 +556,9 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
mhloReduceOp.getResult(0), outShapeTensor);
return success();
}
rewriter.replaceOp(op, mhloReduceOp.getResults());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
mhloReduceOp.getResults());
return success();
}
} // namespace
Expand Down
42 changes: 20 additions & 22 deletions lib/Conversion/TorchToMhlo/ViewLike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,18 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {

auto loc = op.getLoc();
auto newRank = dimSizes.size();
auto outTy = OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType());

if (newRank == 0 || rankType.getRank() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
SmallVector<int64_t, 1> newShape(newRank, 1);
Value output = rewriter.create<mhlo::ReshapeOp>(
loc,
RankedTensorType::get(
newShape,
outTy.template dyn_cast<RankedTensorType>().getElementType()),
adaptor.self());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy, output);
return success();
}

Expand All @@ -250,28 +256,19 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
Value numel = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));
for (auto d : dimSizes) {
numel = rewriter.create<arith::MulIOp>(loc, numel, d);
auto rank = rankType.getRank();
for (size_t d = 0; d < rank; ++d) {
Value dimSize = rewriter.create<arith::IndexCastOp>(
loc, intType, rewriter.create<tensor::DimOp>(loc, adaptor.self(), d));
numel = rewriter.create<arith::MulIOp>(loc, numel, dimSize);
}
numel = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
numel);

if (dimSizes.size() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.self());
return success();
}
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
Value computedShape = rewriter.create<mhlo::ComputeReshapeShapeOp>(
loc, mhloShape.getType(), numel, mhloShape);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.self(), computedShape);
op, outTy, adaptor.self(), computedShape);
return success();
}

Expand Down Expand Up @@ -360,7 +357,8 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "the size of the dimension being squeezed is can't be unknown");

rewriter.replaceOp(op, adaptor.self());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.self());
return success();
}

Expand Down Expand Up @@ -402,8 +400,8 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
if (failed(unsqzTensorInfo))
return rewriter.notifyMatchFailure(op,
"failed to create unsqueezed tensor");

rewriter.replaceOp(op, *unsqzTensorInfo);
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()), *unsqzTensorInfo);
return success();
}
} // namespace
Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/Torch/IR/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,10 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
if (auto floatType = dtype.dyn_cast<mlir::FloatType>()) {
return dtype;
} else if (auto integerType = dtype.dyn_cast<IntegerType>()) {
if (integerType.isUnsignedInteger()) {
return IntegerType::get(context, integerType.getWidth(),
IntegerType::Unsigned);
}
return IntegerType::get(context, integerType.getWidth(),
IntegerType::Signless);
}
Expand Down
26 changes: 18 additions & 8 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,26 +738,29 @@ class DecomposeAtenRollOp : public OpRewritePattern<AtenRollOp> {
auto self = op.self();
auto selfTy = self.getType().cast<BaseTensorType>();
// roll(input, shift, dim) = cat({
// slice(input, dim, -shift, none),
// slice(input, dim, 0, -shift)}, dim)
// slice(input, dim, (dimSize-shift)%dimSize, none),
// slice(input, dim, 0, (dimSize-shift)%dimSize}, dim)
auto imitateRoll = [&](Value input, Value shift, Value dim,
int64_t cstDim) {
Value negShift = rewriter.create<AtenNegIntOp>(loc, shift);
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
Value shiftPlus = rewriter.create<AtenSubIntOp>(loc, dimSize, shift);
Value splitPos =
rewriter.create<AtenRemainderIntOp>(loc, shiftPlus, dimSize);
ArrayRef<int64_t> inputShape = selfTy.getSizes();
SmallVector<int64_t> sizes;
sizes.append(inputShape.begin(), inputShape.end());
sizes[cstDim] = ShapedType::kDynamicSize;
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
selfTy.getDtype());
Value slice0 = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, input, dim, negShift, constNone, constOne);
loc, sliceTy, input, dim, splitPos, constNone, constOne);
Value slice1 = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, input, dim, constZero, negShift, constOne);
loc, sliceTy, input, dim, constZero, splitPos, constOne);

Type listType = Torch::ListType::get(sliceTy);
Value slices = rewriter.create<PrimListConstructOp>(
loc, listType, llvm::ArrayRef<Value>{slice0, slice1});
return rewriter.create<AtenCatOp>(loc, self.getType(), slices, dim);
return rewriter.create<AtenCatOp>(loc, op.getType(), slices, dim);
};
int rank = getTensorRank(self);
if (rank < 0)
Expand Down Expand Up @@ -1525,6 +1528,7 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,
Value input, Value prob,
Value &output) {
auto inputType = input.getType().cast<BaseTensorType>();
auto inputDtype = inputType.getDtype();
auto probType = prob.getType().cast<BaseTensorType>();
// Both the `input` and `prob` must be ranked tensors.
if (!inputType.hasSizes() || !inputType.hasDtype() || !probType.hasSizes() ||
Expand All @@ -1540,8 +1544,14 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,

// Since the `aten.rand_like` op expects float-type operand, create a
// float-type tensor with the same shape as that of the `input`.
Type floatDtype = rewriter.getF64Type();
if (inputDtype.isa<mlir::FloatType>() &&
inputDtype.cast<mlir::FloatType>().getWidth() < 64) {
floatDtype = rewriter.getF32Type();
}

Value floatTensor =
convertTensorToDtype(rewriter, loc, input, rewriter.getF64Type());
convertTensorToDtype(rewriter, loc, input, floatDtype);
Value none = rewriter.create<ConstantNoneOp>(loc);
Value randomVal = rewriter.create<AtenRandLikeOp>(
loc, floatTensor.getType(), floatTensor, /*dtype=*/none, /*layout=*/none,
Expand All @@ -1555,7 +1565,7 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,

// As the `output` is expected to be of the `input` type, convert the boolean
// tensor `lessThanP` to a `input` type tensor.
output = convertTensorToDtype(rewriter, loc, lessThanP, inputType.getDtype());
output = convertTensorToDtype(rewriter, loc, lessThanP, inputDtype);
return success();
}

Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ static bool isInitialValueTransitivelySafeToInline(Value initialValue,
namespace {
class InlineGlobalSlotsPass
: public InlineGlobalSlotsBase<InlineGlobalSlotsPass> {

void runOnOperation() override {

ModuleOp module = getOperation();
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ class TypeAnalysis : public dataflow::SparseDataFlowAnalysis<
using BaseT =
dataflow::SparseDataFlowAnalysis<dataflow::Lattice<ValueKnowledge>>;
using BaseT::SparseDataFlowAnalysis;
void setToEntryState(dataflow::Lattice<ValueKnowledge> *lattice) override {}

// Compute the knowledge for the results of an op, based on the knowledge of
// the operands and any information intrinsic to `op`.
Expand Down Expand Up @@ -1108,7 +1109,7 @@ void TypeAnalysis::visitOperation(Operation *op,

// Otherwise, this is an unknown operation. Just mark all results as
// having reached a pessimistic fixpoint.
markAllPessimisticFixpoint(results);
setAllToEntryStates(results);
return;
}

Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Utils/TorchUpstream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ static inline bool isQIntType(ScalarType t) {
// Type promotion related code are copied from
// aten/src/ATen/native/TypeProperties.*.
//===----------------------------------------------------------------------===//
static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
ScalarType promoteTypes(ScalarType a, ScalarType b) {
// This is generated according to NumPy's promote_types
constexpr auto u1 = ScalarType::Byte;
constexpr auto i1 = ScalarType::Char;
Expand Down
23 changes: 23 additions & 0 deletions lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ LogicalResult ToBuiltinTensorOp::inferReturnTypes(
return success();
}

void ToBuiltinTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](ToBuiltinTensorOp op, PatternRewriter &rewriter) {
auto fromBuiltinTensorOp =
op.getOperand().getDefiningOp<FromBuiltinTensorOp>();
if (!fromBuiltinTensorOp)
return rewriter.notifyMatchFailure(op, "operand not FromBuiltinTensorOp");
rewriter.replaceOp(op, fromBuiltinTensorOp.getOperand());
return success();
});
}

//===----------------------------------------------------------------------===//
// FromBuiltinTensorOp
//===----------------------------------------------------------------------===//
Expand All @@ -71,6 +83,17 @@ LogicalResult FromBuiltinTensorOp::verify() {
return success();
}

void FromBuiltinTensorOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](FromBuiltinTensorOp op, PatternRewriter &rewriter) {
auto toBuiltinTensorOp = op.getOperand().getDefiningOp<ToBuiltinTensorOp>();
if (!toBuiltinTensorOp)
return rewriter.notifyMatchFailure(op, "operand not ToBuiltinTensorOp");
rewriter.replaceOp(op, toBuiltinTensorOp.getOperand());
return success();
});
}

//===----------------------------------------------------------------------===//
// FromI64Op
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit f93457b

Please sign in to comment.