From d294d47d089720238ce4a1c58f623efe10f77009 Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Fri, 22 Jul 2022 11:33:18 +0800 Subject: [PATCH] rebase, remove chlo and clang-format --- CMakeLists.txt | 3 + lib/Conversion/TorchToMhlo/CMakeLists.txt | 2 - lib/Conversion/TorchToMhlo/PopulatePatterns.h | 5 +- lib/Conversion/TorchToMhlo/SliceLikeOps.cpp | 249 ------------ lib/Conversion/TorchToMhlo/ViewLikeOps.cpp | 214 +++++++++- test/Conversion/TorchToMhlo/slice_like.mlir | 298 -------------- test/Conversion/TorchToMhlo/view_like.mlir | 382 ++++++++++++++++-- 7 files changed, 561 insertions(+), 592 deletions(-) delete mode 100644 lib/Conversion/TorchToMhlo/SliceLikeOps.cpp delete mode 100644 test/Conversion/TorchToMhlo/slice_like.mlir diff --git a/CMakeLists.txt b/CMakeLists.txt index eb43eeaf0963..10432d1e9c8f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,6 +39,9 @@ endmacro() option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON) if(TORCH_MLIR_ENABLE_MHLO) add_definitions(-DTORCH_MLIR_ENABLE_MHLO) + # The i64 calculation is much slower than i32 on some devices, such as Nvidia GPU. + # One can truncate from i64 to i32 since dimension sizes are unlikely to exceed + # the range of i32(4GiB) option(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 "Enable truncate dimension size from i64 to i32(unsafely)" OFF) if(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32) diff --git a/lib/Conversion/TorchToMhlo/CMakeLists.txt b/lib/Conversion/TorchToMhlo/CMakeLists.txt index 4d18a5ae611c..e1f544293045 100644 --- a/lib/Conversion/TorchToMhlo/CMakeLists.txt +++ b/lib/Conversion/TorchToMhlo/CMakeLists.txt @@ -8,7 +8,6 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo DEPENDS MhloDialect - ChloDialect TorchMLIRConversionPassIncGen LINK_COMPONENTS @@ -18,7 +17,6 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo MLIRIR MLIRPass MhloDialect - ChloDialect TorchMLIRTorchDialect ) diff --git a/lib/Conversion/TorchToMhlo/PopulatePatterns.h b/lib/Conversion/TorchToMhlo/PopulatePatterns.h index 8580dc3ecd14..97bb8602882d 100644 --- a/lib/Conversion/TorchToMhlo/PopulatePatterns.h +++ b/lib/Conversion/TorchToMhlo/PopulatePatterns.h @@ -20,9 +20,8 @@ void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target); void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter, - RewritePatternSet &patterns, - ConversionTarget &target); - + RewritePatternSet &patterns, + ConversionTarget &target); } // namespace torch_to_mhlo } // namespace torch diff --git a/lib/Conversion/TorchToMhlo/SliceLikeOps.cpp b/lib/Conversion/TorchToMhlo/SliceLikeOps.cpp deleted file mode 100644 index 37450f100339..000000000000 --- a/lib/Conversion/TorchToMhlo/SliceLikeOps.cpp +++ /dev/null @@ -1,249 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" - -#include "../PassDetail.h" -#include "./PopulatePatterns.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "torch-mlir/Conversion/Utils/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" -#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" - -using namespace mlir; -using namespace mlir::torch; -using namespace mlir::torch::Torch; - -#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 -static constexpr size_t kMhloDimSizeBits = 32; -#else -static constexpr size_t kMhloDimSizeBits = 64; -#endif - -namespace { - -SmallVector getDimSizesOfTensor( - PatternRewriter& rewriter, - Operation* op, - Value value) { - auto valueTy = value.getType().dyn_cast(); - if (!valueTy) { - op->emitOpError("getDimSizesOfTensor(): the input is not a ranked tensor"); - return {}; - } - - auto rank = valueTy.getRank(); - if (rank == 0) { - return {}; - } - - SmallVector dimSizes; - dimSizes.reserve(rank); - auto loc = op->getLoc(); - for (auto d = 0; d < rank; ++d) { - dimSizes.emplace_back(rewriter.create( - loc, - rewriter.getIntegerType(kMhloDimSizeBits), - rewriter.create(loc, value, d))); - } - return dimSizes; -} - -// A dimension index from torch.dialect might outside the range [0, dimSize]. -// The function is used to normalize the input index into the range. -Value getNormalizedDimSizeInternal( - PatternRewriter& rewriter, - Operation* op, - Value index, - Value dimSize) { - auto loc = op->getLoc(); - Value zero = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0)); - - // To normalize index into range [-dimSize, dimSize] - // index = min(max(-dimSize, index), dimSize) - auto negDimSize = rewriter.create(loc, zero, dimSize); - index = rewriter.create(loc, negDimSize, index); - index = rewriter.create(loc, dimSize, index); - - auto dimSizePlusIndex = rewriter.create(loc, dimSize, index); - auto indexPositive = rewriter.create( - loc, arith::CmpIPredicate::sge, index, zero); - // get positive index: (index >=0) ? index: index + dimSize - return rewriter.create( - loc, indexPositive, index, dimSizePlusIndex); -} - -Value getDynamicSliceInternal( - PatternRewriter& rewriter, - Operation* op, - Value input, - Value startIndex, - Value endIndex, - Value step, - size_t dimIndex, - ArrayRef dimSizes) { - auto loc = op->getLoc(); - // startIndex & endIndex has been normailized into range [0, dSize] - Type intType = rewriter.getIntegerType(kMhloDimSizeBits); - Value zero = rewriter.create( - loc, rewriter.getIntegerAttr(intType, 0)); - Value one = rewriter.create( - loc, rewriter.getIntegerAttr(intType, 1)); - - SmallVector startIndices; - SmallVector endIndices; - SmallVector strides; - - auto inputTy = input.getType().dyn_cast(); - size_t rank = inputTy.getRank(); - startIndices.reserve(rank); - endIndices.reserve(rank); - strides.reserve(rank); - - auto endIndexIsZero = rewriter.create( - loc, arith::CmpIPredicate::eq, endIndex, zero); - endIndex = rewriter.create( - loc, endIndexIsZero, dimSizes[dimIndex], endIndex); - - for (size_t r = 0; r < rank; ++r) { - if (r == dimIndex) { - startIndices.push_back(startIndex); - endIndices.push_back(endIndex); - strides.push_back(step); - } else { - startIndices.push_back(zero); - endIndices.push_back(dimSizes[r]); - strides.push_back(one); - } - } - - auto startTensor = - rewriter.create(loc, startIndices).getResult(); - auto endTensor = - rewriter.create(loc, endIndices).getResult(); - auto stridesTensor = - rewriter.create(loc, strides).getResult(); - - auto inputShape = inputTy.getShape(); - SmallVector sliceShape(inputShape.begin(), inputShape.end()); - sliceShape[dimIndex] = ShapedType::kDynamicSize; - auto sliceoutputTy = - RankedTensorType::get(sliceShape, inputTy.getElementType()); - return rewriter.create( - loc, sliceoutputTy, input, startTensor, endTensor, stridesTensor); -} - -// Get a dynamic slice of the tensor from startIndex to endIndex with stride step -// on the specifed dimension. The input startIndex(default to 0), -// endIndex(default to dimSize), and step(default to 1) can be optional. -Value getDynamicSlice( - PatternRewriter& rewriter, - Operation* op, - Value input, - llvm::Optional startIndexOpt, - llvm::Optional endIndexOpt, - llvm::Optional stepOpt, - int64_t dim) { - auto loc = op->getLoc(); - auto inputTy = input.getType().dyn_cast(); - auto rank = inputTy.getRank(); - - dim = (dim + rank) % rank; - Value dimSize = rewriter.create( - loc, - rewriter.getI64Type(), - rewriter.create(loc, input, dim)); - - Value normStartIndex = startIndexOpt - ? getNormalizedDimSizeInternal(rewriter, op, *startIndexOpt, dimSize) - : rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0)); - Value normEndIndex = endIndexOpt - ? getNormalizedDimSizeInternal(rewriter, op, *endIndexOpt, dimSize) - : dimSize; - Value step = stepOpt - ? *stepOpt - : rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1)); - -#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 - auto i32Type = rewriter.getIntegerType(kMhloDimSizeBits); - normStartIndex = - rewriter.create(loc, i32Type, normStartIndex); - normEndIndex = - rewriter.create(loc, i32Type, normEndIndex); - step = rewriter.create(loc, i32Type, step); -#endif - auto dimSizes = getDimSizesOfTensor(rewriter, op, input); - - return getDynamicSliceInternal( - rewriter, op, input, normStartIndex, normEndIndex, step, dim, dimSizes); -} - -template -class ConvertAtenOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename AtenOpT::Adaptor; - LogicalResult - matchAndRewrite(AtenOpT op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenSliceTensorOp op, - OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const { - auto self = adaptor.self(); - auto selfTy = self.getType().template cast(); - if (!selfTy) - return op.emitError("Only ranked tensor types supported in MHLO Rsub"); - int64_t dim; - if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) - return rewriter.notifyMatchFailure( - op, "Only constant dim is currently supported"); - - auto getOptionalVal = [&](Value val) -> llvm::Optional { - if (val.getType().isa()) { - return llvm::None; - } else { - return val; - } - }; - - llvm::Optional start = getOptionalVal(adaptor.start()); - llvm::Optional end = getOptionalVal(adaptor.end()); - llvm::Optional step = getOptionalVal(adaptor.step()); - - Value sliced = - getDynamicSlice(rewriter, op, self, start, end, step, dim); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), sliced); - - return success(); -} -} // namespace - -void mlir::torch::torch_to_mhlo::populateSliceLikeOpPatternsAndLegality( - TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target) { - MLIRContext *context = patterns.getContext(); - -#define INSERT_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_ATENOP_PATTERN(AtenSliceTensorOp); -#undef INSERT_ATENOP_PATTERN - -} diff --git a/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp b/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp index 78b62b501da5..0ecd96bf6293 100644 --- a/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp +++ b/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp @@ -13,7 +13,6 @@ #include "./PopulatePatterns.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -29,9 +28,200 @@ using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::TorchConversion; +#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 +static constexpr size_t kMhloDimSizeBits = 32; +#else +static constexpr size_t kMhloDimSizeBits = 64; +#endif namespace { +SmallVector getDimSizesOfTensor(PatternRewriter &rewriter, + Operation *op, Value value) { + auto valueTy = value.getType().dyn_cast(); + if (!valueTy) { + op->emitOpError("getDimSizesOfTensor(): the input is not a ranked tensor"); + return {}; + } + + auto rank = valueTy.getRank(); + if (rank == 0) { + return {}; + } + + SmallVector dimSizes; + dimSizes.reserve(rank); + auto loc = op->getLoc(); + for (auto d = 0; d < rank; ++d) { + dimSizes.emplace_back(rewriter.create( + loc, rewriter.getIntegerType(kMhloDimSizeBits), + rewriter.create(loc, value, d))); + } + return dimSizes; +} + +// A dimension index from torch.dialect might outside the range [0, dimSize]. +// The function is used to normalize the input index into the range. +Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op, + Value index, Value dimSize) { + auto loc = op->getLoc(); + Value zero = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0)); + + // To normalize index into range [-dimSize, dimSize] + // index = min(max(-dimSize, index), dimSize) + auto negDimSize = rewriter.create(loc, zero, dimSize); + index = rewriter.create(loc, negDimSize, index); + index = rewriter.create(loc, dimSize, index); + + auto dimSizePlusIndex = rewriter.create(loc, dimSize, index); + auto indexPositive = rewriter.create( + loc, arith::CmpIPredicate::sge, index, zero); + // get positive index: (index >=0) ? index: index + dimSize + return rewriter.create(loc, indexPositive, index, + dimSizePlusIndex); +} + +Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op, + Value input, Value startIndex, Value endIndex, + Value step, size_t dimIndex, + ArrayRef dimSizes) { + auto loc = op->getLoc(); + // startIndex & endIndex has been normailized into range [0, dSize] + Type intType = rewriter.getIntegerType(kMhloDimSizeBits); + Value zero = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 0)); + Value one = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 1)); + + SmallVector startIndices; + SmallVector endIndices; + SmallVector strides; + + auto inputTy = input.getType().dyn_cast(); + size_t rank = inputTy.getRank(); + startIndices.reserve(rank); + endIndices.reserve(rank); + strides.reserve(rank); + + auto endIndexIsZero = rewriter.create( + loc, arith::CmpIPredicate::eq, endIndex, zero); + endIndex = rewriter.create(loc, endIndexIsZero, + dimSizes[dimIndex], endIndex); + + for (size_t r = 0; r < rank; ++r) { + if (r == dimIndex) { + startIndices.push_back(startIndex); + endIndices.push_back(endIndex); + strides.push_back(step); + } else { + startIndices.push_back(zero); + endIndices.push_back(dimSizes[r]); + strides.push_back(one); + } + } + + auto startTensor = + rewriter.create(loc, startIndices).getResult(); + auto endTensor = + rewriter.create(loc, endIndices).getResult(); + auto stridesTensor = + rewriter.create(loc, strides).getResult(); + + auto inputShape = inputTy.getShape(); + SmallVector sliceShape(inputShape.begin(), inputShape.end()); + sliceShape[dimIndex] = ShapedType::kDynamicSize; + auto sliceoutputTy = + RankedTensorType::get(sliceShape, inputTy.getElementType()); + return rewriter.create( + loc, sliceoutputTy, input, startTensor, endTensor, stridesTensor); +} + +// Get a dynamic slice of the tensor from startIndex to endIndex with stride +// step on the specifed dimension. The input startIndex(default to 0), +// endIndex(default to dimSize), and step(default to 1) can be optional. +Value getDynamicSlice(PatternRewriter &rewriter, Operation *op, Value input, + llvm::Optional startIndexOpt, + llvm::Optional endIndexOpt, + llvm::Optional stepOpt, int64_t dim) { + auto loc = op->getLoc(); + auto inputTy = input.getType().dyn_cast(); + auto rank = inputTy.getRank(); + + dim = (dim + rank) % rank; + Value dimSize = rewriter.create( + loc, rewriter.getI64Type(), + rewriter.create(loc, input, dim)); + + Value normStartIndex = + startIndexOpt + ? getNormalizedDimSizeInternal(rewriter, op, *startIndexOpt, dimSize) + : rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0)); + Value normEndIndex = + endIndexOpt + ? getNormalizedDimSizeInternal(rewriter, op, *endIndexOpt, dimSize) + : dimSize; + Value step = + stepOpt ? *stepOpt + : rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1)); + +#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 + auto i32Type = rewriter.getIntegerType(kMhloDimSizeBits); + normStartIndex = + rewriter.create(loc, i32Type, normStartIndex); + normEndIndex = rewriter.create(loc, i32Type, normEndIndex); + step = rewriter.create(loc, i32Type, step); +#endif + auto dimSizes = getDimSizesOfTensor(rewriter, op, input); + + return getDynamicSliceInternal(rewriter, op, input, normStartIndex, + normEndIndex, step, dim, dimSizes); +} + +template +class ConvertAtenOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenSliceTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.self(); + auto selfTy = self.getType().template cast(); + if (!selfTy) + return op.emitError("Only ranked tensor types supported in MHLO Rsub"); + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "Only constant dim is currently supported"); + + auto getOptionalVal = [&](Value val) -> llvm::Optional { + if (val.getType().isa()) { + return llvm::None; + } else { + return val; + } + }; + + llvm::Optional start = getOptionalVal(adaptor.start()); + llvm::Optional end = getOptionalVal(adaptor.end()); + llvm::Optional step = getOptionalVal(adaptor.step()); + + Value sliced = getDynamicSlice(rewriter, op, self, start, end, step, dim); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), sliced); + + return success(); +} + // This defines a template to construct ops whose legalizations are // specialized. template @@ -81,13 +271,23 @@ class ConvertAtenViewOp : public OpConversionPattern { }); #endif + Type intType = rewriter.getIntegerType(kMhloDimSizeBits); + Value numel = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 1)); + for (auto d : dimSizes) { + numel = rewriter.create(loc, numel, d); + } + numel = rewriter.create(loc, rewriter.getIndexType(), + numel); + Value mhloShape = rewriter.create(loc, dimSizes); - rewriter.replaceOpWithNewOp( + Value computedShape = rewriter.create( + loc, mhloShape.getType(), numel, mhloShape); + rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), - adaptor.self(), - mhloShape); + adaptor.self(), computedShape); return success(); } @@ -123,6 +323,12 @@ void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality( ConversionTarget &target) { MLIRContext *context = patterns.getContext(); +#define INSERT_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_ATENOP_PATTERN(AtenSliceTensorOp); +#undef INSERT_ATENOP_PATTERN + #define INSERT_VIEW_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); diff --git a/test/Conversion/TorchToMhlo/slice_like.mlir b/test/Conversion/TorchToMhlo/slice_like.mlir deleted file mode 100644 index 4963eca14c20..000000000000 --- a/test/Conversion/TorchToMhlo/slice_like.mlir +++ /dev/null @@ -1,298 +0,0 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s - -// CHECK-LABEL: func.func @torch.aten.slice.strided$slice_like( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] -// CHECK: %[[INT9223372036854775807:.*]] = torch.constant.int 9223372036854775807 -// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT9223372036854775807]] -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 -// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 -// CHECK: %[[T6:.*]] = arith.subi %[[C0_I64]], %[[T5]] : i64 -// CHECK: %[[T7:.*]] = arith.maxsi %[[T6]], %[[T1]] : i64 -// CHECK: %[[T8:.*]] = arith.minsi %[[T5]], %[[T7]] : i64 -// CHECK: %[[T9:.*]] = arith.addi %[[T5]], %[[T8]] : i64 -// CHECK: %[[T10:.*]] = arith.cmpi sge, %[[T8]], %[[C0_I64]] : i64 -// CHECK: %[[T11:.*]] = arith.select %[[T10]], %[[T8]], %[[T9]] : i64 -// CHECK: %[[C0_I64_0:.*]] = arith.constant 0 : i64 -// CHECK: %[[T12:.*]] = arith.subi %[[C0_I64_0]], %[[T5]] : i64 -// CHECK: %[[T13:.*]] = arith.maxsi %[[T12]], %[[T3]] : i64 -// CHECK: %[[T14:.*]] = arith.minsi %[[T5]], %[[T13]] : i64 -// CHECK: %[[T15:.*]] = arith.addi %[[T5]], %[[T14]] : i64 -// CHECK: %[[T16:.*]] = arith.cmpi sge, %[[T14]], %[[C0_I64_0]] : i64 -// CHECK: %[[T17:.*]] = arith.select %[[T16]], %[[T14]], %[[T15]] : i64 -// CHECK: %[[C0_1:.*]] = arith.constant 0 : index -// CHECK: %[[T18:.*]] = tensor.dim %[[T0]], %[[C0_1]] : tensor -// CHECK: %[[T19:.*]] = arith.index_cast %[[T18]] : index to i64 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[T20:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T21:.*]] = arith.index_cast %[[T20]] : index to i64 -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[T22:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T23:.*]] = arith.index_cast %[[T22]] : index to i64 -// CHECK: %[[C0_I64_2:.*]] = arith.constant 0 : i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[T24:.*]] = arith.cmpi eq, %[[T17]], %[[C0_I64_2]] : i64 -// CHECK: %[[T25:.*]] = arith.select %[[T24]], %[[T19]], %[[T17]] : i64 -// CHECK: %[[T26:.*]] = tensor.from_elements %[[T11]], %[[C0_I64_2]], %[[C0_I64_2]] : tensor<3xi64> -// CHECK: %[[T27:.*]] = tensor.from_elements %[[T25]], %[[T21]], %[[T23]] : tensor<3xi64> -// CHECK: %[[T28:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T29:.*]] = "mhlo.real_dynamic_slice"(%[[T0]], %[[T26]], %[[T27]], %[[T28]]) : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor -// CHECK: %[[T30:.*]] = mhlo.convert %[[T29]] : tensor -// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor -> !torch.vtensor<[?,?,?],f32> -// CHECK: return %[[T31]] : !torch.vtensor<[?,?,?],f32> -func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 - %int9223372036854775807 = torch.constant.int 9223372036854775807 - %0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int9223372036854775807, %int2 : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?],f32> - return %0 : !torch.vtensor<[?,?,?],f32> -} - -// CHECK-LABEL: func.func @torch.aten.slice.strided.static$slice_like( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] -// CHECK: %[[INT9223372036854775807:.*]] = torch.constant.int 9223372036854775807 -// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT9223372036854775807]] -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x65x256xf32> -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 -// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 -// CHECK: %[[T6:.*]] = arith.subi %[[C0_I64]], %[[T5]] : i64 -// CHECK: %[[T7:.*]] = arith.maxsi %[[T6]], %[[T1]] : i64 -// CHECK: %[[T8:.*]] = arith.minsi %[[T5]], %[[T7]] : i64 -// CHECK: %[[T9:.*]] = arith.addi %[[T5]], %[[T8]] : i64 -// CHECK: %[[T10:.*]] = arith.cmpi sge, %[[T8]], %[[C0_I64]] : i64 -// CHECK: %[[T11:.*]] = arith.select %[[T10]], %[[T8]], %[[T9]] : i64 -// CHECK: %[[C0_I64_0:.*]] = arith.constant 0 : i64 -// CHECK: %[[T12:.*]] = arith.subi %[[C0_I64_0]], %[[T5]] : i64 -// CHECK: %[[T13:.*]] = arith.maxsi %[[T12]], %[[T3]] : i64 -// CHECK: %[[T14:.*]] = arith.minsi %[[T5]], %[[T13]] : i64 -// CHECK: %[[T15:.*]] = arith.addi %[[T5]], %[[T14]] : i64 -// CHECK: %[[T16:.*]] = arith.cmpi sge, %[[T14]], %[[C0_I64_0]] : i64 -// CHECK: %[[T17:.*]] = arith.select %[[T16]], %[[T14]], %[[T15]] : i64 -// CHECK: %[[C0_1:.*]] = arith.constant 0 : index -// CHECK: %[[T18:.*]] = tensor.dim %[[T0]], %[[C0_1]] : tensor<4x65x256xf32> -// CHECK: %[[T19:.*]] = arith.index_cast %[[T18]] : index to i64 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[T20:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<4x65x256xf32> -// CHECK: %[[T21:.*]] = arith.index_cast %[[T20]] : index to i64 -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[T22:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<4x65x256xf32> -// CHECK: %[[T23:.*]] = arith.index_cast %[[T22]] : index to i64 -// CHECK: %[[C0_I64_2:.*]] = arith.constant 0 : i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[T24:.*]] = arith.cmpi eq, %[[T17]], %[[C0_I64_2]] : i64 -// CHECK: %[[T25:.*]] = arith.select %[[T24]], %[[T19]], %[[T17]] : i64 -// CHECK: %[[T26:.*]] = tensor.from_elements %[[T11]], %[[C0_I64_2]], %[[C0_I64_2]] : tensor<3xi64> -// CHECK: %[[T27:.*]] = tensor.from_elements %[[T25]], %[[T21]], %[[T23]] : tensor<3xi64> -// CHECK: %[[T28:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T29:.*]] = "mhlo.real_dynamic_slice"(%[[T0]], %[[T26]], %[[T27]], %[[T28]]) : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor -// CHECK: %[[T30:.*]] = mhlo.convert(%[[T29]]) : (tensor) -> tensor<2x65x256xf32> -// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32> -// CHECK: return %[[T31]] : !torch.vtensor<[2,65,256],f32> -func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> { - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 - %int9223372036854775807 = torch.constant.int 9223372036854775807 - %0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int9223372036854775807, %int2 : !torch.vtensor<[4,65,256],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,65,256],f32> - return %0 : !torch.vtensor<[2,65,256],f32> -} - - -// CHECK-LABEL: func.func @torch.aten.slice.last$slice_like( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT1]] -// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 -// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT]]-1 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 -// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 -// CHECK: %[[T6:.*]] = arith.subi %[[C0_I64]], %[[T5]] : i64 -// CHECK: %[[T7:.*]] = arith.maxsi %[[T6]], %[[T3]] : i64 -// CHECK: %[[T8:.*]] = arith.minsi %[[T5]], %[[T7]] : i64 -// CHECK: %[[T9:.*]] = arith.addi %[[T5]], %[[T8]] : i64 -// CHECK: %[[T10:.*]] = arith.cmpi sge, %[[T8]], %[[C0_I64]] : i64 -// CHECK: %[[T11:.*]] = arith.select %[[T10]], %[[T8]], %[[T9]] : i64 -// CHECK: %[[C0_I64_0:.*]] = arith.constant 0 : i64 -// CHECK: %[[T12:.*]] = arith.subi %[[C0_I64_0]], %[[T5]] : i64 -// CHECK: %[[T13:.*]] = arith.maxsi %[[T12]], %[[T1]] : i64 -// CHECK: %[[T14:.*]] = arith.minsi %[[T5]], %[[T13]] : i64 -// CHECK: %[[T15:.*]] = arith.addi %[[T5]], %[[T14]] : i64 -// CHECK: %[[T16:.*]] = arith.cmpi sge, %[[T14]], %[[C0_I64_0]] : i64 -// CHECK: %[[T17:.*]] = arith.select %[[T16]], %[[T14]], %[[T15]] : i64 -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[T18:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T19:.*]] = arith.index_cast %[[T18]] : index to i64 -// CHECK: %[[C1_1:.*]] = arith.constant 1 : index -// CHECK: %[[T20:.*]] = tensor.dim %[[T0]], %[[C1_1]] : tensor -// CHECK: %[[T21:.*]] = arith.index_cast %[[T20]] : index to i64 -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[T22:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T23:.*]] = arith.index_cast %[[T22]] : index to i64 -// CHECK: %[[C0_I64_2:.*]] = arith.constant 0 : i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[T24:.*]] = arith.cmpi eq, %[[T17]], %[[C0_I64_2]] : i64 -// CHECK: %[[T25:.*]] = arith.select %[[T24]], %[[T21]], %[[T17]] : i64 -// CHECK: %[[T26:.*]] = tensor.from_elements %[[C0_I64_2]], %[[T11]], %[[C0_I64_2]] : tensor<3xi64> -// CHECK: %[[T27:.*]] = tensor.from_elements %[[T19]], %[[T25]], %[[T23]] : tensor<3xi64> -// CHECK: %[[T28:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T29:.*]] = "mhlo.real_dynamic_slice"(%[[T0]], %[[T26]], %[[T27]], %[[T28]]) : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor -// CHECK: %[[T30:.*]] = mhlo.convert(%[[T29]]) : (tensor) -> tensor -// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor -> !torch.vtensor<[?,1,?],f32> -// CHECK: return %[[T31]] : !torch.vtensor<[?,1,?],f32> -func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> { - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %int-1 = torch.constant.int -1 - %0 = torch.aten.slice.Tensor %arg0, %int1, %int-1, %int0, %int1 : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,1,?],f32> - return %0 : !torch.vtensor<[?,1,?],f32> -} - - -// CHECK-LABEL: func.func @torch.aten.slice.last.static$slice_like( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT1]] -// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 -// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT]]-1 -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<4x65x256xf32> -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 -// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 -// CHECK: %[[T6:.*]] = arith.subi %[[C0_I64]], %[[T5]] : i64 -// CHECK: %[[T7:.*]] = arith.maxsi %[[T6]], %[[T3]] : i64 -// CHECK: %[[T8:.*]] = arith.minsi %[[T5]], %[[T7]] : i64 -// CHECK: %[[T9:.*]] = arith.addi %[[T5]], %[[T8]] : i64 -// CHECK: %[[T10:.*]] = arith.cmpi sge, %[[T8]], %[[C0_I64]] : i64 -// CHECK: %[[T11:.*]] = arith.select %[[T10]], %[[T8]], %[[T9]] : i64 -// CHECK: %[[C0_I64_0:.*]] = arith.constant 0 : i64 -// CHECK: %[[T12:.*]] = arith.subi %[[C0_I64_0]], %[[T5]] : i64 -// CHECK: %[[T13:.*]] = arith.maxsi %[[T12]], %[[T1]] : i64 -// CHECK: %[[T14:.*]] = arith.minsi %[[T5]], %[[T13]] : i64 -// CHECK: %[[T15:.*]] = arith.addi %[[T5]], %[[T14]] : i64 -// CHECK: %[[T16:.*]] = arith.cmpi sge, %[[T14]], %[[C0_I64_0]] : i64 -// CHECK: %[[T17:.*]] = arith.select %[[T16]], %[[T14]], %[[T15]] : i64 -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[T18:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x65x256xf32> -// CHECK: %[[T19:.*]] = arith.index_cast %[[T18]] : index to i64 -// CHECK: %[[C1_1:.*]] = arith.constant 1 : index -// CHECK: %[[T20:.*]] = tensor.dim %[[T0]], %[[C1_1]] : tensor<4x65x256xf32> -// CHECK: %[[T21:.*]] = arith.index_cast %[[T20]] : index to i64 -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[T22:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<4x65x256xf32> -// CHECK: %[[T23:.*]] = arith.index_cast %[[T22]] : index to i64 -// CHECK: %[[C0_I64_2:.*]] = arith.constant 0 : i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[T24:.*]] = arith.cmpi eq, %[[T17]], %[[C0_I64_2]] : i64 -// CHECK: %[[T25:.*]] = arith.select %[[T24]], %[[T21]], %[[T17]] : i64 -// CHECK: %[[T26:.*]] = tensor.from_elements %[[C0_I64_2]], %[[T11]], %[[C0_I64_2]] : tensor<3xi64> -// CHECK: %[[T27:.*]] = tensor.from_elements %[[T19]], %[[T25]], %[[T23]] : tensor<3xi64> -// CHECK: %[[T28:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T29:.*]] = "mhlo.real_dynamic_slice"(%[[T0]], %[[T26]], %[[T27]], %[[T28]]) : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x?x256xf32> -// CHECK: %[[T30:.*]] = mhlo.convert(%[[T29]]) : (tensor<4x?x256xf32>) -> tensor<4x1x256xf32> -// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32> -// CHECK: return %[[T31]] : !torch.vtensor<[4,1,256],f32> -func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> { - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %int-1 = torch.constant.int -1 - %0 = torch.aten.slice.Tensor %arg0, %int1, %int-1, %int0, %int1 : !torch.vtensor<[4,65,256],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1,256],f32> - return %0 : !torch.vtensor<[4,1,256],f32> -} - - -// CHECK-LABEL: func.func @torch.aten.slice.none$slice_like( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT2]] -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 -// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 -// CHECK: %[[C1_0:.*]] = arith.constant 1 : index -// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1_0]] : tensor -// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[T8:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T9:.*]] = arith.index_cast %[[T8]] : index to i64 -// CHECK: %[[C0_I64_1:.*]] = arith.constant 0 : i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[T10:.*]] = arith.cmpi eq, %[[T3]], %[[C0_I64_1]] : i64 -// CHECK: %[[T11:.*]] = arith.select %[[T10]], %[[T7]], %[[T3]] : i64 -// CHECK: %[[T12:.*]] = tensor.from_elements %[[C0_I64_1]], %[[C0_I64]], %[[C0_I64_1]] : tensor<3xi64> -// CHECK: %[[T13:.*]] = tensor.from_elements %[[T5]], %[[T11]], %[[T9]] : tensor<3xi64> -// CHECK: %[[T14:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T15:.*]] = "mhlo.real_dynamic_slice"(%[[T0]], %[[T12]], %[[T13]], %[[T14]]) : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor -// CHECK: %[[T16:.*]] = mhlo.convert %[[T15]] : tensor -// CHECK: %[[T17:.*]] = torch_c.from_builtin_tensor %[[T16]] : tensor -> !torch.vtensor<[?,?,?],f32> -// CHECK: return %[[T17]] : !torch.vtensor<[?,?,?],f32> -func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %none = torch.constant.none - %0 = torch.aten.slice.Tensor %arg0, %int1, %none, %none, %int2 : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.none, !torch.none, !torch.int -> !torch.vtensor<[?,?,?],f32> - return %0 : !torch.vtensor<[?,?,?],f32> -} - -// CHECK-LABEL: func.func @torch.aten.slice.none.static$slice_like( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT2]] -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<4x65x256xf32> -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 -// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x65x256xf32> -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 -// CHECK: %[[C1_0:.*]] = arith.constant 1 : index -// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1_0]] : tensor<4x65x256xf32> -// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[T8:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<4x65x256xf32> -// CHECK: %[[T9:.*]] = arith.index_cast %[[T8]] : index to i64 -// CHECK: %[[C0_I64_1:.*]] = arith.constant 0 : i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[T10:.*]] = arith.cmpi eq, %[[T3]], %[[C0_I64_1]] : i64 -// CHECK: %[[T11:.*]] = arith.select %[[T10]], %[[T7]], %[[T3]] : i64 -// CHECK: %[[T12:.*]] = tensor.from_elements %[[C0_I64_1]], %[[C0_I64]], %[[C0_I64_1]] : tensor<3xi64> -// CHECK: %[[T13:.*]] = tensor.from_elements %[[T5]], %[[T11]], %[[T9]] : tensor<3xi64> -// CHECK: %[[T14:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T15:.*]] = "mhlo.real_dynamic_slice"(%[[T0]], %[[T12]], %[[T13]], %[[T14]]) : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x?x256xf32> -// CHECK: %[[T16:.*]] = mhlo.convert(%[[T15]]) : (tensor<4x?x256xf32>) -> tensor<4x33x256xf32> -// CHECK: %[[T17:.*]] = torch_c.from_builtin_tensor %[[T16]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32> -// CHECK: return %[[T17]] : !torch.vtensor<[4,33,256],f32> -func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> { - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %none = torch.constant.none - %0 = torch.aten.slice.Tensor %arg0, %int1, %none, %none, %int2 : !torch.vtensor<[4,65,256],f32>, !torch.int, !torch.none, !torch.none, !torch.int -> !torch.vtensor<[4,33,256],f32> - return %0 : !torch.vtensor<[4,33,256],f32> -} diff --git a/test/Conversion/TorchToMhlo/view_like.mlir b/test/Conversion/TorchToMhlo/view_like.mlir index e8b2a2e021e0..2e6394a76192 100644 --- a/test/Conversion/TorchToMhlo/view_like.mlir +++ b/test/Conversion/TorchToMhlo/view_like.mlir @@ -1,18 +1,320 @@ // RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s -// CHECK-LABEL: func.func @torch.aten.view$view_like( +// CHECK-LABEL: func.func @torch.aten.slice.strided$slice_like( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[INT9223372036854775807:.*]] = torch.constant.int 9223372036854775807 +// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT9223372036854775807]] +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 +// CHECK: %[[T6:.*]] = arith.subi %[[C0_I64]], %[[T5]] : i64 +// CHECK: %[[T7:.*]] = arith.maxsi %[[T6]], %[[T1]] : i64 +// CHECK: %[[T8:.*]] = arith.minsi %[[T5]], %[[T7]] : i64 +// CHECK: %[[T9:.*]] = arith.addi %[[T5]], %[[T8]] : i64 +// CHECK: %[[T10:.*]] = arith.cmpi sge, %[[T8]], %[[C0_I64]] : i64 +// CHECK: %[[T11:.*]] = arith.select %[[T10]], %[[T8]], %[[T9]] : i64 +// CHECK: %[[C0_I64_0:.*]] = arith.constant 0 : i64 +// CHECK: %[[T12:.*]] = arith.subi %[[C0_I64_0]], %[[T5]] : i64 +// CHECK: %[[T13:.*]] = arith.maxsi %[[T12]], %[[T3]] : i64 +// CHECK: %[[T14:.*]] = arith.minsi %[[T5]], %[[T13]] : i64 +// CHECK: %[[T15:.*]] = arith.addi %[[T5]], %[[T14]] : i64 +// CHECK: %[[T16:.*]] = arith.cmpi sge, %[[T14]], %[[C0_I64_0]] : i64 +// CHECK: %[[T17:.*]] = arith.select %[[T16]], %[[T14]], %[[T15]] : i64 +// CHECK: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK: %[[T18:.*]] = tensor.dim %[[T0]], %[[C0_1]] : tensor +// CHECK: %[[T19:.*]] = arith.index_cast %[[T18]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T20:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T21:.*]] = arith.index_cast %[[T20]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T22:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T23:.*]] = arith.index_cast %[[T22]] : index to i64 +// CHECK: %[[C0_I64_2:.*]] = arith.constant 0 : i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T24:.*]] = arith.cmpi eq, %[[T17]], %[[C0_I64_2]] : i64 +// CHECK: %[[T25:.*]] = arith.select %[[T24]], %[[T19]], %[[T17]] : i64 +// CHECK: %[[T26:.*]] = tensor.from_elements %[[T11]], %[[C0_I64_2]], %[[C0_I64_2]] : tensor<3xi64> +// CHECK: %[[T27:.*]] = tensor.from_elements %[[T25]], %[[T21]], %[[T23]] : tensor<3xi64> +// CHECK: %[[T28:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> +// CHECK: %[[T29:.*]] = "mhlo.real_dynamic_slice"(%[[T0]], %[[T26]], %[[T27]], %[[T28]]) : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T30:.*]] = mhlo.convert %[[T29]] : tensor +// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor -> !torch.vtensor<[?,?,?],f32> +// CHECK: return %[[T31]] : !torch.vtensor<[?,?,?],f32> +func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int9223372036854775807 = torch.constant.int 9223372036854775807 + %0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int9223372036854775807, %int2 : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.slice.strided.static$slice_like( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[INT9223372036854775807:.*]] = torch.constant.int 9223372036854775807 +// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT9223372036854775807]] +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x65x256xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 +// CHECK: %[[T6:.*]] = arith.subi %[[C0_I64]], %[[T5]] : i64 +// CHECK: %[[T7:.*]] = arith.maxsi %[[T6]], %[[T1]] : i64 +// CHECK: %[[T8:.*]] = arith.minsi %[[T5]], %[[T7]] : i64 +// CHECK: %[[T9:.*]] = arith.addi %[[T5]], %[[T8]] : i64 +// CHECK: %[[T10:.*]] = arith.cmpi sge, %[[T8]], %[[C0_I64]] : i64 +// CHECK: %[[T11:.*]] = arith.select %[[T10]], %[[T8]], %[[T9]] : i64 +// CHECK: %[[C0_I64_0:.*]] = arith.constant 0 : i64 +// CHECK: %[[T12:.*]] = arith.subi %[[C0_I64_0]], %[[T5]] : i64 +// CHECK: %[[T13:.*]] = arith.maxsi %[[T12]], %[[T3]] : i64 +// CHECK: %[[T14:.*]] = arith.minsi %[[T5]], %[[T13]] : i64 +// CHECK: %[[T15:.*]] = arith.addi %[[T5]], %[[T14]] : i64 +// CHECK: %[[T16:.*]] = arith.cmpi sge, %[[T14]], %[[C0_I64_0]] : i64 +// CHECK: %[[T17:.*]] = arith.select %[[T16]], %[[T14]], %[[T15]] : i64 +// CHECK: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK: %[[T18:.*]] = tensor.dim %[[T0]], %[[C0_1]] : tensor<4x65x256xf32> +// CHECK: %[[T19:.*]] = arith.index_cast %[[T18]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T20:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<4x65x256xf32> +// CHECK: %[[T21:.*]] = arith.index_cast %[[T20]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T22:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<4x65x256xf32> +// CHECK: %[[T23:.*]] = arith.index_cast %[[T22]] : index to i64 +// CHECK: %[[C0_I64_2:.*]] = arith.constant 0 : i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T24:.*]] = arith.cmpi eq, %[[T17]], %[[C0_I64_2]] : i64 +// CHECK: %[[T25:.*]] = arith.select %[[T24]], %[[T19]], %[[T17]] : i64 +// CHECK: %[[T26:.*]] = tensor.from_elements %[[T11]], %[[C0_I64_2]], %[[C0_I64_2]] : tensor<3xi64> +// CHECK: %[[T27:.*]] = tensor.from_elements %[[T25]], %[[T21]], %[[T23]] : tensor<3xi64> +// CHECK: %[[T28:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> +// CHECK: %[[T29:.*]] = "mhlo.real_dynamic_slice"(%[[T0]], %[[T26]], %[[T27]], %[[T28]]) : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T30:.*]] = mhlo.convert(%[[T29]]) : (tensor) -> tensor<2x65x256xf32> +// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32> +// CHECK: return %[[T31]] : !torch.vtensor<[2,65,256],f32> +func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int9223372036854775807 = torch.constant.int 9223372036854775807 + %0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int9223372036854775807, %int2 : !torch.vtensor<[4,65,256],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,65,256],f32> + return %0 : !torch.vtensor<[2,65,256],f32> +} + + +// CHECK-LABEL: func.func @torch.aten.slice.last$slice_like( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT1]] +// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 +// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT]]-1 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 +// CHECK: %[[T6:.*]] = arith.subi %[[C0_I64]], %[[T5]] : i64 +// CHECK: %[[T7:.*]] = arith.maxsi %[[T6]], %[[T3]] : i64 +// CHECK: %[[T8:.*]] = arith.minsi %[[T5]], %[[T7]] : i64 +// CHECK: %[[T9:.*]] = arith.addi %[[T5]], %[[T8]] : i64 +// CHECK: %[[T10:.*]] = arith.cmpi sge, %[[T8]], %[[C0_I64]] : i64 +// CHECK: %[[T11:.*]] = arith.select %[[T10]], %[[T8]], %[[T9]] : i64 +// CHECK: %[[C0_I64_0:.*]] = arith.constant 0 : i64 +// CHECK: %[[T12:.*]] = arith.subi %[[C0_I64_0]], %[[T5]] : i64 +// CHECK: %[[T13:.*]] = arith.maxsi %[[T12]], %[[T1]] : i64 +// CHECK: %[[T14:.*]] = arith.minsi %[[T5]], %[[T13]] : i64 +// CHECK: %[[T15:.*]] = arith.addi %[[T5]], %[[T14]] : i64 +// CHECK: %[[T16:.*]] = arith.cmpi sge, %[[T14]], %[[C0_I64_0]] : i64 +// CHECK: %[[T17:.*]] = arith.select %[[T16]], %[[T14]], %[[T15]] : i64 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T18:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T19:.*]] = arith.index_cast %[[T18]] : index to i64 +// CHECK: %[[C1_1:.*]] = arith.constant 1 : index +// CHECK: %[[T20:.*]] = tensor.dim %[[T0]], %[[C1_1]] : tensor +// CHECK: %[[T21:.*]] = arith.index_cast %[[T20]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T22:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T23:.*]] = arith.index_cast %[[T22]] : index to i64 +// CHECK: %[[C0_I64_2:.*]] = arith.constant 0 : i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T24:.*]] = arith.cmpi eq, %[[T17]], %[[C0_I64_2]] : i64 +// CHECK: %[[T25:.*]] = arith.select %[[T24]], %[[T21]], %[[T17]] : i64 +// CHECK: %[[T26:.*]] = tensor.from_elements %[[C0_I64_2]], %[[T11]], %[[C0_I64_2]] : tensor<3xi64> +// CHECK: %[[T27:.*]] = tensor.from_elements %[[T19]], %[[T25]], %[[T23]] : tensor<3xi64> +// CHECK: %[[T28:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> +// CHECK: %[[T29:.*]] = "mhlo.real_dynamic_slice"(%[[T0]], %[[T26]], %[[T27]], %[[T28]]) : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T30:.*]] = mhlo.convert(%[[T29]]) : (tensor) -> tensor +// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor -> !torch.vtensor<[?,1,?],f32> +// CHECK: return %[[T31]] : !torch.vtensor<[?,1,?],f32> +func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int-1 = torch.constant.int -1 + %0 = torch.aten.slice.Tensor %arg0, %int1, %int-1, %int0, %int1 : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,1,?],f32> + return %0 : !torch.vtensor<[?,1,?],f32> +} + + +// CHECK-LABEL: func.func @torch.aten.slice.last.static$slice_like( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT1]] +// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 +// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT]]-1 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<4x65x256xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 +// CHECK: %[[T6:.*]] = arith.subi %[[C0_I64]], %[[T5]] : i64 +// CHECK: %[[T7:.*]] = arith.maxsi %[[T6]], %[[T3]] : i64 +// CHECK: %[[T8:.*]] = arith.minsi %[[T5]], %[[T7]] : i64 +// CHECK: %[[T9:.*]] = arith.addi %[[T5]], %[[T8]] : i64 +// CHECK: %[[T10:.*]] = arith.cmpi sge, %[[T8]], %[[C0_I64]] : i64 +// CHECK: %[[T11:.*]] = arith.select %[[T10]], %[[T8]], %[[T9]] : i64 +// CHECK: %[[C0_I64_0:.*]] = arith.constant 0 : i64 +// CHECK: %[[T12:.*]] = arith.subi %[[C0_I64_0]], %[[T5]] : i64 +// CHECK: %[[T13:.*]] = arith.maxsi %[[T12]], %[[T1]] : i64 +// CHECK: %[[T14:.*]] = arith.minsi %[[T5]], %[[T13]] : i64 +// CHECK: %[[T15:.*]] = arith.addi %[[T5]], %[[T14]] : i64 +// CHECK: %[[T16:.*]] = arith.cmpi sge, %[[T14]], %[[C0_I64_0]] : i64 +// CHECK: %[[T17:.*]] = arith.select %[[T16]], %[[T14]], %[[T15]] : i64 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T18:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x65x256xf32> +// CHECK: %[[T19:.*]] = arith.index_cast %[[T18]] : index to i64 +// CHECK: %[[C1_1:.*]] = arith.constant 1 : index +// CHECK: %[[T20:.*]] = tensor.dim %[[T0]], %[[C1_1]] : tensor<4x65x256xf32> +// CHECK: %[[T21:.*]] = arith.index_cast %[[T20]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T22:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<4x65x256xf32> +// CHECK: %[[T23:.*]] = arith.index_cast %[[T22]] : index to i64 +// CHECK: %[[C0_I64_2:.*]] = arith.constant 0 : i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T24:.*]] = arith.cmpi eq, %[[T17]], %[[C0_I64_2]] : i64 +// CHECK: %[[T25:.*]] = arith.select %[[T24]], %[[T21]], %[[T17]] : i64 +// CHECK: %[[T26:.*]] = tensor.from_elements %[[C0_I64_2]], %[[T11]], %[[C0_I64_2]] : tensor<3xi64> +// CHECK: %[[T27:.*]] = tensor.from_elements %[[T19]], %[[T25]], %[[T23]] : tensor<3xi64> +// CHECK: %[[T28:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> +// CHECK: %[[T29:.*]] = "mhlo.real_dynamic_slice"(%[[T0]], %[[T26]], %[[T27]], %[[T28]]) : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x?x256xf32> +// CHECK: %[[T30:.*]] = mhlo.convert(%[[T29]]) : (tensor<4x?x256xf32>) -> tensor<4x1x256xf32> +// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32> +// CHECK: return %[[T31]] : !torch.vtensor<[4,1,256],f32> +func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int-1 = torch.constant.int -1 + %0 = torch.aten.slice.Tensor %arg0, %int1, %int-1, %int0, %int1 : !torch.vtensor<[4,65,256],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1,256],f32> + return %0 : !torch.vtensor<[4,1,256],f32> +} + + +// CHECK-LABEL: func.func @torch.aten.slice.none$slice_like( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C1_0:.*]] = arith.constant 1 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1_0]] : tensor +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T8:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor +// CHECK: %[[T9:.*]] = arith.index_cast %[[T8]] : index to i64 +// CHECK: %[[C0_I64_1:.*]] = arith.constant 0 : i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T10:.*]] = arith.cmpi eq, %[[T3]], %[[C0_I64_1]] : i64 +// CHECK: %[[T11:.*]] = arith.select %[[T10]], %[[T7]], %[[T3]] : i64 +// CHECK: %[[T12:.*]] = tensor.from_elements %[[C0_I64_1]], %[[C0_I64]], %[[C0_I64_1]] : tensor<3xi64> +// CHECK: %[[T13:.*]] = tensor.from_elements %[[T5]], %[[T11]], %[[T9]] : tensor<3xi64> +// CHECK: %[[T14:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> +// CHECK: %[[T15:.*]] = "mhlo.real_dynamic_slice"(%[[T0]], %[[T12]], %[[T13]], %[[T14]]) : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T16:.*]] = mhlo.convert %[[T15]] : tensor +// CHECK: %[[T17:.*]] = torch_c.from_builtin_tensor %[[T16]] : tensor -> !torch.vtensor<[?,?,?],f32> +// CHECK: return %[[T17]] : !torch.vtensor<[?,?,?],f32> +func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %none = torch.constant.none + %0 = torch.aten.slice.Tensor %arg0, %int1, %none, %none, %int2 : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.none, !torch.none, !torch.int -> !torch.vtensor<[?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.slice.none.static$slice_like( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<4x65x256xf32> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x65x256xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C1_0:.*]] = arith.constant 1 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1_0]] : tensor<4x65x256xf32> +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T8:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<4x65x256xf32> +// CHECK: %[[T9:.*]] = arith.index_cast %[[T8]] : index to i64 +// CHECK: %[[C0_I64_1:.*]] = arith.constant 0 : i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T10:.*]] = arith.cmpi eq, %[[T3]], %[[C0_I64_1]] : i64 +// CHECK: %[[T11:.*]] = arith.select %[[T10]], %[[T7]], %[[T3]] : i64 +// CHECK: %[[T12:.*]] = tensor.from_elements %[[C0_I64_1]], %[[C0_I64]], %[[C0_I64_1]] : tensor<3xi64> +// CHECK: %[[T13:.*]] = tensor.from_elements %[[T5]], %[[T11]], %[[T9]] : tensor<3xi64> +// CHECK: %[[T14:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> +// CHECK: %[[T15:.*]] = "mhlo.real_dynamic_slice"(%[[T0]], %[[T12]], %[[T13]], %[[T14]]) : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x?x256xf32> +// CHECK: %[[T16:.*]] = mhlo.convert(%[[T15]]) : (tensor<4x?x256xf32>) -> tensor<4x33x256xf32> +// CHECK: %[[T17:.*]] = torch_c.from_builtin_tensor %[[T16]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32> +// CHECK: return %[[T17]] : !torch.vtensor<[4,33,256],f32> +func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %none = torch.constant.none + %0 = torch.aten.slice.Tensor %arg0, %int1, %none, %none, %int2 : !torch.vtensor<[4,65,256],f32>, !torch.int, !torch.none, !torch.none, !torch.int -> !torch.vtensor<[4,33,256],f32> + return %0 : !torch.vtensor<[4,33,256],f32> +} + +// CHECK-LABEL: func.func @torch.aten.view$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor -// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 +// CHECK: %[[INTneg1:.*]] = torch.constant.int -1 // CHECK: %[[INT224:.*]] = torch.constant.int 224 -// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT]]-1, %[[INT]]224 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT]]-1 +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INTneg1]], %[[INT224]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INTneg1]] // CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT224]] -// CHECK: %[[T4:.*]] = tensor.from_elements %[[T2]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "chlo.dynamic_reshape"(%[[T0]], %[[T4]]) : (tensor, tensor<2xi64>) -> tensor -// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,224],f32> -// CHECK: return %[[T6]] : !torch.vtensor<[?,224],f32> -func.func @torch.aten.view$view_like(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> { +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T4:.*]] = arith.muli %[[C1_I64]], %[[T2]] : i64 +// CHECK: %[[T5:.*]] = arith.muli %[[T4]], %[[T3]] : i64 +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : i64 to index +// CHECK: %[[T7:.*]] = tensor.from_elements %[[T2]], %[[T3]] : tensor<2xi64> +// CHECK: %[[T8:.*]] = mhlo.compute_reshape_shape %[[T6]], %[[T7]] : index, tensor<2xi64> -> tensor<2xi64> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T8]]) : (tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,224],f32> +// CHECK: return %[[T10]] : !torch.vtensor<[?,224],f32> +func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> { %int-1 = torch.constant.int -1 %int224 = torch.constant.int 224 %0 = torch.prim.ListConstruct %int-1, %int224 : (!torch.int, !torch.int) -> !torch.list @@ -20,24 +322,30 @@ func.func @torch.aten.view$view_like(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !t return %1 : !torch.vtensor<[?,224],f32> } -// ----- -// CHECK-LABEL: func.func @torch.aten.reshape$view_like( +// CHECK-LABEL: func.func @torch.aten.reshape$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?,?],f32> -> tensor -// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 +// CHECK: %[[INTneg1:.*]] = torch.constant.int -1 // CHECK: %[[INT120:.*]] = torch.constant.int 120 // CHECK: %[[INT4:.*]] = torch.constant.int 4 // CHECK: %[[INT64:.*]] = torch.constant.int 64 -// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT]]-1, %[[INT]]120, %[[INT]]4, %[[INT]]64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT]]-1 +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INTneg1]], %[[INT120]], %[[INT4]], %[[INT64]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INTneg1]] // CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT120]] // CHECK: %[[T4:.*]] = torch_c.to_i64 %[[INT4]] // CHECK: %[[T5:.*]] = torch_c.to_i64 %[[INT64]] -// CHECK: %[[T6:.*]] = tensor.from_elements %[[T2]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64> -// CHECK: %[[T7:.*]] = "chlo.dynamic_reshape"(%[[T0]], %[[T6]]) : (tensor, tensor<4xi64>) -> tensor -// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor -> !torch.vtensor<[?,120,4,64],f32> -// CHECK: return %[[T8]] : !torch.vtensor<[?,120,4,64],f32> -func.func @torch.aten.reshape$view_like(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> { +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T6:.*]] = arith.muli %[[C1_I64]], %[[T2]] : i64 +// CHECK: %[[T7:.*]] = arith.muli %[[T6]], %[[T3]] : i64 +// CHECK: %[[T8:.*]] = arith.muli %[[T7]], %[[T4]] : i64 +// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T5]] : i64 +// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : i64 to index +// CHECK: %[[T11:.*]] = tensor.from_elements %[[T2]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64> +// CHECK: %[[T12:.*]] = mhlo.compute_reshape_shape %[[T10]], %[[T11]] : index, tensor<4xi64> -> tensor<4xi64> +// CHECK: %[[T13:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T12]]) : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]] : tensor -> !torch.vtensor<[?,120,4,64],f32> +// CHECK: return %[[T14]] : !torch.vtensor<[?,120,4,64],f32> +func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> { %int-1 = torch.constant.int -1 %int120 = torch.constant.int 120 %int4 = torch.constant.int 4 @@ -47,24 +355,29 @@ func.func @torch.aten.reshape$view_like(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) return %1 : !torch.vtensor<[?,120,4,64],f32> } -// ----- -// CHECK-LABEL: func.func @torch.aten.view.minus1$view_like( +// CHECK-LABEL: func.func @torch.aten.view$minus1( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3,?,?],f32>) -> !torch.vtensor<[2,3,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3,?,?],f32> -> tensor<2x3x?x?xf32> -// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 +// CHECK: %[[INTneg1:.*]] = torch.constant.int -1 // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[T1:.*]] = torch.aten.size.int %[[ARG0]], %[[INT0]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int // CHECK: %[[T2:.*]] = torch.aten.size.int %[[ARG0]], %[[INT1]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]], %[[INT]]-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]], %[[INTneg1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[T4:.*]] = torch_c.to_i64 %[[T1]] // CHECK: %[[T5:.*]] = torch_c.to_i64 %[[T2]] -// CHECK: %[[T6:.*]] = torch_c.to_i64 %[[INT]]-1 -// CHECK: %[[T7:.*]] = tensor.from_elements %[[T4]], %[[T5]], %[[T6]] : tensor<3xi64> -// CHECK: %[[T8:.*]] = "chlo.dynamic_reshape"(%[[T0]], %[[T7]]) : (tensor<2x3x?x?xf32>, tensor<3xi64>) -> tensor<2x3x?xf32> -// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<2x3x?xf32> -> !torch.vtensor<[2,3,?],f32> -// CHECK: return %[[T9]] : !torch.vtensor<[2,3,?],f32> -func.func @torch.aten.view.minus1$view_like(%arg0: !torch.vtensor<[2,3,?,?],f32>) -> !torch.vtensor<[2,3,?],f32> { +// CHECK: %[[T6:.*]] = torch_c.to_i64 %[[INTneg1]] +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T7:.*]] = arith.muli %[[C1_I64]], %[[T4]] : i64 +// CHECK: %[[T8:.*]] = arith.muli %[[T7]], %[[T5]] : i64 +// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T6]] : i64 +// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : i64 to index +// CHECK: %[[T11:.*]] = tensor.from_elements %[[T4]], %[[T5]], %[[T6]] : tensor<3xi64> +// CHECK: %[[T12:.*]] = mhlo.compute_reshape_shape %[[T10]], %[[T11]] : index, tensor<3xi64> -> tensor<3xi64> +// CHECK: %[[T13:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T12]]) : (tensor<2x3x?x?xf32>, tensor<3xi64>) -> tensor<2x3x?xf32> +// CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]] : tensor<2x3x?xf32> -> !torch.vtensor<[2,3,?],f32> +// CHECK: return %[[T14]] : !torch.vtensor<[2,3,?],f32> +func.func @torch.aten.view$minus1(%arg0: !torch.vtensor<[2,3,?,?],f32>) -> !torch.vtensor<[2,3,?],f32> { %int-1 = torch.constant.int -1 %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 @@ -75,8 +388,7 @@ func.func @torch.aten.view.minus1$view_like(%arg0: !torch.vtensor<[2,3,?,?],f32> return %3 : !torch.vtensor<[2,3,?],f32> } -// ----- -// CHECK-LABEL: func.func @torch.aten.view.to_rank1$view_like( +// CHECK-LABEL: func.func @torch.aten.view$to_rank1( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[],f32> -> tensor // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -84,22 +396,20 @@ func.func @torch.aten.view.minus1$view_like(%arg0: !torch.vtensor<[2,3,?,?],f32> // CHECK: %[[T2:.*]] = "mhlo.reshape"(%[[T0]]) : (tensor) -> tensor<1xf32> // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: return %[[T3]] : !torch.vtensor<[1],f32> -func.func @torch.aten.view.to_rank1$view_like(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { +func.func @torch.aten.view$to_rank1(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { %int1 = torch.constant.int 1 %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[],f32>, !torch.list -> !torch.vtensor<[1],f32> return %1 : !torch.vtensor<[1],f32> } - -// ----- -// CHECK-LABEL: func.func @torch.aten.view.to_rank0$view_like( +// CHECK-LABEL: func.func @torch.aten.view$to_rank0( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1],f32> -> tensor<1xf32> // CHECK: %[[T1:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[T2:.*]] = "mhlo.reshape"(%[[T0]]) : (tensor<1xf32>) -> tensor // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[],f32> // CHECK: return %[[T3]] : !torch.vtensor<[],f32> -func.func @torch.aten.view.to_rank0$view_like(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> { +func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> { %0 = torch.prim.ListConstruct : () -> !torch.list %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[],f32> return %1 : !torch.vtensor<[],f32>