diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 5727eca2e5e..89be8e51b93 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -321,9 +321,11 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenSoftmaxIntOp op, PatternRewriter &rewriter) const override { Value self = op.self(); - if (!op.dtype().getType().isa()) - return rewriter.notifyMatchFailure( - op, "Unimplemented non-None dtype for softmax"); + + // Do not need check dtype args here, since dtype have been infered in op.getType() + // if (!op.dtype().getType().isa()) + // return rewriter.notifyMatchFailure( + // op, "Unimplemented non-None dtype for softmax"); BaseTensorType tensorType = self.getType().cast(); if (!tensorType.hasDtype() || !tensorType.getDtype().isa())