From 0b23af27d338fa005e2bd1c93a61e88b40edca23 Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Wed, 3 Aug 2022 08:16:31 +0800 Subject: [PATCH] [MHLO] support non-constant torch scalar in BasicOps (#1134) See RFC https://github.com/llvm/torch-mlir/issues/999 Co-authored-by: Bairen Yi yibairen.byron@bytedance.com Co-authored-by: Jiawei Wu xremold@gmail.com Co-authored-by: Tianyou Guo tianyou.gty@alibaba-inc.com Co-authored-by: Xu Yan yancey.yx@alibaba-inc.com Co-authored-by: Ziheng Jiang ziheng.jiang@bytedance.com --- lib/Conversion/TorchToMhlo/BasicOp.cpp | 42 +- .../TorchToMhlo/MhloLegalizeUtils.cpp | 98 +--- .../TorchToMhlo/MhloLegalizeUtils.h | 13 +- test/Conversion/TorchToMhlo/basic.mlir | 16 +- test/Conversion/TorchToMhlo/elementwise.mlir | 543 +++++++++++------- 5 files changed, 362 insertions(+), 350 deletions(-) diff --git a/lib/Conversion/TorchToMhlo/BasicOp.cpp b/lib/Conversion/TorchToMhlo/BasicOp.cpp index c3586b9782c..ecec4882c55 100644 --- a/lib/Conversion/TorchToMhlo/BasicOp.cpp +++ b/lib/Conversion/TorchToMhlo/BasicOp.cpp @@ -159,23 +159,15 @@ class ConvertAtenAddSubOp : public OpConversionPattern { } if (!rhsType) { - if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs, - outElemTy, {}))) - return op.emitError("currently only scalar constants are supported for " - "conversion in MHLO operation"); + rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy); } lhs = mhlo::promoteType(rewriter, lhs, outType); rhs = mhlo::promoteType(rewriter, rhs, outType); if (!skipMultiplyAlpha(op.alpha())) { - Value alpha; - if (failed(mhlo::torchAlphaToMhloTensor(rewriter, op.getOperation(), - op.alpha(), alpha, outElemTy, {}, - /*checkForUnity=*/false))) { - return op.emitError("currently only scalar constants are supported for " - "alpha in conversion to MHLO operation"); - } + Value alpha = + mhlo::scalarToMhloTensor(rewriter, op, adaptor.alpha(), outElemTy); DenseIntElementsAttr bcastDimensions; rhs = rewriter.create(op->getLoc(), rhs, alpha, bcastDimensions); @@ -216,13 +208,13 @@ class ConvertAtenMulDivOp : public OpConversionPattern { return op.emitError( "only floating-point or integer datatype legalization supported"); } - if (!rhsType) { - if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs, - outElemTy, {}))) - return op.emitError("currently only scalar constants are supported for " - "conversion in MHLO operation"); - } + Value lhsTensor = lhs; + if (std::is_same()) { + rhs = lhs; + } else if (!rhsType) { + rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy); + } DenseIntElementsAttr bcastDimensions; lhs = mhlo::promoteType(rewriter, lhs, outType); rhs = mhlo::promoteType(rewriter, rhs, outType); @@ -263,11 +255,7 @@ class ConvertAtenCompareOp : public OpConversionPattern { } if (!rhsTy) { - if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs, - lhsElemTy, {}))) { - return op.emitError("currently only scalar constants are supported for " - "conversion in MHLO operation"); - } + rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), lhsElemTy); } // TODO: what is the PyTorch default type promotion? @@ -569,12 +557,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .cast(); auto outputShape = outputType.getShape(); auto outputElemType = outputType.getElementType(); - Value mhloTensor; - if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.a(), mhloTensor, - outputElemType, outputShape, - false))) { - return op->emitError("failed lowering PrimNumToTensorScalarOp to MHLO"); - } + Value mhloTensor = + mhlo::scalarToMhloTensor(rewriter, op, adaptor.a(), outputElemType); rewriter.replaceOp(op, mhloTensor); return success(); } @@ -1020,4 +1004,4 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenBatchNormOp); INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); #undef INSERT_ATENOP_PATTERN -} \ No newline at end of file +} diff --git a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp index 4a6c333ffc4..8d646a73f6b 100644 --- a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp @@ -174,93 +174,15 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op, return const_op.getResult(); } -// TODO: Support for variable scalar. -LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter, - Operation *op, Value torchScalarValue, - Value &mhloTensor, Type dtype, - llvm::ArrayRef dshape, - bool doBroadcast) { - // Retrieve a const float or int value but create the out Tensor with dtype. - double doubleValue; - auto isFloat = - matchPattern(torchScalarValue, m_TorchConstantFloat(&doubleValue)); - - int64_t intValue; - auto isInt = matchPattern(torchScalarValue, m_TorchConstantInt(&intValue)); - - if (!isFloat && !isInt) - return op->emitError("Unable to extract the scalar constant"); - - if (dtype.isa()) { - if (doBroadcast) { - mhloTensor = getSplatConstTensor( - rewriter, op, (isFloat ? doubleValue : intValue), dtype, dshape); - } else { - mhloTensor = mhlo::getConstTensor( - rewriter, op, (isFloat ? doubleValue : intValue), dshape) - .getValue(); - } - } else if (auto intType = dtype.dyn_cast()) { - auto w = intType.getWidth(); - if (w != 32 && w != 64) - return op->emitError("Unsupported integer type") << intType; - - if (w == 32) { - if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { - return op->emitError("Supplied value of scalar constant exceeds limits " - "of destination type"); - } - int32_t d = isFloat ? static_cast(doubleValue) - : static_cast(intValue); - if (doBroadcast) { - mhloTensor = - getSplatConstTensor(rewriter, op, d, dtype, dshape); - } else { - mhloTensor = - mhlo::getConstTensor(rewriter, op, {d}, dshape).getValue(); - } - } else if (w == 64) { - if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { - return op->emitError("Supplied value of scalar constant exceeds limits " - "of destination type"); - } - int64_t d = (isFloat ? static_cast(doubleValue) : intValue); - if (doBroadcast) { - mhloTensor = - getSplatConstTensor(rewriter, op, d, dtype, dshape); - } else { - mhloTensor = - mhlo::getConstTensor(rewriter, op, {d}, dshape).getValue(); - } - } - } else - return op->emitError("Usupported element type"); - - return success(); -} - -LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter, - Operation *op, Value alphaScalar, - Value &alphaTensor, Type dtype, - llvm::ArrayRef dshape, - bool checkForUnity) { - if (succeeded(torchScalarToMhloTensor(rewriter, op, alphaScalar, alphaTensor, - dtype, dshape))) - return success(); - - // `alpha` has not been specified. - int64_t alphaValue; - if (!matchPattern(alphaScalar, m_TorchConstantInt(&alphaValue))) - return op->emitError("Currently only scalar constants are supported for " - "alpha in MHLO operation"); - // When no alpha has been specified, this must be 1. - if (checkForUnity && alphaValue != 1) - return op->emitError("Unsupported integer value for alpha"); - - alphaTensor = - mlir::mhlo::getMhloConstTensorSingleF32(rewriter, op, alphaValue); - - return success(); +Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op, + Value scalarValue, Type dtype) { + auto tensor = rewriter.create( + op->getLoc(), ArrayRef{scalarValue}); + auto dtype_tensor = + rewriter.create(op->getLoc(), tensor, dtype); + return rewriter.create( + op->getLoc(), RankedTensorType::get(mlir::ArrayRef{}, dtype), + dtype_tensor); } Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) { @@ -439,4 +361,4 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc, .getResult(); } } // namespace mhlo -} // namespace mlir \ No newline at end of file +} // namespace mlir diff --git a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h index 850c7b75a08..5875d7baea2 100644 --- a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h +++ b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h @@ -47,17 +47,8 @@ template Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op, T val, Type dtype, llvm::ArrayRef dshape); -LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter, - Operation *op, Value torchScalarValue, - Value &mhloTensor, Type dtype, - llvm::ArrayRef dshape, - bool doBroadcast = true); - -LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter, - Operation *op, Value alphaScalar, - Value &alphaTensor, Type dtype, - llvm::ArrayRef dshape, - bool checkForUnity); +Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op, + Value scalarValue, Type dtype); Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType); diff --git a/test/Conversion/TorchToMhlo/basic.mlir b/test/Conversion/TorchToMhlo/basic.mlir index ee7da07e10d..74d9bfda85a 100644 --- a/test/Conversion/TorchToMhlo/basic.mlir +++ b/test/Conversion/TorchToMhlo/basic.mlir @@ -41,11 +41,15 @@ func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { // ----- -// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[],si64> { -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor -// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor -> !torch.vtensor<[],si64> -// CHECK: return %[[VAL_1]] : !torch.vtensor<[],si64> +// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic( +// CHECK-SAME: ) -> !torch.vtensor<[],si64> { +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]] +// CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64> +// CHECK: %[[T2:.*]] = mhlo.convert %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = "mhlo.reshape"(%[[T2]]) : (tensor<1xi64>) -> tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[],si64> +// CHECK: return %[[T4]] : !torch.vtensor<[],si64> func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> { %int1 = torch.constant.int 1 %0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[], si64> @@ -251,4 +255,4 @@ func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) -> %2 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list %result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %2, %1, %0, %float1.000000e-05 : !torch.vtensor<[3,7,4,5],f32>, !torch.list, !torch.vtensor<[4,5],f32>, !torch.vtensor<[4,5],f32>, !torch.float -> !torch.vtensor<[3,7,4,5],f32>, !torch.vtensor<[3,7,1,1],f32>, !torch.vtensor<[3,7,1,1],f32> return %result0 : !torch.vtensor<[3,7,4,5],f32> -} \ No newline at end of file +} diff --git a/test/Conversion/TorchToMhlo/elementwise.mlir b/test/Conversion/TorchToMhlo/elementwise.mlir index 643c44b118b..65e462d441b 100644 --- a/test/Conversion/TorchToMhlo/elementwise.mlir +++ b/test/Conversion/TorchToMhlo/elementwise.mlir @@ -1,10 +1,9 @@ // RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s -// ----- - // CHECK-LABEL: func.func @torch.aten.gelu( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[STR:.*]] = torch.constant.str "none" // CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 1.000000e+00 : f32} : (tensor) -> tensor // CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) {value = 2.000000e+00 : f32} : (tensor) -> tensor // CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) {value = 5.000000e-01 : f32} : (tensor) -> tensor @@ -22,13 +21,14 @@ func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[ return %0 : !torch.vtensor<[?,?],f32> } +// ----- -// CHECK-LABEL: func.func @torch.aten.tanh$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = mhlo.tanh %[[VAL_1]] : tensor -// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +// CHECK-LABEL: func.func @torch.aten.tanh$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = mhlo.tanh %[[T0]] : tensor +// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %0 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> @@ -36,12 +36,12 @@ func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // ----- -// CHECK-LABEL: func.func @torch.aten.log$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = mhlo.log %[[VAL_1]] : tensor -// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +// CHECK-LABEL: func.func @torch.aten.log$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = mhlo.log %[[T0]] : tensor +// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %0 = torch.aten.log %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> @@ -49,43 +49,44 @@ func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // ----- -// CHECK-LABEL: func.func @torch.aten.exp$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = mhlo.exponential %[[VAL_1]] : tensor -// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +// CHECK-LABEL: func.func @torch.aten.exp$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = mhlo.exponential %[[T0]] : tensor +// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %0 = torch.aten.exp %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> } - // ----- -// CHECK-LABEL: func.func @torch.aten.neg$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = mhlo.negate %[[VAL_1]] : tensor -// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +// CHECK-LABEL: func.func @torch.aten.neg$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = mhlo.negate %[[T0]] : tensor +// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> } - // ----- -// CHECK-LABEL: func.func @torch.aten.addscalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int9 = torch.constant.int 9 -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor -// CHECK: %[[VAL_3:.*]] = chlo.broadcast_add %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +// CHECK-LABEL: func.func @torch.aten.addscalar$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT9:.*]] = torch.constant.int 9 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_add %[[T0]], %[[T4]] : (tensor, tensor) -> tensor +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %int9 = torch.constant.int 9 %int1 = torch.constant.int 1 @@ -95,17 +96,23 @@ func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // ----- -// CHECK-LABEL: func.func @torch.aten.addscalar$alpha( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int9 = torch.constant.int 9 -// CHECK: %int2 = torch.constant.int 2 -// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor -// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<2.000000e+00> : tensor -// CHECK: %[[VAL_4:.*]] = chlo.broadcast_multiply %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = chlo.broadcast_add %[[VAL_1]], %[[VAL_4]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> +// CHECK-LABEL: func.func @torch.aten.addscalar$alpha( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT9:.*]] = torch.constant.int 9 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T3:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T4:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T5:.*]] = "mhlo.reshape"(%[[T4]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T6:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> +// CHECK: %[[T7:.*]] = mhlo.convert(%[[T6]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T8:.*]] = "mhlo.reshape"(%[[T7]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T9:.*]] = chlo.broadcast_multiply %[[T5]], %[[T8]] : (tensor, tensor) -> tensor +// CHECK: %[[T10:.*]] = chlo.broadcast_add %[[T0]], %[[T9]] : (tensor, tensor) -> tensor +// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.addscalar$alpha(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %int9 = torch.constant.int 9 %int2 = torch.constant.int 2 @@ -115,15 +122,14 @@ func.func @torch.aten.addscalar$alpha(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // ----- -// CHECK-LABEL: func.func @torch.aten.addtensor$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = chlo.broadcast_add %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> +// CHECK-LABEL: func.func @torch.aten.addtensor$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = chlo.broadcast_add %[[T0]], %[[T1]] : (tensor, tensor) -> tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %int1 = torch.constant.int 1 %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> @@ -132,17 +138,19 @@ func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // ----- -// CHECK-LABEL: func.func @torch.aten.addtensor$alpha( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int2 = torch.constant.int 2 -// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<2.000000e+00> : tensor -// CHECK: %[[VAL_5:.*]] = chlo.broadcast_multiply %[[VAL_3]], %[[VAL_4]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = chlo.broadcast_add %[[VAL_2]], %[[VAL_5]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> +// CHECK-LABEL: func.func @torch.aten.addtensor$alpha( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T3:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> +// CHECK: %[[T4:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T5:.*]] = "mhlo.reshape"(%[[T4]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T6:.*]] = chlo.broadcast_multiply %[[T1]], %[[T5]] : (tensor, tensor) -> tensor +// CHECK: %[[T7:.*]] = chlo.broadcast_add %[[T0]], %[[T6]] : (tensor, tensor) -> tensor +// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T8]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.addtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %int2 = torch.constant.int 2 %0 = torch.aten.add.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> @@ -151,16 +159,15 @@ func.func @torch.aten.addtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // ----- -// CHECK-LABEL: func.func @torch.aten.addtensor$promote( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = mhlo.convert(%[[VAL_2]]) : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = chlo.broadcast_add %[[VAL_4]], %[[VAL_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],si64> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],si64> +// CHECK-LABEL: func.func @torch.aten.addtensor$promote( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],si32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = mhlo.convert(%[[T0]]) : (tensor) -> tensor +// CHECK: %[[T3:.*]] = chlo.broadcast_add %[[T2]], %[[T1]] : (tensor, tensor) -> tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?],si64> +// CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64> func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { %int1 = torch.constant.int 1 %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si64>, !torch.int -> !torch.vtensor<[?,?],si64> @@ -169,15 +176,18 @@ func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1 // ----- -// CHECK-LABEL: func.func @torch.aten.subscalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int9 = torch.constant.int 9 -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor -// CHECK: %[[VAL_3:.*]] = chlo.broadcast_subtract %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +// CHECK-LABEL: func.func @torch.aten.subscalar$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT9:.*]] = torch.constant.int 9 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_subtract %[[T0]], %[[T4]] : (tensor, tensor) -> tensor +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %int9 = torch.constant.int 9 %int1 = torch.constant.int 1 @@ -187,17 +197,23 @@ func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // ----- -// CHECK-LABEL: func.func @torch.aten.subscalar$alpha( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int9 = torch.constant.int 9 -// CHECK: %int2 = torch.constant.int 2 -// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor -// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<2.000000e+00> : tensor -// CHECK: %[[VAL_4:.*]] = chlo.broadcast_multiply %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = chlo.broadcast_subtract %[[VAL_1]], %[[VAL_4]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> +// CHECK-LABEL: func.func @torch.aten.subscalar$alpha( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT9:.*]] = torch.constant.int 9 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T3:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T4:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T5:.*]] = "mhlo.reshape"(%[[T4]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T6:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> +// CHECK: %[[T7:.*]] = mhlo.convert(%[[T6]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T8:.*]] = "mhlo.reshape"(%[[T7]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T9:.*]] = chlo.broadcast_multiply %[[T5]], %[[T8]] : (tensor, tensor) -> tensor +// CHECK: %[[T10:.*]] = chlo.broadcast_subtract %[[T0]], %[[T9]] : (tensor, tensor) -> tensor +// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T11]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.subscalar$alpha(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %int9 = torch.constant.int 9 %int2 = torch.constant.int 2 @@ -207,15 +223,14 @@ func.func @torch.aten.subscalar$alpha(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // ----- -// CHECK-LABEL: func.func @torch.aten.subtensor$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = chlo.broadcast_subtract %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> +// CHECK-LABEL: func.func @torch.aten.subtensor$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = chlo.broadcast_subtract %[[T0]], %[[T1]] : (tensor, tensor) -> tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %int1 = torch.constant.int 1 %0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> @@ -224,17 +239,19 @@ func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // ----- -// CHECK-LABEL: func.func @torch.aten.subtensor$alpha( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int2 = torch.constant.int 2 -// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<2.000000e+00> : tensor -// CHECK: %[[VAL_5:.*]] = chlo.broadcast_multiply %[[VAL_3]], %[[VAL_4]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = chlo.broadcast_subtract %[[VAL_2]], %[[VAL_5]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> +// CHECK-LABEL: func.func @torch.aten.subtensor$alpha( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T3:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> +// CHECK: %[[T4:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T5:.*]] = "mhlo.reshape"(%[[T4]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T6:.*]] = chlo.broadcast_multiply %[[T1]], %[[T5]] : (tensor, tensor) -> tensor +// CHECK: %[[T7:.*]] = chlo.broadcast_subtract %[[T0]], %[[T6]] : (tensor, tensor) -> tensor +// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T8]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.subtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %int2 = torch.constant.int 2 %0 = torch.aten.sub.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> @@ -243,16 +260,15 @@ func.func @torch.aten.subtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // ----- -// CHECK-LABEL: func.func @torch.aten.subtensor$promote( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = mhlo.convert(%[[VAL_2]]) : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = chlo.broadcast_subtract %[[VAL_4]], %[[VAL_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],si64> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],si64> +// CHECK-LABEL: func.func @torch.aten.subtensor$promote( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],si32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = mhlo.convert(%[[T0]]) : (tensor) -> tensor +// CHECK: %[[T3:.*]] = chlo.broadcast_subtract %[[T2]], %[[T1]] : (tensor, tensor) -> tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?],si64> +// CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64> func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { %int1 = torch.constant.int 1 %0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si64>, !torch.int -> !torch.vtensor<[?,?],si64> @@ -261,14 +277,17 @@ func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1 // ----- -// CHECK-LABEL: func.func @torch.aten.mulscalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int9 = torch.constant.int 9 -// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor -// CHECK: %[[VAL_3:.*]] = chlo.broadcast_multiply %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +// CHECK-LABEL: func.func @torch.aten.mulscalar$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT9:.*]] = torch.constant.int 9 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T0]], %[[T4]] : (tensor, tensor) -> tensor +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.mulscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %int9 = torch.constant.int 9 %0 = torch.aten.mul.Scalar %arg0, %int9 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> @@ -277,14 +296,13 @@ func.func @torch.aten.mulscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // ----- -// CHECK-LABEL: func.func @torch.aten.multensor$basic( -// CHECK-SAME: %[[VLA_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VLA_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VLA_2:.*]] = torch_c.to_builtin_tensor %[[VLA_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VLA_3:.*]] = torch_c.to_builtin_tensor %[[VLA_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VLA_4:.*]] = chlo.broadcast_multiply %[[VLA_2]], %[[VLA_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VLA_5:.*]] = torch_c.from_builtin_tensor %[[VLA_4]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VLA_5]] : !torch.vtensor<[?,?],f32> +// CHECK-LABEL: func.func @torch.aten.multensor$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T2:.*]] = chlo.broadcast_multiply %[[T0]], %[[T1]] : (tensor, tensor) -> tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> @@ -292,14 +310,17 @@ func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // ----- -// CHECK-LABEL: func.func @torch.aten.divscalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int9 = torch.constant.int 9 -// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor -// CHECK: %[[VAL_3:.*]] = chlo.broadcast_divide %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +// CHECK-LABEL: func.func @torch.aten.divscalar$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT9:.*]] = torch.constant.int 9 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_divide %[[T0]], %[[T4]] : (tensor, tensor) -> tensor +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.divscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %int9 = torch.constant.int 9 %0 = torch.aten.div.Scalar %arg0, %int9 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> @@ -308,14 +329,13 @@ func.func @torch.aten.divscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // ----- -// CHECK-LABEL: func.func @torch.aten.divtensor$basic( -// CHECK-SAME: %[[VLA_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VLA_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VLA_2:.*]] = torch_c.to_builtin_tensor %[[VLA_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VLA_3:.*]] = torch_c.to_builtin_tensor %[[VLA_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VLA_4:.*]] = chlo.broadcast_divide %[[VLA_2]], %[[VLA_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VLA_5:.*]] = torch_c.from_builtin_tensor %[[VLA_4]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VLA_5]] : !torch.vtensor<[?,?],f32> +// CHECK-LABEL: func.func @torch.aten.divtensor$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor, tensor) -> tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> @@ -323,14 +343,17 @@ func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // ----- -// CHECK-LABEL: func.func @torch.aten.gt.scalar( -// CHECK-SAME: %arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int3 = torch.constant.int 3 -// CHECK: %1 = mhlo.constant dense<3.000000e+00> : tensor -// CHECK: %2 = chlo.broadcast_compare %0, %1 {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor -// CHECK: %3 = torch_c.from_builtin_tensor %2 : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %3 : !torch.vtensor<[?,?],i1> +// CHECK-LABEL: func.func @torch.aten.gt.scalar( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT3:.*]] = torch.constant.int 3 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT3]] +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_compare %[[T0]], %[[T4]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[T6]] : !torch.vtensor<[?,?],i1> func.func @torch.aten.gt.scalar(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { %int3 = torch.constant.int 3 %0 = torch.aten.gt.Scalar %arg0, %int3 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],i1> @@ -339,14 +362,13 @@ func.func @torch.aten.gt.scalar(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // ----- -// CHECK-LABEL: func.func @torch.aten.gt.tensor( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> -// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK-LABEL: func.func @torch.aten.gt.tensor( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> func.func @torch.aten.gt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { %0 = torch.aten.gt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> return %0 : !torch.vtensor<[?,?],i1> @@ -354,14 +376,13 @@ func.func @torch.aten.gt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch. // ----- -// CHECK-LABEL: func.func @torch.aten.lt.tensor( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> -// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK-LABEL: func.func @torch.aten.lt.tensor( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> func.func @torch.aten.lt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { %0 = torch.aten.lt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> return %0 : !torch.vtensor<[?,?],i1> @@ -369,14 +390,13 @@ func.func @torch.aten.lt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch. // ----- -// CHECK-LABEL: func.func @torch.aten.eq.tensor( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> -// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK-LABEL: func.func @torch.aten.eq.tensor( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> func.func @torch.aten.eq.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { %0 = torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> return %0 : !torch.vtensor<[?,?],i1> @@ -384,14 +404,13 @@ func.func @torch.aten.eq.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch. // ----- -// CHECK-LABEL: func.func @torch.aten.ne.tensor( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> -// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK-LABEL: func.func @torch.aten.ne.tensor( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> func.func @torch.aten.ne.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { %0 = torch.aten.ne.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> return %0 : !torch.vtensor<[?,?],i1> @@ -399,15 +418,15 @@ func.func @torch.aten.ne.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch. // ----- -// CHECK-LABEL: func.func @torch.aten.permute$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 -// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_5:.*]] = "mhlo.transpose"(%[[VAL_1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<64x4xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<64x4xf32> -> !torch.vtensor<[64,4],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[64,4],f32> +// CHECK-LABEL: func.func @torch.aten.permute$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T2:.*]] = "mhlo.transpose"(%[[T0]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<64x4xf32> +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<64x4xf32> -> !torch.vtensor<[64,4],f32> +// CHECK: return %[[T3]] : !torch.vtensor<[64,4],f32> func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> { %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 @@ -418,14 +437,106 @@ func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch // ----- -// CHECK-LABEL: func.func @torch.aten.relu( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 0.000000e+00 : f32} : (tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = mhlo.maximum %[[VAL_1]], %[[VAL_2]] : tensor -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +// CHECK-LABEL: func.func @torch.aten.relu( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 0.000000e+00 : f32} : (tensor) -> tensor +// CHECK: %[[T2:.*]] = mhlo.maximum %[[T0]], %[[T1]] : tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %0 = torch.aten.relu %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addscalar$variable( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.float) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_f64 %[[ARG1]] +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xf64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64> +// CHECK: %[[T6:.*]] = mhlo.convert(%[[T5]]) : (tensor<1xf64>) -> tensor<1xf32> +// CHECK: %[[T7:.*]] = "mhlo.reshape"(%[[T6]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T8:.*]] = chlo.broadcast_multiply %[[T4]], %[[T7]] : (tensor, tensor) -> tensor +// CHECK: %[[T9:.*]] = chlo.broadcast_add %[[T0]], %[[T8]] : (tensor, tensor) -> tensor +// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.addscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.float) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.add.Scalar %arg0, %arg1, %arg1: !torch.vtensor<[?,?],f32>, !torch.float, !torch.float -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addtensor$variable( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG2:.*]]: !torch.float) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T2:.*]] = torch_c.to_f64 %[[ARG2]] +// CHECK: %[[T3:.*]] = tensor.from_elements %[[T2]] : tensor<1xf64> +// CHECK: %[[T4:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xf64>) -> tensor<1xf32> +// CHECK: %[[T5:.*]] = "mhlo.reshape"(%[[T4]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T6:.*]] = chlo.broadcast_multiply %[[T1]], %[[T5]] : (tensor, tensor) -> tensor +// CHECK: %[[T7:.*]] = chlo.broadcast_add %[[T0]], %[[T6]] : (tensor, tensor) -> tensor +// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T8]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.addtensor$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>, %arg2: !torch.float) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.add.Tensor %arg0, %arg1, %arg2: !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.mulscalar$variable( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T0]], %[[T4]] : (tensor, tensor) -> tensor +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.mulscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.mul.Scalar %arg0, %arg1: !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.divscalar$variable( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_divide %[[T0]], %[[T4]] : (tensor, tensor) -> tensor +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.divscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.div.Scalar %arg0, %arg1: !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.gt.scalar$variable( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_compare %[[T0]], %[[T4]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[T6]] : !torch.vtensor<[?,?],i1> +func.func @torch.aten.gt.scalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.gt.Scalar %arg0, %arg1: !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} +