Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

Commit

Permalink
aicompiler rebase 20220830
Browse files Browse the repository at this point in the history
fix float width

fix divide_floor & export promoteTypes api (#9)
  • Loading branch information
Tanyo Kwok committed Sep 1, 2022
1 parent a924de3 commit f5a8a93
Show file tree
Hide file tree
Showing 14 changed files with 135 additions and 45 deletions.
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions };
//===-----------------------------------------------------------------------===//
enum EmbeddingBagMode { MODE_SUM, MODE_MEAN, MODE_MAX };

ScalarType promoteTypes(ScalarType a, ScalarType b);
} // namespace torch_upstream
} // namespace torch
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor",
let assemblyFormat = [{
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand All @@ -61,6 +62,7 @@ def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tenso
let assemblyFormat = [{
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand Down
6 changes: 5 additions & 1 deletion lib/Conversion/TorchToArith/TorchToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,13 +364,17 @@ class ConvertTorchToArith : public ConvertTorchToArithBase<ConvertTorchToArith>
target.addIllegalOp<Torch::ConstantIntOp>();
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
context);
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>();
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp,
AtenRemainderIntOp>();
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenSubIntOp, arith::SubIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenRemainderIntOp, arith::RemSIOp>>(
typeConverter, context);

target.addIllegalOp<AtenSubFloatOp>();
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
typeConverter, context);
Expand Down
24 changes: 14 additions & 10 deletions lib/Conversion/TorchToMhlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;
using namespace mlir::torch::torch_to_mhlo;

bool skipMultiplyAlpha(Value alphaValue) {
Expand Down Expand Up @@ -409,10 +410,8 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
Value dValue = shape[i];
Value newD;
int64_t dInt;
if (!(matchPattern(dValue, m_TorchConstantInt(&dInt)))) {
return op->emitError("element of desired shape must be a scalar");
}
if (i >= leadingRank && dInt == -1) {
if (i >= leadingRank && matchPattern(dValue, m_TorchConstantInt(&dInt)) &&
dInt == -1) {
newD = rewriter.create<mlir::tensor::DimOp>(op->getLoc(), self,
i - leadingRank);
} else {
Expand Down Expand Up @@ -602,7 +601,11 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
false),
lhs);
rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, lhs, zeroTensor);
auto outType = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();

rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, outType, lhs, zeroTensor);
return success();
}

Expand All @@ -628,7 +631,11 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement);
auto erfAdd = rewriter.create<mhlo::AddOp>(loc, erf, one);
auto halfMul = rewriter.create<mhlo::MulOp>(loc, erfAdd, half);
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul);
auto outType = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();

rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, outType, input, halfMul);
return success();
}

Expand Down Expand Up @@ -984,7 +991,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp);
INSERT_UNARY_FPONLY_PATTERN(AtenCloneOp, mhlo::CopyOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenNegOp, mhlo::NegOp);
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, mhlo::RsqrtOp);
#undef INSERT_UNARY_FPONLY_PATTERN

#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
Expand Down Expand Up @@ -1040,14 +1047,11 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenContiguousOp);

INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenGeluOp);
INSERT_ATENOP_PATTERN(AtenErfOp);

INSERT_ATENOP_PATTERN(AtenCatOp);

INSERT_ATENOP_PATTERN(AtenBatchNormOp);
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
INSERT_ATENOP_PATTERN(AtenNumelOp);
Expand Down
33 changes: 33 additions & 0 deletions lib/Conversion/TorchToMhlo/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,35 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
return result.getResult();
}

void castContractingDim(PatternRewriter &rewriter, Operation *op, Value &lhs,
Value &rhs, int64_t lhsContractingDim,
int64_t rhsContractingDim) {
auto lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
auto rhsTy = rhs.getType().dyn_cast<RankedTensorType>();

auto oldLhsShape = lhsTy.getShape();
auto oldRhsShape = rhsTy.getShape();
SmallVector<int64_t> lhsShape;
SmallVector<int64_t> rhsShape;
lhsShape.append(oldLhsShape.begin(), oldLhsShape.end());
rhsShape.append(oldRhsShape.begin(), oldRhsShape.end());
auto lhsContractingDimSize = lhsShape[lhsContractingDim];
auto rhsContractingDimSize = rhsShape[rhsContractingDim];
if (lhsContractingDimSize != rhsContractingDimSize) {
if (lhsContractingDimSize == ShapedType::kDynamicSize &&
rhsContractingDimSize >= 0) {
lhsShape[lhsContractingDim] = rhsContractingDimSize;
auto newRankTy = RankedTensorType::get(lhsShape, lhsTy.getElementType());
lhs = rewriter.create<tensor::CastOp>(op->getLoc(), newRankTy, lhs);
} else if (rhsContractingDimSize == ShapedType::kDynamicSize &&
lhsContractingDimSize >= 0) {
rhsShape[rhsContractingDim] = lhsContractingDimSize;
auto newRankTy = RankedTensorType::get(rhsShape, rhsTy.getElementType());
rhs = rewriter.create<tensor::CastOp>(op->getLoc(), newRankTy, rhs);
}
}
}

void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
Value &inpRhs, int64_t leadingRank,
size_t dimSizeIndexBits) {
Expand Down Expand Up @@ -199,6 +228,8 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp<AtenOpT> {
->convertType(op.getType())
.template cast<RankedTensorType>();

castContractingDim(rewriter, op, lhs, rhs, lhsContractingDim,
rhsContractingDim);
output = rewriter
.create<mhlo::DotGeneralOp>(op->getLoc(), resultTy, lhs, rhs,
dotDimensionNumbers, nullptr)
Expand Down Expand Up @@ -358,6 +389,8 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
auto lhsContractingDim = nBatchDims + 1;
auto rhsContractingDim = nBatchDims;

castContractingDim(rewriter, op, lhs, rhs, lhsContractingDim,
rhsContractingDim);
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
mhlo::DotDimensionNumbersAttr::get(
rewriter.getContext(),
Expand Down
11 changes: 9 additions & 2 deletions lib/Conversion/TorchToMhlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) return failure();

llvm::sort(dims.begin(), dims.end());
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));

Expand Down Expand Up @@ -438,6 +439,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) return failure();
llvm::sort(dims.begin(), dims.end());
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));

Expand All @@ -458,7 +460,9 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
rewriter.create<mhlo::ReturnOp>(op->getLoc(), maxResult);
}

rewriter.replaceOp(op, mhloReduceOp.getResults());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
mhloReduceOp.getResults());
return success();
}
} // namespace
Expand Down Expand Up @@ -522,6 +526,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) return failure();

llvm::sort(dims.begin(), dims.end());
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));

Expand Down Expand Up @@ -566,7 +571,9 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
mhloReduceOp.getResult(0), outShapeTensor);
return success();
}
rewriter.replaceOp(op, mhloReduceOp.getResults());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
mhloReduceOp.getResults());
return success();
}
} // namespace
Expand Down
42 changes: 20 additions & 22 deletions lib/Conversion/TorchToMhlo/ViewLike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,18 @@ class ConvertAtenViewOp : public ConvertAtenOp<AtenOpT> {

auto loc = op.getLoc();
auto newRank = dimSizes.size();
auto outTy = OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType());

if (newRank == 0 || rankType.getRank() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
SmallVector<int64_t, 1> newShape(newRank, 1);
Value output = rewriter.create<mhlo::ReshapeOp>(
loc,
RankedTensorType::get(
newShape,
outTy.template dyn_cast<RankedTensorType>().getElementType()),
adaptor.self());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy, output);
return success();
}

Expand All @@ -207,28 +213,19 @@ class ConvertAtenViewOp : public ConvertAtenOp<AtenOpT> {

Value numel = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));
for (auto d : dimSizes) {
numel = rewriter.create<arith::MulIOp>(loc, numel, d);
auto rank = rankType.getRank();
for (size_t d = 0; d < rank; ++d) {
Value dimSize = rewriter.create<arith::IndexCastOp>(
loc, intType, rewriter.create<tensor::DimOp>(loc, adaptor.self(), d));
numel = rewriter.create<arith::MulIOp>(loc, numel, dimSize);
}
numel = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
numel);

if (dimSizes.size() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.self());
return success();
}
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
Value computedShape = rewriter.create<mhlo::ComputeReshapeShapeOp>(
loc, mhloShape.getType(), numel, mhloShape);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.self(), computedShape);
op, outTy, adaptor.self(), computedShape);
return success();
}

Expand Down Expand Up @@ -357,7 +354,8 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "the size of the dimension being squeezed is can't be unknown");

rewriter.replaceOp(op, adaptor.self());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.self());
return success();
}

Expand Down Expand Up @@ -400,8 +398,8 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
if (failed(unsqzTensorInfo))
return rewriter.notifyMatchFailure(op,
"failed to create unsqueezed tensor");

rewriter.replaceOp(op, *unsqzTensorInfo);
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()), *unsqzTensorInfo);
return success();
}

Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/Torch/IR/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,10 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
if (auto floatType = dtype.dyn_cast<mlir::FloatType>()) {
return dtype;
} else if (auto integerType = dtype.dyn_cast<IntegerType>()) {
if (integerType.isUnsignedInteger()) {
return IntegerType::get(context, integerType.getWidth(),
IntegerType::Unsigned);
}
return IntegerType::get(context, integerType.getWidth(),
IntegerType::Signless);
}
Expand Down
26 changes: 18 additions & 8 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,26 +738,29 @@ class DecomposeAtenRollOp : public OpRewritePattern<AtenRollOp> {
auto self = op.self();
auto selfTy = self.getType().cast<BaseTensorType>();
// roll(input, shift, dim) = cat({
// slice(input, dim, -shift, none),
// slice(input, dim, 0, -shift)}, dim)
// slice(input, dim, (dimSize-shift)%dimSize, none),
// slice(input, dim, 0, (dimSize-shift)%dimSize}, dim)
auto imitateRoll = [&](Value input, Value shift, Value dim,
int64_t cstDim) {
Value negShift = rewriter.create<AtenNegIntOp>(loc, shift);
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
Value shiftPlus = rewriter.create<AtenSubIntOp>(loc, dimSize, shift);
Value splitPos =
rewriter.create<AtenRemainderIntOp>(loc, shiftPlus, dimSize);
ArrayRef<int64_t> inputShape = selfTy.getSizes();
SmallVector<int64_t> sizes;
sizes.append(inputShape.begin(), inputShape.end());
sizes[cstDim] = ShapedType::kDynamicSize;
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
selfTy.getDtype());
Value slice0 = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, input, dim, negShift, constNone, constOne);
loc, sliceTy, input, dim, splitPos, constNone, constOne);
Value slice1 = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, input, dim, constZero, negShift, constOne);
loc, sliceTy, input, dim, constZero, splitPos, constOne);

Type listType = Torch::ListType::get(sliceTy);
Value slices = rewriter.create<PrimListConstructOp>(
loc, listType, llvm::ArrayRef<Value>{slice0, slice1});
return rewriter.create<AtenCatOp>(loc, self.getType(), slices, dim);
return rewriter.create<AtenCatOp>(loc, op.getType(), slices, dim);
};
int rank = getTensorRank(self);
if (rank < 0)
Expand Down Expand Up @@ -1525,6 +1528,7 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,
Value input, Value prob,
Value &output) {
auto inputType = input.getType().cast<BaseTensorType>();
auto inputDtype = inputType.getDtype();
auto probType = prob.getType().cast<BaseTensorType>();
// Both the `input` and `prob` must be ranked tensors.
if (!inputType.hasSizes() || !inputType.hasDtype() || !probType.hasSizes() ||
Expand All @@ -1540,8 +1544,14 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,

// Since the `aten.rand_like` op expects float-type operand, create a
// float-type tensor with the same shape as that of the `input`.
Type floatDtype = rewriter.getF64Type();
if (inputDtype.isa<mlir::FloatType>() &&
inputDtype.cast<mlir::FloatType>().getWidth() < 64) {
floatDtype = rewriter.getF32Type();
}

Value floatTensor =
convertTensorToDtype(rewriter, loc, input, rewriter.getF64Type());
convertTensorToDtype(rewriter, loc, input, floatDtype);
Value none = rewriter.create<ConstantNoneOp>(loc);
Value randomVal = rewriter.create<AtenRandLikeOp>(
loc, floatTensor.getType(), floatTensor, /*dtype=*/none, /*layout=*/none,
Expand All @@ -1555,7 +1565,7 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,

// As the `output` is expected to be of the `input` type, convert the boolean
// tensor `lessThanP` to a `input` type tensor.
output = convertTensorToDtype(rewriter, loc, lessThanP, inputType.getDtype());
output = convertTensorToDtype(rewriter, loc, lessThanP, inputDtype);
return success();
}

Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ static bool isInitialValueTransitivelySafeToInline(Value initialValue,
namespace {
class InlineGlobalSlotsPass
: public InlineGlobalSlotsBase<InlineGlobalSlotsPass> {

void runOnOperation() override {

ModuleOp module = getOperation();
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ class TypeAnalysis : public dataflow::SparseDataFlowAnalysis<
using BaseT =
dataflow::SparseDataFlowAnalysis<dataflow::Lattice<ValueKnowledge>>;
using BaseT::SparseDataFlowAnalysis;
void setToEntryState(dataflow::Lattice<ValueKnowledge> *lattice) override {}

// Compute the knowledge for the results of an op, based on the knowledge of
// the operands and any information intrinsic to `op`.
Expand Down Expand Up @@ -1108,7 +1109,7 @@ void TypeAnalysis::visitOperation(Operation *op,

// Otherwise, this is an unknown operation. Just mark all results as
// having reached a pessimistic fixpoint.
markAllPessimisticFixpoint(results);
setAllToEntryStates(results);
return;
}

Expand Down
Loading

0 comments on commit f5a8a93

Please sign in to comment.