Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

Commit

Permalink
BladeDISC related patches
Browse files Browse the repository at this point in the history
* Fix float width
* Fix divide_floor & export promoteTypes api (#9)
* To comply with the old pytorch versions
* Add native_dropout_backward & native_layer_norm_backward decomposition (#15)
* Add native_dropout and related ops pattern (llvm#1211)
* [MHLO] fix dot general contract
* Fix batch_norm, div.Tensor_mode and folder (#21)
* Reimplement linear lowering
* Reimplement 2-D rhs for mutmul
* Add torchdynamo
* Decompose torch.slice_scatter (llvm#1622)
* Fix i64 torch.tensor dtype
* Add more mhlo basic converters
* Alleviate softmax datatype check (#24)
* Fix decompose native_batch_norm (#27)
* Support group_norm lowering (#25)
* Decompose torch.ones/zeros (#28)
* Fix softmax output type
* Fix gather
* Fix some decompose patterns
* Not check assert at runtime (#31)
* Fix bool tensor attr conversion bug (#32)
* Fix mlirDenseElementsAttrBoolGet
  • Loading branch information
Tanyo Kwok authored and JamesTheZ committed Jul 19, 2023
1 parent 0caaf8d commit 38ece66
Show file tree
Hide file tree
Showing 23 changed files with 1,059 additions and 102 deletions.
128 changes: 110 additions & 18 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4929,6 +4929,97 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
}];
}

def Torch_AtenInstanceNormOp : Torch_Op<"aten.instance_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchOptionalTensorType:$running_mean,
AnyTorchOptionalTensorType:$running_var,
Torch_BoolType:$use_input_stats,
Torch_FloatType:$momentum,
Torch_FloatType:$eps,
Torch_BoolType:$cudnn_enabled
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenInstanceNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 9, 1);
}
void AtenInstanceNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 9, 1);
}
}];
}

def Torch_AtenGroupNormOp : Torch_Op<"aten.group_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
Torch_IntType:$num_groups,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
Torch_FloatType:$eps,
Torch_BoolType:$cudnn_enabled
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenGroupNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenGroupNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}

def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
Torch_IntType:$N,
Torch_IntType:$C,
Torch_IntType:$HxW,
Torch_IntType:$group,
Torch_FloatType:$eps
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1,
AnyTorchTensorType:$result2
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNativeGroupNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 3);
}
void AtenNativeGroupNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 3);
}
}];
}

def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -7233,9 +7324,10 @@ def Torch_AtenOnesLikeOp : Torch_Op<"aten.ones_like", [
}

def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [
Pure,
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
ReadOnly,
]> {
let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`";
let arguments = (ins
Expand Down Expand Up @@ -7742,53 +7834,53 @@ def Torch_AtenMaxOp : Torch_Op<"aten.max", [
}];
}

def Torch_AtenMaxDimOp : Torch_Op<"aten.max.dim", [
def Torch_AtenAmaxOp : Torch_Op<"aten.amax", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)`";
let summary = "Generated op for `aten::amax : (Tensor, int[]?, bool) -> Tensor`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchOptionalListOfTorchIntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$values,
AnyTorchTensorType:$indices
AnyTorchTensorType:$results
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxDimOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 2);
ParseResult AtenAmaxOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenMaxDimOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 2);
void AtenAmaxOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenAmaxOp : Torch_Op<"aten.amax", [
def Torch_AtenMaxDimOp : Torch_Op<"aten.max.dim", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::amax : (Tensor, int[], bool) -> (Tensor)`";
let summary = "Generated op for `aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$dim,
Torch_IntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$result
AnyTorchTensorType:$values,
AnyTorchTensorType:$indices
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAmaxOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
ParseResult AtenMaxDimOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 2);
}
void AtenAmaxOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
void AtenMaxDimOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 2);
}
}];
}
Expand Down
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ enum ReductionType { MAX, MEAN, MIN, SUM, PROD };

ReductionType get_reduction_enum(const llvm::StringRef &reduce);

ScalarType promoteTypes(ScalarType a, ScalarType b);
} // namespace torch_upstream
} // namespace torch
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor",
let assemblyFormat = [{
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand All @@ -61,6 +62,7 @@ def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tenso
let assemblyFormat = [{
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand All @@ -80,6 +82,7 @@ def TorchConversion_ToI1Op : TorchConversion_Op<"to_i1", [
let assemblyFormat = [{
$operand attr-dict
}];
let hasFolder = 1;
}

def TorchConversion_FromI1Op : TorchConversion_Op<"from_i1", [
Expand All @@ -98,6 +101,7 @@ def TorchConversion_FromI1Op : TorchConversion_Op<"from_i1", [
let assemblyFormat = [{
$operand attr-dict
}];
let hasFolder = 1;
}

def TorchConversion_ToI64Op : TorchConversion_Op<"to_i64", [
Expand Down
9 changes: 8 additions & 1 deletion lib/Conversion/TorchToArith/TorchToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,13 +429,20 @@ class ConvertTorchToArith : public ConvertTorchToArithBase<ConvertTorchToArith>
target.addIllegalOp<AtenAddOp>();
patterns.add<ConvertAtenAddOp>(typeConverter, context);

target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>();
// target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>();
// patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
// context);
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp,
AtenRemainderIntOp>();
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenSubIntOp, arith::SubIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenRemainderIntOp, arith::RemSIOp>>(
typeConverter, context);

target.addIllegalOp<AtenSubFloatOp>();
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
typeConverter, context);
Expand Down
16 changes: 7 additions & 9 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_stablehlo;
using namespace mlir::torch::TorchConversion;

LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
mlir::Value &self, mlir::Value &other,
Expand Down Expand Up @@ -354,8 +355,8 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
outElemTy);
}
DenseIntElementsAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, lhs, outType);
rhs = hlo::promoteType(rewriter, rhs, outType);
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
auto loc = op.getLoc();
Value result =
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
Expand Down Expand Up @@ -1547,12 +1548,10 @@ class ConvertRuntimeAssertOp : public OpConversionPattern<RuntimeAssertOp> {
matchAndRewrite(RuntimeAssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
bool condition;
if (!matchPattern(op.getCondition(), m_TorchConstantBool(&condition))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: condition must be a constant");
}
if (!condition) {
return op->emitError("condition must be true");
if (matchPattern(op.getCondition(), m_TorchConstantBool(&condition))) {
if (!condition) {
return op->emitError("condition must be true");
}
}
rewriter.eraseOp(op);
return success();
Expand Down Expand Up @@ -1679,7 +1678,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenContiguousOp);

INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenGeluOp);
Expand Down
55 changes: 31 additions & 24 deletions lib/Conversion/TorchToStablehlo/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,50 +73,57 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
}

RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op,
Value &lhs, Value &rhs,
Value &lhs, Value &rhs, int64_t nBatchDims,
int64_t lhsResultDim, int64_t rhsResultDim,
int64_t lhsContractingDim,
int64_t rhsContractingDim) {
auto lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
auto rhsTy = rhs.getType().dyn_cast<RankedTensorType>();

bool shouldCastLhs = false;
bool shouldCastRhs = false;
auto oldLhsShape = lhsTy.getShape();
auto oldRhsShape = rhsTy.getShape();
SmallVector<int64_t> lhsShape;
SmallVector<int64_t> rhsShape;
SmallVector<int64_t> outShape;

lhsShape.append(oldLhsShape.begin(), oldLhsShape.end());
rhsShape.append(oldRhsShape.begin(), oldRhsShape.end());
// set batch dims
for (auto k = 0; k < nBatchDims; ++k) {
if (lhsShape[k] == ShapedType::kDynamic && rhsShape[k] >= 0) {
lhsShape[k] = rhsShape[k];
shouldCastLhs = true;
}
if (rhsShape[k] == ShapedType::kDynamic && lhsShape[k] >= 0) {
rhsShape[k] = lhsShape[k];
shouldCastRhs = true;
}
outShape.push_back(lhsShape[k]);
}
// set contracting dims
auto lhsContractingDimSize = lhsShape[lhsContractingDim];
auto rhsContractingDimSize = rhsShape[rhsContractingDim];
if (lhsContractingDimSize != rhsContractingDimSize) {
if (lhsContractingDimSize == ShapedType::kDynamic &&
rhsContractingDimSize >= 0) {
lhsShape[lhsContractingDim] = rhsContractingDimSize;
auto newRankTy = RankedTensorType::get(lhsShape, lhsTy.getElementType());
lhs = rewriter.create<tensor::CastOp>(op->getLoc(), newRankTy, lhs);
shouldCastLhs = true;
} else if (rhsContractingDimSize == ShapedType::kDynamic &&
lhsContractingDimSize >= 0) {
rhsShape[rhsContractingDim] = lhsContractingDimSize;
auto newRankTy = RankedTensorType::get(rhsShape, rhsTy.getElementType());
rhs = rewriter.create<tensor::CastOp>(op->getLoc(), newRankTy, rhs);
shouldCastRhs = true;
}
}
SmallVector<int64_t> outShape;
// set batch dims, will skip invalid dimensions
for (int64_t k = 0; k < static_cast<int64_t>(lhsShape.size()); ++k) {
if (k == lhsResultDim || k == lhsContractingDim)
continue;
outShape.push_back(lhsShape[k]);
if (shouldCastLhs) {
auto newRankTy = RankedTensorType::get(lhsShape, lhsTy.getElementType());
lhs = rewriter.create<tensor::CastOp>(op->getLoc(), newRankTy, lhs);
}
for (int64_t k = 0, b = 0; k < static_cast<int64_t>(rhsShape.size()); ++k) {
if (b >= static_cast<int64_t>(outShape.size()))
break;
if (k == rhsResultDim || k == rhsContractingDim)
continue;
if (outShape[b] == ShapedType::kDynamic && rhsShape[k] >= 0) {
outShape[b] = rhsShape[k];
}
b++;

if (shouldCastRhs) {
auto newRankTy = RankedTensorType::get(rhsShape, rhsTy.getElementType());
rhs = rewriter.create<tensor::CastOp>(op->getLoc(), newRankTy, rhs);
}

// set result dimensions
Expand Down Expand Up @@ -379,10 +386,10 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
auto lhsRank = lhsTy.getRank();
auto rhsRank = rhsTy.getRank();

if (lhsRank != 2 && lhsRank != 3)
return op.emitError("aten.Linear called but input rank not 2 or 3");
if (rhsRank != 2 && rhsRank != 3)
return op.emitError("aten.Linear called but weight rank not 2 or 3");
if (lhsRank < 1)
return op.emitError("aten.Linear called but input rank 0");
if (rhsRank != 2)
return op.emitError("aten.Linear called but weight rank not 2");

return success();
}
Expand Down
Loading

0 comments on commit 38ece66

Please sign in to comment.