From 4d9dac4a4ac8a3eb5d8abaf462f4b719fd8b6a23 Mon Sep 17 00:00:00 2001 From: Luyang Date: Mon, 18 Jul 2022 23:31:52 +0800 Subject: [PATCH] handle non-contiguous input (#8665) * handle non-contiguous input * refine * auto format by CI Co-authored-by: oneflow-ci-bot --- oneflow/ir/include/OneFlow/OneFlowUserOps.td | 2 +- python/oneflow/test/modules/test_cast.py | 3 ++- python/oneflow/test/modules/test_tensor_ops.py | 11 +++++++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index 4ea3cda8269..8a3d9053140 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -8405,7 +8405,7 @@ def OneFlow_BernoulliOp : OneFlow_BaseOp<"bernoulli", [NoSideEffect, NoGrad, Cpu let has_data_type_infer_fn = 1; } -def OneFlow_CastOp : OneFlow_BaseOp<"cast", [NoSideEffect, SupportNonContiguous, DeclareOpInterfaceMethods]> { +def OneFlow_CastOp : OneFlow_BaseOp<"cast", [NoSideEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in ); diff --git a/python/oneflow/test/modules/test_cast.py b/python/oneflow/test/modules/test_cast.py index c6017cb17c4..8689f3860d8 100644 --- a/python/oneflow/test/modules/test_cast.py +++ b/python/oneflow/test/modules/test_cast.py @@ -51,7 +51,8 @@ def _test_cast_with_non_contiguous_input(test_case, device, shape): output = flow.cast(input, flow.float32) np_out = np_arr.astype(np.float32).transpose(permute_dims) test_case.assertTrue(np.array_equal(output.numpy(), np_out)) - test_case.assertTrue(input.stride() == output.stride()) + # TODO:when cast kernel support stride + # test_case.assertTrue(input.stride() == output.stride()) def _test_cast_backward(test_case, device, shape): diff --git a/python/oneflow/test/modules/test_tensor_ops.py b/python/oneflow/test/modules/test_tensor_ops.py index 07d3252a614..88e2479f013 100644 --- a/python/oneflow/test/modules/test_tensor_ops.py +++ b/python/oneflow/test/modules/test_tensor_ops.py @@ -18,6 +18,7 @@ from collections import OrderedDict import numpy as np +from random import shuffle from oneflow.test_utils.test_util import GenArgList import oneflow as flow @@ -154,6 +155,16 @@ def test_long_0dim(test_case): y = x.long() return y + @autotest(n=5, auto_backward=False) + def test_long_with_non_contiguous_input(test_case): + device = random_device() + permute_list = list(range(4)) + shuffle(permute_list) + input = random_tensor(ndim=4).to(device) + x = input.permute(permute_list) + y = x.long() + return y + @autotest(n=20, auto_backward=False, rtol=1e-4, atol=1e-4, check_graph=True) def test_int(test_case): device = random_device()