Skip to content

Commit

Permalink
[Torch Dialect]Support aten.cosine_similarity (#2364)
Browse files Browse the repository at this point in the history
As title, add support for aten.cosine_similarity, support broadcast
inputA/inputB to the same shape
  • Loading branch information
JianzheXiao committed Nov 8, 2023
1 parent 026cb31 commit a42d4c1
Show file tree
Hide file tree
Showing 12 changed files with 298 additions and 83 deletions.
7 changes: 0 additions & 7 deletions include/torch-mlir/Conversion/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,6 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
Value torchOptionalInt, Value builtinInt,
Value defaultValue, Value dimSize);

// Checks whether the `inputA` and `inputB` are broadcast compatible or not. If
// yes, then computes the final broadcast shape.
void computeBroadcastShape(ConversionPatternRewriter &rewriter, Location loc,
Value inputA, Value inputB,
SmallVector<int64_t> &resultShape,
SmallVector<Value> &resultShapeValue);

} // namespace Torch
} // namespace torch
} // namespace mlir
Expand Down
26 changes: 26 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5123,6 +5123,32 @@ def Torch_AtenMvOp : Torch_Op<"aten.mv", [
}];
}

def Torch_AtenCosineSimilarityOp : Torch_Op<"aten.cosine_similarity", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::cosine_similarity : (Tensor, Tensor, int, float) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$x1,
AnyTorchTensorType:$x2,
Torch_IntType:$dim,
Torch_FloatType:$eps
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenCosineSimilarityOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenCosineSimilarityOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
8 changes: 8 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ Type getBuiltInTypeForTorchScalar(Type type);

Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
Type dtype);

// Checks whether the `inputA` and `inputB` are broadcast compatible or not. If
// yes, then computes the final broadcast shape.
void computeBroadcastShape(PatternRewriter &rewriter, Location loc,
Value inputA, Value inputB,
SmallVector<int64_t> &resultShape,
SmallVector<Value> &resultShapeValue);

// Helper to convert a tensor to a specific scalar type.
Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
Type dtype);
Expand Down
76 changes: 0 additions & 76 deletions lib/Conversion/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,82 +355,6 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
return castIntToIndex(rewriter, loc, boundedByDimSize);
}

// Checks whether the `shapeA` and `shapeB` are broadcast compatible or not. If
// yes, then computes the final broadcast shape.
void computeBroadcastShape(ConversionPatternRewriter &rewriter, Location loc,
Value inputA, Value inputB,
SmallVector<int64_t> &resultShape,
SmallVector<Value> &resultShapeValue) {
SmallVector<int64_t> shapeA{
inputA.getType().cast<BaseTensorType>().getSizes()};
SmallVector<int64_t> shapeB{
inputB.getType().cast<BaseTensorType>().getSizes()};
unsigned rankA = shapeA.size();
unsigned rankB = shapeB.size();
unsigned minRank = rankA > rankB ? rankB : rankA;
// Check whether the shapes of the tensors are broadcastable or not.
// Two tensors are “broadcastable” if the following rules hold:
// 1.) Each tensor has at least one dimension.
// 2.) When iterating over the dimension sizes, starting at the trailing
// dimension, the dimension sizes must either be equal, one of them is 1, or
// one of them does not exist.
for (unsigned i = 0; i < minRank; i++) {
Value sizeDimA = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rankA - i - 1));
Value sizeDimB = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rankB - i - 1));
Value sizeInputA =
rewriter.createOrFold<AtenSizeIntOp>(loc, inputA, sizeDimA);
Value sizeInputB =
rewriter.createOrFold<AtenSizeIntOp>(loc, inputB, sizeDimB);
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value cmpSizeAEqualsSizeB =
rewriter.create<Torch::AtenEqIntOp>(loc, sizeInputA, sizeInputB);
Value cmpSizeAEqualsOne =
rewriter.create<Torch::AtenEqIntOp>(loc, sizeInputA, torchCstOne);
Value cmpSizeBEqualsOne =
rewriter.create<Torch::AtenEqIntOp>(loc, sizeInputB, torchCstOne);
Value anyBoolOpList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(cmpSizeAEqualsOne.getType()),
SmallVector<Value>{cmpSizeAEqualsSizeB, cmpSizeAEqualsOne,
cmpSizeBEqualsOne});
Value cmp = rewriter.create<Torch::AtenAnyBoolOp>(loc, anyBoolOpList);
rewriter.create<Torch::RuntimeAssertOp>(
loc, cmp, "tensors are not broadcast compatible");
}
// If we reach here then it means both the shapes are broadcast compatible.
resultShape = rankA >= rankB ? shapeA : shapeB;
Value shapeTensor = rankA >= rankB ? inputA : inputB;
for (unsigned i = 0; i < resultShape.size(); i++) {
Value sizeDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
resultShapeValue.push_back(
rewriter.createOrFold<AtenSizeIntOp>(loc, shapeTensor, sizeDim));
}

unsigned resultRank = resultShape.size();
for (unsigned i = 0; i < minRank; i++) {
Value sizeDimA = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rankA - i - 1));
Value sizeDimB = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rankB - i - 1));
Value sizeInputA =
rewriter.createOrFold<AtenSizeIntOp>(loc, inputA, sizeDimA);
Value sizeInputB =
rewriter.createOrFold<AtenSizeIntOp>(loc, inputB, sizeDimB);
resultShapeValue[resultRank - i - 1] =
rewriter.create<PrimMaxIntOp>(loc, sizeInputA, sizeInputB);
if (shapeA[rankA - i - 1] == kUnknownSize ||
shapeB[rankB - i - 1] == kUnknownSize) {
resultShape[resultRank - i - 1] = kUnknownSize;
} else {
resultShape[resultRank - i - 1] =
std::max(shapeA[rankA - i - 1], shapeB[rankB - i - 1]);
}
}
}

} // namespace Torch
} // namespace torch
} // namespace mlir
38 changes: 38 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6290,6 +6290,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.cosine_similarity\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.float) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %int1 = torch.constant.int 1\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" %1 = torch.aten.slice.t %0, %none, %arg2, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
" %2 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %3 = torch.aten.slice.t %0, %2, %none, %int1 : !torch.list<int>, !torch.int, !torch.none, !torch.int -> !torch.list<int>\n"
" %4 = torch.aten.add.t %1, %3 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.hardtanh\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -8542,6 +8552,34 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.cosine_similarity\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.float) -> !torch.int {\n"
" %int7 = torch.constant.int 7\n"
" %int6 = torch.constant.int 6\n"
" %int5 = torch.constant.int 5\n"
" %int15 = torch.constant.int 15\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.prim.ListConstruct %int15, %int5, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = torch.aten.__contains__.int_list %3, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n"
" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.ceil\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
64 changes: 64 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3920,6 +3920,69 @@ class DecomposeAtenClampMaxOp : public OpRewritePattern<AtenClampMaxOp> {
};
} // namespace

namespace {
class DecomposeAtenCosineSimilarityOp
: public OpRewritePattern<AtenCosineSimilarityOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenCosineSimilarityOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value x1 = op.getX1();
Value x2 = op.getX2();
Value dim = op.getDim();

// Broadcast x1 and x2 to the same shape
SmallVector<int64_t> indexBroadcastShapeInt;
SmallVector<Value> indexBroadcastShapeValue;
computeBroadcastShape(rewriter, loc, x1, x2, indexBroadcastShapeInt,
indexBroadcastShapeValue);
Type dtype = x1.getType().cast<BaseTensorType>().getOptionalDtype();
Type broadcastType = ValueTensorType::get(
op.getContext(), llvm::ArrayRef(indexBroadcastShapeInt), dtype);
Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
indexBroadcastShapeValue);
x1 = rewriter.create<AtenBroadcastToOp>(loc, broadcastType, x1,
indexBroadcastShapeTorchList);
x2 = rewriter.create<AtenBroadcastToOp>(loc, broadcastType, x2,
indexBroadcastShapeTorchList);

// Compute the mul of A and B
Value dotProduct =
rewriter.create<AtenMulTensorOp>(loc, broadcastType, x1, x2);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
ValueRange{dim});
Value sumDotProduct = rewriter.create<Torch::AtenSumDimIntListOp>(
loc, op.getType(), /*self=*/dotProduct, /*dim=*/dimList,
/*keepdim=*/cstFalse,
/*dtype=*/cstNone);

// Compute the norm of A and B
Value ord = rewriter.create<Torch::ConstantFloatOp>(loc,
rewriter.getF64FloatAttr(2.0));
Value normA = rewriter.create<AtenLinalgVectorNormOp>(
loc, op.getType(), x1, ord, dimList, /*keepdim=*/cstFalse,
/*dtype=*/cstNone);
Value normB = rewriter.create<AtenLinalgVectorNormOp>(
loc, op.getType(), x2, ord, dimList, /*keepdim=*/cstFalse,
/*dtype=*/cstNone);

// Compute the product of the norms
Value normProduct =
rewriter.create<AtenMulTensorOp>(loc, op.getType(), normA, normB);
Value normProductClamp = rewriter.create<AtenClampOp>(
loc, op.getType(), normProduct, op.getEps(), /*max=*/cstNone);
// Compute the final cosine similarity by division
rewriter.replaceOpWithNewOp<AtenDivTensorOp>(
op, op.getType(), sumDotProduct, normProductClamp);
return success();
}
};
} // namespace

namespace {
// Decompose `aten.baddbmm` op into `aten.bmm`, `aten.mul.Scalar`, and
// `aten.add.Tensor` op.
Expand Down Expand Up @@ -5535,6 +5598,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCosineSimilarityOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBaddbmmOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNumpyTOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenRandnGeneratorOp>();
target.addIllegalOp<AtenRandnLikeOp>();
target.addIllegalOp<AtenVarMeanOp>();
target.addIllegalOp<AtenCosineSimilarityOp>();
target.addIllegalOp<AtenNewEmptyStridedOp>();
target.addIllegalOp<AtenEmptyStridedOp>();
target.addIllegalOp<AtenBucketizeTensorOp>();
Expand Down
76 changes: 76 additions & 0 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,82 @@ FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter,
return unsqueezed;
}

// Checks whether the `shapeA` and `shapeB` are broadcast compatible or not. If
// yes, then computes the final broadcast shape.
void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc,
Value inputA, Value inputB,
SmallVector<int64_t> &resultShape,
SmallVector<Value> &resultShapeValue) {
SmallVector<int64_t> shapeA{
inputA.getType().cast<BaseTensorType>().getSizes()};
SmallVector<int64_t> shapeB{
inputB.getType().cast<BaseTensorType>().getSizes()};
unsigned rankA = shapeA.size();
unsigned rankB = shapeB.size();
unsigned minRank = rankA > rankB ? rankB : rankA;
// Check whether the shapes of the tensors are broadcastable or not.
// Two tensors are “broadcastable” if the following rules hold:
// 1.) Each tensor has at least one dimension.
// 2.) When iterating over the dimension sizes, starting at the trailing
// dimension, the dimension sizes must either be equal, one of them is 1, or
// one of them does not exist.
for (unsigned i = 0; i < minRank; i++) {
Value sizeDimA = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rankA - i - 1));
Value sizeDimB = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rankB - i - 1));
Value sizeInputA =
rewriter.createOrFold<AtenSizeIntOp>(loc, inputA, sizeDimA);
Value sizeInputB =
rewriter.createOrFold<AtenSizeIntOp>(loc, inputB, sizeDimB);
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value cmpSizeAEqualsSizeB =
rewriter.create<Torch::AtenEqIntOp>(loc, sizeInputA, sizeInputB);
Value cmpSizeAEqualsOne =
rewriter.create<Torch::AtenEqIntOp>(loc, sizeInputA, torchCstOne);
Value cmpSizeBEqualsOne =
rewriter.create<Torch::AtenEqIntOp>(loc, sizeInputB, torchCstOne);
Value anyBoolOpList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(cmpSizeAEqualsOne.getType()),
SmallVector<Value>{cmpSizeAEqualsSizeB, cmpSizeAEqualsOne,
cmpSizeBEqualsOne});
Value cmp = rewriter.create<Torch::AtenAnyBoolOp>(loc, anyBoolOpList);
rewriter.create<Torch::RuntimeAssertOp>(
loc, cmp, "tensors are not broadcast compatible");
}
// If we reach here then it means both the shapes are broadcast compatible.
resultShape = rankA >= rankB ? shapeA : shapeB;
Value shapeTensor = rankA >= rankB ? inputA : inputB;
for (unsigned i = 0; i < resultShape.size(); i++) {
Value sizeDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
resultShapeValue.push_back(
rewriter.createOrFold<AtenSizeIntOp>(loc, shapeTensor, sizeDim));
}

unsigned resultRank = resultShape.size();
for (unsigned i = 0; i < minRank; i++) {
Value sizeDimA = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rankA - i - 1));
Value sizeDimB = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(rankB - i - 1));
Value sizeInputA =
rewriter.createOrFold<AtenSizeIntOp>(loc, inputA, sizeDimA);
Value sizeInputB =
rewriter.createOrFold<AtenSizeIntOp>(loc, inputB, sizeDimB);
resultShapeValue[resultRank - i - 1] =
rewriter.create<PrimMaxIntOp>(loc, sizeInputA, sizeInputB);
if (shapeA[rankA - i - 1] == kUnknownSize ||
shapeB[rankB - i - 1] == kUnknownSize) {
resultShape[resultRank - i - 1] = kUnknownSize;
} else {
resultShape[resultRank - i - 1] =
std::max(shapeA[rankA - i - 1], shapeB[rankB - i - 1]);
}
}
}

bool Torch::isAssumingStrictSymbolicShapes(Block *block) {
for (Operation *parentOp = block->getParentOp(); parentOp;
parentOp = parentOp->getParentOp()) {
Expand Down
2 changes: 2 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,8 @@
"BucketizeTensorStaticModule_basic",
"CumsumStaticModule_basic",
"CumsumStaticNegativeDimModule_basic",
"CosineSimilarityStaticModule_basic",
"CosineSimilarityStaticBroadcastModule_basic",
"DetachModule_basic",
"ElementwiseIsnanModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
Expand Down
Loading

0 comments on commit a42d4c1

Please sign in to comment.