diff --git a/frontends/pytorch/e2e_testing/torchscript/elementwise.py b/frontends/pytorch/e2e_testing/torchscript/elementwise.py index 8cfc69f8fe06..b647c7f9b44d 100644 --- a/frontends/pytorch/e2e_testing/torchscript/elementwise.py +++ b/frontends/pytorch/e2e_testing/torchscript/elementwise.py @@ -149,7 +149,6 @@ def forward(self, a, b): def ElementwiseFlattenBroadcastModule_basic(module, tu: TestUtils): module.forward(tu.rand(6), tu.rand()) - # ============================================================================== @@ -169,3 +168,24 @@ def forward(self, x): @register_test_case(module_factory=lambda: ElementwiseReluModule()) def ElementwiseReluModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 2) - 0.5) + +# ============================================================================== + + +class ElementwiseSigmoidModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.sigmoid(x) + + +@register_test_case(module_factory=lambda: ElementwiseSigmoidModule()) +def ElementwiseSigmoidModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5)) + diff --git a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py index e911f77a2972..d52e9ef01218 100644 --- a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py +++ b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py @@ -436,6 +436,7 @@ def emit_with_mutating_variants(key, **kwargs): for key in [ "aten::tanh : (Tensor) -> (Tensor)", "aten::relu : (Tensor) -> (Tensor)", + "aten::sigmoid : (Tensor) -> (Tensor)", "aten::sin : (Tensor) -> (Tensor)", "aten::exp : (Tensor) -> (Tensor)", "aten::cos : (Tensor) -> (Tensor)", diff --git a/include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td b/include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td index be3ec788d443..9e5fc090b99a 100644 --- a/include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td @@ -71,6 +71,34 @@ def Torch_AtenRelu_Op : Torch_Op<"aten.relu_", [ let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; } +def Torch_AtenSigmoidOp : Torch_Op<"aten.sigmoid", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::sigmoid : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenSigmoid_Op : Torch_Op<"aten.sigmoid_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::sigmoid_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + def Torch_AtenSinOp : Torch_Op<"aten.sin", [ AllowsTypeRefinement, HasValueSemantics diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index d12e1858538f..f2386b39211b 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -654,6 +654,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp( ArrayRef operands) { if (isa(op)) return b.create(loc, payloadArgs[0]); + if (isa(op)){ + Type elementType = payloadArgs[0].getType(); + auto one = b.create(loc, FloatAttr::get(elementType, 1)); + auto negate = b.create(loc, payloadArgs[0]); + auto exp = b.create(loc, negate); + auto added = b.create(loc, exp, one); + return b.create(loc, one, added); + } if (auto relu = dyn_cast(op)) { if (!relu.getType() .cast() @@ -775,7 +783,8 @@ struct ConvertElementwiseOp : ConversionPattern { matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!isa(op)) + AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, + AtenSigmoidOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -1137,7 +1146,8 @@ class ConvertTorchToLinalg patterns.add(typeConverter, context); target .addIllegalOp(); + AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, + AtenSigmoidOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 063f6abf7b75..b1eda002b5ed 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -175,7 +175,7 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis { AtenSubScalarOp, AtenMulScalarOp, AtenDivScalarOp, AtenFmodScalarOp, AtenFloorDivideScalarOp, AtenEqScalarOp, AtenGeScalarOp, AtenNeScalarOp, AtenBitwiseNotOp, AtenToDtypeOp, AtenExpOp, - AtenSinOp, AtenCosOp, DerefineOp>(op)) { + AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp>(op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); }