diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index d561c2101173..516954b88fbc 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -94,13 +94,6 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, Value torchOptionalInt, Value builtinInt, Value defaultValue, Value dimSize); -// Checks whether the `inputA` and `inputB` are broadcast compatible or not. If -// yes, then computes the final broadcast shape. -void computeBroadcastShape(ConversionPatternRewriter &rewriter, Location loc, - Value inputA, Value inputB, - SmallVector &resultShape, - SmallVector &resultShapeValue); - } // namespace Torch } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c0ac6a9419e8..a51a3c4d9710 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5123,6 +5123,32 @@ def Torch_AtenMvOp : Torch_Op<"aten.mv", [ }]; } +def Torch_AtenCosineSimilarityOp : Torch_Op<"aten.cosine_similarity", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::cosine_similarity : (Tensor, Tensor, int, float) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$x1, + AnyTorchTensorType:$x2, + Torch_IntType:$dim, + Torch_FloatType:$eps + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCosineSimilarityOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenCosineSimilarityOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 3fa871b3370a..14622f654139 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -54,6 +54,14 @@ Type getBuiltInTypeForTorchScalar(Type type); Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc, Type dtype); + +// Checks whether the `inputA` and `inputB` are broadcast compatible or not. If +// yes, then computes the final broadcast shape. +void computeBroadcastShape(PatternRewriter &rewriter, Location loc, + Value inputA, Value inputB, + SmallVector &resultShape, + SmallVector &resultShapeValue); + // Helper to convert a tensor to a specific scalar type. Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input, Type dtype); diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 89b17a50be99..3df9da94b735 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -355,82 +355,6 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, return castIntToIndex(rewriter, loc, boundedByDimSize); } -// Checks whether the `shapeA` and `shapeB` are broadcast compatible or not. If -// yes, then computes the final broadcast shape. -void computeBroadcastShape(ConversionPatternRewriter &rewriter, Location loc, - Value inputA, Value inputB, - SmallVector &resultShape, - SmallVector &resultShapeValue) { - SmallVector shapeA{ - inputA.getType().cast().getSizes()}; - SmallVector shapeB{ - inputB.getType().cast().getSizes()}; - unsigned rankA = shapeA.size(); - unsigned rankB = shapeB.size(); - unsigned minRank = rankA > rankB ? rankB : rankA; - // Check whether the shapes of the tensors are broadcastable or not. - // Two tensors are “broadcastable” if the following rules hold: - // 1.) Each tensor has at least one dimension. - // 2.) When iterating over the dimension sizes, starting at the trailing - // dimension, the dimension sizes must either be equal, one of them is 1, or - // one of them does not exist. - for (unsigned i = 0; i < minRank; i++) { - Value sizeDimA = rewriter.create( - loc, rewriter.getI64IntegerAttr(rankA - i - 1)); - Value sizeDimB = rewriter.create( - loc, rewriter.getI64IntegerAttr(rankB - i - 1)); - Value sizeInputA = - rewriter.createOrFold(loc, inputA, sizeDimA); - Value sizeInputB = - rewriter.createOrFold(loc, inputB, sizeDimB); - Value torchCstOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value cmpSizeAEqualsSizeB = - rewriter.create(loc, sizeInputA, sizeInputB); - Value cmpSizeAEqualsOne = - rewriter.create(loc, sizeInputA, torchCstOne); - Value cmpSizeBEqualsOne = - rewriter.create(loc, sizeInputB, torchCstOne); - Value anyBoolOpList = rewriter.create( - loc, Torch::ListType::get(cmpSizeAEqualsOne.getType()), - SmallVector{cmpSizeAEqualsSizeB, cmpSizeAEqualsOne, - cmpSizeBEqualsOne}); - Value cmp = rewriter.create(loc, anyBoolOpList); - rewriter.create( - loc, cmp, "tensors are not broadcast compatible"); - } - // If we reach here then it means both the shapes are broadcast compatible. - resultShape = rankA >= rankB ? shapeA : shapeB; - Value shapeTensor = rankA >= rankB ? inputA : inputB; - for (unsigned i = 0; i < resultShape.size(); i++) { - Value sizeDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - resultShapeValue.push_back( - rewriter.createOrFold(loc, shapeTensor, sizeDim)); - } - - unsigned resultRank = resultShape.size(); - for (unsigned i = 0; i < minRank; i++) { - Value sizeDimA = rewriter.create( - loc, rewriter.getI64IntegerAttr(rankA - i - 1)); - Value sizeDimB = rewriter.create( - loc, rewriter.getI64IntegerAttr(rankB - i - 1)); - Value sizeInputA = - rewriter.createOrFold(loc, inputA, sizeDimA); - Value sizeInputB = - rewriter.createOrFold(loc, inputB, sizeDimB); - resultShapeValue[resultRank - i - 1] = - rewriter.create(loc, sizeInputA, sizeInputB); - if (shapeA[rankA - i - 1] == kUnknownSize || - shapeB[rankB - i - 1] == kUnknownSize) { - resultShape[resultRank - i - 1] = kUnknownSize; - } else { - resultShape[resultRank - i - 1] = - std::max(shapeA[rankA - i - 1], shapeB[rankB - i - 1]); - } - } -} - } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 7ffa776e2d41..6f578d1ad5f3 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6290,6 +6290,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.cosine_similarity\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.float) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %int1 = torch.constant.int 1\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" %1 = torch.aten.slice.t %0, %none, %arg2, %int1 : !torch.list, !torch.none, !torch.int, !torch.int -> !torch.list\n" +" %2 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %3 = torch.aten.slice.t %0, %2, %none, %int1 : !torch.list, !torch.int, !torch.none, !torch.int -> !torch.list\n" +" %4 = torch.aten.add.t %1, %3 : !torch.list, !torch.list -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.hardtanh\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8542,6 +8552,34 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cosine_similarity\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.float) -> !torch.int {\n" +" %int7 = torch.constant.int 7\n" +" %int6 = torch.constant.int 6\n" +" %int5 = torch.constant.int 5\n" +" %int15 = torch.constant.int 15\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.ListConstruct %int15, %int5, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.aten.__contains__.int_list %3, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.ceil\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index e315bab00e68..d3e7a47c4aa5 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3920,6 +3920,69 @@ class DecomposeAtenClampMaxOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenCosineSimilarityOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenCosineSimilarityOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value x1 = op.getX1(); + Value x2 = op.getX2(); + Value dim = op.getDim(); + + // Broadcast x1 and x2 to the same shape + SmallVector indexBroadcastShapeInt; + SmallVector indexBroadcastShapeValue; + computeBroadcastShape(rewriter, loc, x1, x2, indexBroadcastShapeInt, + indexBroadcastShapeValue); + Type dtype = x1.getType().cast().getOptionalDtype(); + Type broadcastType = ValueTensorType::get( + op.getContext(), llvm::ArrayRef(indexBroadcastShapeInt), dtype); + Value indexBroadcastShapeTorchList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + indexBroadcastShapeValue); + x1 = rewriter.create(loc, broadcastType, x1, + indexBroadcastShapeTorchList); + x2 = rewriter.create(loc, broadcastType, x2, + indexBroadcastShapeTorchList); + + // Compute the mul of A and B + Value dotProduct = + rewriter.create(loc, broadcastType, x1, x2); + Value cstFalse = rewriter.create(loc, false); + Value cstNone = rewriter.create(loc); + Value dimList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), + ValueRange{dim}); + Value sumDotProduct = rewriter.create( + loc, op.getType(), /*self=*/dotProduct, /*dim=*/dimList, + /*keepdim=*/cstFalse, + /*dtype=*/cstNone); + + // Compute the norm of A and B + Value ord = rewriter.create(loc, + rewriter.getF64FloatAttr(2.0)); + Value normA = rewriter.create( + loc, op.getType(), x1, ord, dimList, /*keepdim=*/cstFalse, + /*dtype=*/cstNone); + Value normB = rewriter.create( + loc, op.getType(), x2, ord, dimList, /*keepdim=*/cstFalse, + /*dtype=*/cstNone); + + // Compute the product of the norms + Value normProduct = + rewriter.create(loc, op.getType(), normA, normB); + Value normProductClamp = rewriter.create( + loc, op.getType(), normProduct, op.getEps(), /*max=*/cstNone); + // Compute the final cosine similarity by division + rewriter.replaceOpWithNewOp( + op, op.getType(), sumDotProduct, normProductClamp); + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.baddbmm` op into `aten.bmm`, `aten.mul.Scalar`, and // `aten.add.Tensor` op. @@ -5535,6 +5598,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index d40f99066fc9..5cdcfbc5d231 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -484,6 +484,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 14a264ada342..067cc410ddd3 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -325,6 +325,82 @@ FailureOr Torch::unsqueezeTensor(PatternRewriter &rewriter, return unsqueezed; } +// Checks whether the `shapeA` and `shapeB` are broadcast compatible or not. If +// yes, then computes the final broadcast shape. +void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc, + Value inputA, Value inputB, + SmallVector &resultShape, + SmallVector &resultShapeValue) { + SmallVector shapeA{ + inputA.getType().cast().getSizes()}; + SmallVector shapeB{ + inputB.getType().cast().getSizes()}; + unsigned rankA = shapeA.size(); + unsigned rankB = shapeB.size(); + unsigned minRank = rankA > rankB ? rankB : rankA; + // Check whether the shapes of the tensors are broadcastable or not. + // Two tensors are “broadcastable” if the following rules hold: + // 1.) Each tensor has at least one dimension. + // 2.) When iterating over the dimension sizes, starting at the trailing + // dimension, the dimension sizes must either be equal, one of them is 1, or + // one of them does not exist. + for (unsigned i = 0; i < minRank; i++) { + Value sizeDimA = rewriter.create( + loc, rewriter.getI64IntegerAttr(rankA - i - 1)); + Value sizeDimB = rewriter.create( + loc, rewriter.getI64IntegerAttr(rankB - i - 1)); + Value sizeInputA = + rewriter.createOrFold(loc, inputA, sizeDimA); + Value sizeInputB = + rewriter.createOrFold(loc, inputB, sizeDimB); + Value torchCstOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value cmpSizeAEqualsSizeB = + rewriter.create(loc, sizeInputA, sizeInputB); + Value cmpSizeAEqualsOne = + rewriter.create(loc, sizeInputA, torchCstOne); + Value cmpSizeBEqualsOne = + rewriter.create(loc, sizeInputB, torchCstOne); + Value anyBoolOpList = rewriter.create( + loc, Torch::ListType::get(cmpSizeAEqualsOne.getType()), + SmallVector{cmpSizeAEqualsSizeB, cmpSizeAEqualsOne, + cmpSizeBEqualsOne}); + Value cmp = rewriter.create(loc, anyBoolOpList); + rewriter.create( + loc, cmp, "tensors are not broadcast compatible"); + } + // If we reach here then it means both the shapes are broadcast compatible. + resultShape = rankA >= rankB ? shapeA : shapeB; + Value shapeTensor = rankA >= rankB ? inputA : inputB; + for (unsigned i = 0; i < resultShape.size(); i++) { + Value sizeDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + resultShapeValue.push_back( + rewriter.createOrFold(loc, shapeTensor, sizeDim)); + } + + unsigned resultRank = resultShape.size(); + for (unsigned i = 0; i < minRank; i++) { + Value sizeDimA = rewriter.create( + loc, rewriter.getI64IntegerAttr(rankA - i - 1)); + Value sizeDimB = rewriter.create( + loc, rewriter.getI64IntegerAttr(rankB - i - 1)); + Value sizeInputA = + rewriter.createOrFold(loc, inputA, sizeDimA); + Value sizeInputB = + rewriter.createOrFold(loc, inputB, sizeDimB); + resultShapeValue[resultRank - i - 1] = + rewriter.create(loc, sizeInputA, sizeInputB); + if (shapeA[rankA - i - 1] == kUnknownSize || + shapeB[rankB - i - 1] == kUnknownSize) { + resultShape[resultRank - i - 1] = kUnknownSize; + } else { + resultShape[resultRank - i - 1] = + std::max(shapeA[rankA - i - 1], shapeB[rankB - i - 1]); + } + } +} + bool Torch::isAssumingStrictSymbolicShapes(Block *block) { for (Operation *parentOp = block->getParentOp(); parentOp; parentOp = parentOp->getParentOp()) { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e671b9241287..78251b3862ea 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -454,6 +454,8 @@ "BucketizeTensorStaticModule_basic", "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", + "CosineSimilarityStaticModule_basic", + "CosineSimilarityStaticBroadcastModule_basic", "DetachModule_basic", "ElementwiseIsnanModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic", diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 9839c34be618..4450237bec4a 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -98,6 +98,10 @@ def aten〇sin〡shape(self: List[int]) -> List[int]: def aten〇cos〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇cosine_similarity〡shape(x1: List[int], x2: List[int], dim: int = 1, eps: float = 1e-08) -> List[int]: + broadcast = upstream_shape_functions.broadcast(x1, x2) + return broadcast[:dim] + broadcast[dim + 1:] + def aten〇hardtanh〡shape(self: List[int], min_val: float = -1, max_val: float = 1) -> List[int]: return upstream_shape_functions.unary(self) @@ -1587,6 +1591,15 @@ def aten〇broadcast_to〡dtype(self_rank_dtype: Tuple[int, int], size: List[int self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=2,dim=0, error_types={torch.complex128, torch.complex64, *all_integer_dtypes()})) +def aten〇cosine_similarity〡dtype(x1_rank_dtype: Tuple[int, int], x2_rank_dtype: Tuple[int, int], dim: int = 1, eps: float = 1e-08) -> int: + x1_rank, x1_dtype = x1_rank_dtype + x2_rank, x2_dtype = x2_rank_dtype + assert x1_dtype == x2_dtype + assert not x1_dtype not in [torch.bfloat16, torch.float16, torch.float32, torch.float64] + return x1_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇ceil〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 7042ab5756e6..d4ac74fe53b0 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -393,6 +393,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::matmul : (Tensor, Tensor) -> (Tensor)") emit("aten::mv : (Tensor, Tensor) -> (Tensor)") + emit("aten::cosine_similarity : (Tensor, Tensor, int, float) -> (Tensor)") emit( "aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 3bd9b91ebf99..431c5efec686 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -4609,6 +4609,75 @@ def Add_Module_basic(module, tu: TestUtils): # ============================================================================== +class CosineSimilarityStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3], torch.float32, True), + ([2, 3], torch.float32, True), + ]) + def forward(self, x1, x2): + return torch.ops.aten.cosine_similarity(x1, x2) + + +@register_test_case(module_factory=lambda: CosineSimilarityStaticModule()) +def CosineSimilarityStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3), tu.rand(2, 3)) + + +# ============================================================================== + + +class CosineSimilarityStaticBroadcastModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([5, 2, 3], torch.float32, True), + ([4, 5, 1, 1], torch.float32, True), + ]) + def forward(self, x1, x2): + return torch.ops.aten.cosine_similarity(x1, x2) + + +@register_test_case(module_factory=lambda: CosineSimilarityStaticBroadcastModule()) +def CosineSimilarityStaticBroadcastModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 3), tu.rand(4, 5, 1, 1)) + + +# ============================================================================== + + +class CosineSimilarityModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, x1, x2): + return torch.ops.aten.cosine_similarity(x1, x2) + + +@register_test_case(module_factory=lambda: CosineSimilarityModule()) +def CosineSimilarityModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3), tu.rand(2, 3)) + + +# ============================================================================== + + class IscloseStaticModule(torch.nn.Module): def __init__(self):