Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

Commit

Permalink
aicompiler rebase 20220830
Browse files Browse the repository at this point in the history
fix float width

fix divide_floor & export promoteTypes api (#9)
  • Loading branch information
Tanyo Kwok authored and bladedisc committed Sep 23, 2022
1 parent b0b2b3a commit b005b02
Show file tree
Hide file tree
Showing 13 changed files with 197 additions and 46 deletions.
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions };
//===-----------------------------------------------------------------------===//
enum EmbeddingBagMode { MODE_SUM, MODE_MEAN, MODE_MAX };

ScalarType promoteTypes(ScalarType a, ScalarType b);
} // namespace torch_upstream
} // namespace torch
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor",
let assemblyFormat = [{
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand All @@ -61,6 +62,7 @@ def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tenso
let assemblyFormat = [{
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand Down
6 changes: 5 additions & 1 deletion lib/Conversion/TorchToArith/TorchToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,13 +364,17 @@ class ConvertTorchToArith : public ConvertTorchToArithBase<ConvertTorchToArith>
target.addIllegalOp<Torch::ConstantIntOp>();
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
context);
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>();
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp,
AtenRemainderIntOp>();
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenSubIntOp, arith::SubIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenRemainderIntOp, arith::RemSIOp>>(
typeConverter, context);

target.addIllegalOp<AtenSubFloatOp>();
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
typeConverter, context);
Expand Down
17 changes: 12 additions & 5 deletions lib/Conversion/TorchToMhlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;
using namespace mlir::torch::torch_to_mhlo;

bool skipMultiplyAlpha(Value alphaValue) {
Expand Down Expand Up @@ -723,7 +724,11 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
false),
lhs);
rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, lhs, zeroTensor);
auto outType = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();

rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, outType, lhs, zeroTensor);
return success();
}

Expand All @@ -749,7 +754,11 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement);
auto erfAdd = rewriter.create<mhlo::AddOp>(loc, erf, one);
auto halfMul = rewriter.create<mhlo::MulOp>(loc, erfAdd, half);
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul);
auto outType = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();

rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, outType, input, halfMul);
return success();
}

Expand Down Expand Up @@ -1203,7 +1212,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp);
INSERT_UNARY_FPONLY_PATTERN(AtenCloneOp, mhlo::CopyOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenNegOp, mhlo::NegOp);
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, mhlo::RsqrtOp);
#undef INSERT_UNARY_FPONLY_PATTERN

#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
Expand Down Expand Up @@ -1259,12 +1268,10 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenContiguousOp);

INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenGeluOp);
INSERT_ATENOP_PATTERN(AtenErfOp);

INSERT_ATENOP_PATTERN(AtenCatOp);
INSERT_ATENOP_PATTERN(AtenClampOp);
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
Expand Down
11 changes: 9 additions & 2 deletions lib/Conversion/TorchToMhlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) return failure();

llvm::sort(dims.begin(), dims.end());
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));

Expand Down Expand Up @@ -438,6 +439,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) return failure();
llvm::sort(dims.begin(), dims.end());
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));

Expand All @@ -458,7 +460,9 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
rewriter.create<mhlo::ReturnOp>(op->getLoc(), maxResult);
}

rewriter.replaceOp(op, mhloReduceOp.getResults());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
mhloReduceOp.getResults());
return success();
}
} // namespace
Expand Down Expand Up @@ -522,6 +526,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) return failure();

llvm::sort(dims.begin(), dims.end());
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));

Expand Down Expand Up @@ -566,7 +571,9 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
mhloReduceOp.getResult(0), outShapeTensor);
return success();
}
rewriter.replaceOp(op, mhloReduceOp.getResults());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
mhloReduceOp.getResults());
return success();
}
} // namespace
Expand Down
42 changes: 20 additions & 22 deletions lib/Conversion/TorchToMhlo/ViewLike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,18 @@ class ConvertAtenViewOp : public ConvertAtenOp<AtenOpT> {

auto loc = op.getLoc();
auto newRank = dimSizes.size();
auto outTy = OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType());

if (newRank == 0 || rankType.getRank() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
SmallVector<int64_t, 1> newShape(newRank, 1);
Value output = rewriter.create<mhlo::ReshapeOp>(
loc,
RankedTensorType::get(
newShape,
outTy.template dyn_cast<RankedTensorType>().getElementType()),
adaptor.self());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy, output);
return success();
}

Expand All @@ -207,28 +213,19 @@ class ConvertAtenViewOp : public ConvertAtenOp<AtenOpT> {

Value numel = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));
for (auto d : dimSizes) {
numel = rewriter.create<arith::MulIOp>(loc, numel, d);
auto rank = rankType.getRank();
for (size_t d = 0; d < rank; ++d) {
Value dimSize = rewriter.create<arith::IndexCastOp>(
loc, intType, rewriter.create<tensor::DimOp>(loc, adaptor.self(), d));
numel = rewriter.create<arith::MulIOp>(loc, numel, dimSize);
}
numel = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
numel);

if (dimSizes.size() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.self());
return success();
}
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
Value computedShape = rewriter.create<mhlo::ComputeReshapeShapeOp>(
loc, mhloShape.getType(), numel, mhloShape);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.self(), computedShape);
op, outTy, adaptor.self(), computedShape);
return success();
}

Expand Down Expand Up @@ -357,7 +354,8 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "the size of the dimension being squeezed is can't be unknown");

rewriter.replaceOp(op, adaptor.self());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.self());
return success();
}

Expand Down Expand Up @@ -400,8 +398,8 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
if (failed(unsqzTensorInfo))
return rewriter.notifyMatchFailure(op,
"failed to create unsqueezed tensor");

rewriter.replaceOp(op, *unsqzTensorInfo);
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()), *unsqzTensorInfo);
return success();
}

Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/Torch/IR/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,10 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
if (auto floatType = dtype.dyn_cast<mlir::FloatType>()) {
return dtype;
} else if (auto integerType = dtype.dyn_cast<IntegerType>()) {
if (integerType.isUnsignedInteger()) {
return IntegerType::get(context, integerType.getWidth(),
IntegerType::Unsigned);
}
return IntegerType::get(context, integerType.getWidth(),
IntegerType::Signless);
}
Expand Down
Loading

0 comments on commit b005b02

Please sign in to comment.