diff --git a/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp b/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp index 0b3c367b5a5..d278728cc37 100644 --- a/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp +++ b/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp @@ -142,7 +142,7 @@ class BroadcastMul : public BroadcastBinaryGrad { in_grads->resize(2); if (ctx->x_requires_grad) { const auto& y = ctx->SavedTensors().at(ctx->y_index); - const auto& x_grad = JUST(functional::Mul(out_grads.at(0), y)); + const auto& x_grad = JUST(functional::Mul(out_grads.at(0), JUST(functional::Conj(y)))); if (ctx->broadcast_x) { const auto& x = ctx->SavedTensors().at(ctx->x_index); in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(x_grad, x)); @@ -152,7 +152,7 @@ class BroadcastMul : public BroadcastBinaryGrad { } if (ctx->y_requires_grad) { const auto& x = ctx->SavedTensors().at(ctx->x_index); - const auto& y_grad = JUST(functional::Mul(out_grads.at(0), x)); + const auto& y_grad = JUST(functional::Mul(out_grads.at(0), JUST(functional::Conj(x)))); if (ctx->broadcast_y) { const auto& y = ctx->SavedTensors().at(ctx->y_index); in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(y_grad, y)); diff --git a/oneflow/core/framework/dtype.cpp b/oneflow/core/framework/dtype.cpp index 33169d013c5..cd0aca6b751 100644 --- a/oneflow/core/framework/dtype.cpp +++ b/oneflow/core/framework/dtype.cpp @@ -226,24 +226,24 @@ Symbol promoteTypes(const Symbol a, const Symbol b) { static const Symbol _promoteTypesLookup[DataType_ARRAYSIZE][DataType_ARRAYSIZE] = { /* iv c1 f4 f8 i1 i4 i8 u1 re f2 bu bf b1 u2 u4 u8 u16 i2 i16 cp4 cp8 cp16 */ /* iv */ {iv, c1, f4, f8, i1, i4, i8, u1, re, f2, bu, bf, b1, u2, u4, u8, u16, i2, i16, cp4, cp8, cp16}, - /* c1 */ {c1, c1, f4, f8, i1, i4, i8, c1, iv, f2, iv, bf, c1, u2, u4, u8, u16, i2, i16, iv, cp4, cp16}, - /* f4 */ {f4, f4, f4, f8, f4, f4, f4, f4, iv, f4, iv, bf, f4, f4, f4, f4, f4, f4, f4, iv, cp4, cp16}, - /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, iv, f8, iv, bf, f8, f8, f8, f8, f8, f8, f8, iv, cp4, cp16}, - /* i1 */ {i1, i1, f4, f8, i1, i4, i8, i2, iv, f2, iv, bf, i1, i4, i8, i16, iv, i2, i16, iv, cp4, cp16}, - /* i4 */ {i4, i4, f4, f8, i4, i4, i8, i4, iv, f2, iv, bf, i4, i4, i8, i16, iv, i4, i16, iv, cp4, cp16}, - /* i8 */ {i8, i8, f4, f8, i8, i8, i8, i8, iv, f2, iv, bf, i8, i8, i8, i16, iv, i8, i16, iv, cp4, cp16}, - /* u1 */ {u1, c1, f4, f8, i2, i4, i8, u1, iv, f2, iv, bf, u1, u2, u4, u8, u16, i2, i16, iv, cp4, cp16}, + /* c1 */ {c1, c1, f4, f8, i1, i4, i8, c1, iv, f2, iv, bf, c1, u2, u4, u8, u16, i2, i16, iv, cp8, cp16}, + /* f4 */ {f4, f4, f4, f8, f4, f4, f4, f4, iv, f4, iv, bf, f4, f4, f4, f4, f4, f4, f4, iv, cp8, cp16}, + /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, iv, f8, iv, bf, f8, f8, f8, f8, f8, f8, f8, iv, cp8, cp16}, + /* i1 */ {i1, i1, f4, f8, i1, i4, i8, i2, iv, f2, iv, bf, i1, i4, i8, i16, iv, i2, i16, iv, cp8, cp16}, + /* i4 */ {i4, i4, f4, f8, i4, i4, i8, i4, iv, f2, iv, bf, i4, i4, i8, i16, iv, i4, i16, iv, cp8, cp16}, + /* i8 */ {i8, i8, f4, f8, i8, i8, i8, i8, iv, f2, iv, bf, i8, i8, i8, i16, iv, i8, i16, iv, cp8, cp16}, + /* u1 */ {u1, c1, f4, f8, i2, i4, i8, u1, iv, f2, iv, bf, u1, u2, u4, u8, u16, i2, i16, iv, cp8, cp16}, /* re */ {iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv}, - /* f2 */ {f2, f2, f4, f8, f2, f2, f2, f2, iv, f2, iv, bf, f2, f2, f2, f2, iv, f2, f2, iv, cp4, cp16}, + /* f2 */ {f2, f2, f4, f8, f2, f2, f2, f2, iv, f2, iv, bf, f2, f2, f2, f2, iv, f2, f2, iv, cp8, cp16}, /* bu */ {iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, bu, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv}, - /* bf */ {bf, bf, bf, bf, bf, bf, bf, bf, iv, bf, iv, bf, bf, bf, bf, bf, iv, bf, bf, iv, cp4, cp16}, - /* b1 */ {b1, c1, f4, f8, i1, i4, i8, u1, iv, f2, iv, bf, b1, u2, u4, u8, u16, i2, i16, iv, cp4, cp16}, - /* u2 */ {u2, u2, f4, f8, i4, i4, i8, u2, iv, f2, iv, bf, u2, u2, u4, u8, u16, i4, i16, iv, cp4, cp16}, - /* u4 */ {u4, u4, f4, f8, i8, i8, i8, u4, iv, f2, iv, bf, u4, u4, u4, u8, u16, i8, i16, iv, cp4, cp16}, - /* u8 */ {u8, u8, f4, f8, i16, i16, i16, u8, iv, f2, iv, bf, u8, u8, u8, u8, u16, i16, i16, iv, cp4, cp16}, - /* u16 */ {u16, u16, f4, f8, iv, iv, iv, u16, iv, f2, iv, bf, u16, u16, u16, u16, u16, iv, iv, iv, cp4, cp16}, - /* i2 */ {i2, i2, f4, f8, i2, i4, i8, i2, iv, f2, iv, bf, i2, i4, i8, i16, iv, i2, i16, iv, cp4, cp16}, - /* i16 */ {i16, i16, f4, f8, i16, i16, i16, i16, iv, f2, iv, bf, i16, i16, i16, i16, iv, i16, i16, iv, cp4, cp16}, + /* bf */ {bf, bf, bf, bf, bf, bf, bf, bf, iv, bf, iv, bf, bf, bf, bf, bf, iv, bf, bf, iv, cp8, cp16}, + /* b1 */ {b1, c1, f4, f8, i1, i4, i8, u1, iv, f2, iv, bf, b1, u2, u4, u8, u16, i2, i16, iv, cp8, cp16}, + /* u2 */ {u2, u2, f4, f8, i4, i4, i8, u2, iv, f2, iv, bf, u2, u2, u4, u8, u16, i4, i16, iv, cp8, cp16}, + /* u4 */ {u4, u4, f4, f8, i8, i8, i8, u4, iv, f2, iv, bf, u4, u4, u4, u8, u16, i8, i16, iv, cp8, cp16}, + /* u8 */ {u8, u8, f4, f8, i16, i16, i16, u8, iv, f2, iv, bf, u8, u8, u8, u8, u16, i16, i16, iv, cp8, cp16}, + /* u16 */ {u16, u16, f4, f8, iv, iv, iv, u16, iv, f2, iv, bf, u16, u16, u16, u16, u16, iv, iv, iv, cp8, cp16}, + /* i2 */ {i2, i2, f4, f8, i2, i4, i8, i2, iv, f2, iv, bf, i2, i4, i8, i16, iv, i2, i16, iv, cp8, cp16}, + /* i16 */ {i16, i16, f4, f8, i16, i16, i16, i16, iv, f2, iv, bf, i16, i16, i16, i16, iv, i16, i16, iv, cp8, cp16}, /* cp4 */ {iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, cp4, cp8, cp16}, /* cp8 */ {cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, iv, cp8, iv, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp16}, /* cp16 */ {cp16,cp16,cp16,cp16,cp16,cp16,cp16,cp16,iv, cp16,iv, cp16,cp16,cp16,cp16,cp16,cp16, cp16,cp16, cp16, cp16, cp16}}; diff --git a/oneflow/core/functional/impl/math_functor.cpp b/oneflow/core/functional/impl/math_functor.cpp index 291a68613e9..275757fbb27 100644 --- a/oneflow/core/functional/impl/math_functor.cpp +++ b/oneflow/core/functional/impl/math_functor.cpp @@ -85,7 +85,7 @@ class ScalarMathBaseFunctor { "int_operand", "has_int_operand"); TensorProcessor tensor_processor; Symbol lowest_dtype; - if (scalar.IsFloatingPoint()) { + if (scalar.IsFloatingPoint() || scalar.IsComplex()) { attrs.SetAllAttrs(scalar.As(), true, NullOpt, false); // Only promote type to Float32 when tensor is Int type but scalar is float type. if (DType::priority_order[x->dtype()->data_type()] @@ -797,8 +797,9 @@ class ReduceMeanWholeFunctor { ReduceMeanWholeFunctor() {} Maybe operator()(const std::shared_ptr& x) const { // ReduceMean only calculate floating values. - CHECK_OR_RETURN(IsFloatingDataType(x->dtype()->data_type())) - << "RuntimeError: Can only calculate the mean of floating types."; + CHECK_OR_RETURN(IsFloatingDataType(x->dtype()->data_type()) + || IsComplexDataType(x->dtype()->data_type())) + << "RuntimeError: Can only calculate the mean of floating types or complex types."; size_t reduce_count = 1; reduce_count = x->shape()->Count(0); const auto& sum = JUST(functional::ReduceSumWhole(x, NullOpt)); diff --git a/oneflow/core/functional/tensor_processor.cpp b/oneflow/core/functional/tensor_processor.cpp index e242003a40f..b13ebc57507 100644 --- a/oneflow/core/functional/tensor_processor.cpp +++ b/oneflow/core/functional/tensor_processor.cpp @@ -35,7 +35,10 @@ Symbol ComputeCommonDType(const TensorTuple& tensor_tuple) { [](const std::shared_ptr& tensor) { return tensor->shape()->NumAxes() == 0; }); for (auto& tensor_ptr : tensor_tuple) { // skip scalar tensor - if (!all_scalar_tensors && tensor_ptr->shape()->NumAxes() == 0) { continue; } + if (!all_scalar_tensors && tensor_ptr->shape()->NumAxes() == 0 + && !(tensor_ptr->dtype()->is_complex())) { + continue; + } common_dtype = promoteTypes(tensor_ptr->dtype(), common_dtype); } return common_dtype; @@ -114,7 +117,9 @@ Maybe TensorProcessor::Apply() { } JUST(CastToSameType(tensor_tuple_, common_dtype_)); } else { - if (tensor_tuple_.size() == 1 && !tensor_tuple_[0]->dtype()->is_floating_point()) { + if (tensor_tuple_.size() == 1 + && !((tensor_tuple_[0]->dtype()->is_floating_point()) + || tensor_tuple_[0]->dtype()->is_complex())) { Symbol cast_dtype = (inputs_lowest_dtype_vec_[0] == DType::InvalidDataType()) ? DType::Float() : inputs_lowest_dtype_vec_[0]; diff --git a/oneflow/user/kernels/reduce_like_kernels.cpp b/oneflow/user/kernels/reduce_like_kernels.cpp index bb202a8fa96..b16cb2c3b73 100644 --- a/oneflow/user/kernels/reduce_like_kernels.cpp +++ b/oneflow/user/kernels/reduce_like_kernels.cpp @@ -124,6 +124,17 @@ class ReduceSumLikeOpKernel final : public user_op::OpKernel, public user_op::Cu OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL, DEVICE_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ) +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL, + OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), COMPLEX_DATA_TYPE_SEQ); +#if defined(WITH_CUDA) +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL, + OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), + OF_PP_MAKE_TUPLE_SEQ(cuComplex, DataType::kComplex64)); +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL, + OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), + OF_PP_MAKE_TUPLE_SEQ(cuDoubleComplex, DataType::kComplex128)); +#endif // WITH_CUDA + #if defined(WITH_CUDA) namespace { diff --git a/python/oneflow/test/modules/test_add.py b/python/oneflow/test/modules/test_add.py index 931149556dd..3788f893489 100644 --- a/python/oneflow/test/modules/test_add.py +++ b/python/oneflow/test/modules/test_add.py @@ -16,7 +16,7 @@ import unittest from collections import OrderedDict - +import torch as torch_original import numpy as np from oneflow.test_utils.test_util import GenArgList @@ -190,7 +190,7 @@ def test_add(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(n=5) + @autotest(n=10) def test_0_size_add(test_case): device = random_device() x = random_tensor(2, 0, 3).to(device) @@ -198,7 +198,7 @@ def test_0_size_add(test_case): out = x + y return out - @autotest(n=3, auto_backward=False) + @autotest(n=6, auto_backward=False) def test_0dim_inplace_add(test_case): device = random_device() x = random_tensor(2, 2, 3, requires_grad=False).to(device) @@ -206,7 +206,7 @@ def test_0dim_inplace_add(test_case): x += y.mean() return x - @autotest(n=5) + @autotest(n=10) def test_0dim_two_inplace_add(test_case): device = random_device() x = random_tensor(2, 2, 3).to(device).mean() @@ -214,7 +214,7 @@ def test_0dim_two_inplace_add(test_case): x += y.mean() return x - @autotest(n=3) + @autotest(n=6) def test_add_with_alpha(test_case): device = random_device() x1 = random_tensor(2, 2, 3).to(device).mean() @@ -260,7 +260,7 @@ def test_0dim_two_inplace_add(test_case): return x x += y.mean().to(torch.bool) - @autotest(n=3) + @autotest(n=6) def test_add_with_alpha_0dim(test_case): device = random_device() x1 = random_tensor(ndim=0).to(device).mean() @@ -279,7 +279,7 @@ def profile_add(test_case): torch.add(torch.ones(100), 20) torch.add(torch.ones(100), torch.ones(100, 1), alpha=10) - @autotest(n=3) + @autotest(n=6) def test_non_contiguous_inplace_add(test_case): device = random_device() x = random_tensor(2, 2, 4).to(device) @@ -288,7 +288,7 @@ def test_non_contiguous_inplace_add(test_case): y += random_tensor(2, 2, 2).to(device) return y - @autotest(n=5) + @autotest(n=10) def test_scalar_add_with_random_devices(test_case): x1_device = random_device() x2_device = random_device() diff --git a/python/oneflow/test/modules/test_cast.py b/python/oneflow/test/modules/test_cast.py index 48c1134ef1f..0023d482c2f 100644 --- a/python/oneflow/test/modules/test_cast.py +++ b/python/oneflow/test/modules/test_cast.py @@ -19,6 +19,7 @@ from collections import OrderedDict import numpy as np +import torch as torch_original import oneflow as flow import oneflow.unittest @@ -174,6 +175,13 @@ def test_cast_with_scalar_input(test_case): z = y.to(dtype=torch.int8, device=device) return z + @autotest(n=5, auto_backward=True, include_complex=False, atol=1e-5, rtol=1e-5) + def test_cast_with_complex_float2complex(test_case): + device = random_device() + x = random_tensor().to(dtype=torch.float32, device=device) + y = x.to(torch.complex64) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_constant_pad.py b/python/oneflow/test/modules/test_constant_pad.py index d713196c7d7..593ea1dd090 100644 --- a/python/oneflow/test/modules/test_constant_pad.py +++ b/python/oneflow/test/modules/test_constant_pad.py @@ -112,7 +112,7 @@ def test_constantpad3d_with_random_data(test_case): return y @autotest(n=10, rtol=0.001, atol=0.001, auto_backward=False) - def test_constantpad3d_with_random_data(test_case): + def test_constantpad3d_with_random_int_data(test_case): dtype = choice([bool, int]) value = random(0, 2).to(bool) if dtype is bool else random().to(int) m = torch.nn.ConstantPad3d(padding=random(1, 6).to(_size_6_t), value=value,) diff --git a/python/oneflow/test/modules/test_equal.py b/python/oneflow/test/modules/test_equal.py index c5e6ee0c42d..b6daf7dc5b1 100644 --- a/python/oneflow/test/modules/test_equal.py +++ b/python/oneflow/test/modules/test_equal.py @@ -18,6 +18,7 @@ from collections import OrderedDict import numpy as np +import torch as torch_original from oneflow.test_utils.test_util import GenArgList import oneflow as flow @@ -28,7 +29,7 @@ @flow.unittest.skip_unless_1n1d() class TestEqual(flow.unittest.TestCase): - @autotest(n=5, auto_backward=False, check_graph=False) + @autotest(n=5, auto_backward=False, check_graph=False, include_complex=True) def test_eq_with_0_size_data(test_case): device = random_device() x = random_tensor(3, 2, 0, 3).to(device) @@ -75,6 +76,15 @@ def test_flow_equal_with_same_random_data(test_case): x = random_tensor(len(shape), *shape, requires_grad=False).to(device) return torch.equal(x, x) + @autotest(n=5, auto_backward=False, check_graph=False, include_complex=True) + def test_flow_equal_complex_with_same_random_data(test_case): + device = random_device() + shape = random_tensor().oneflow.shape + x = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to( + device + ) + return torch.equal(x, x) + @autotest(n=5, auto_backward=False, check_graph=False) def test_flow_equal_bool_with_random_data(test_case): device = random_device() @@ -87,6 +97,30 @@ def test_flow_equal_bool_with_random_data(test_case): ) return torch.equal(x, y) + @autotest(n=5, auto_backward=False, check_graph=False, include_complex=True) + def test_flow_equal_complex_with_random_data(test_case): + device = random_device() + shape = random_tensor().oneflow.shape + x = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to( + device=device + ) + y = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to( + device=device + ) + return torch.equal(x, y) + + @autotest(n=5, auto_backward=False, check_graph=False, include_complex=True) + def test_flow_not_equal_complex_with_random_data(test_case): + device = random_device() + shape = random_tensor().oneflow.shape + x = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to( + device=device + ) + y = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to( + device=device + ) + return torch.not_equal(x, y) + @autotest(n=5, auto_backward=False, check_graph=False) def test_flow_equal_with_same_random_0d_data(test_case): device = random_device() diff --git a/python/oneflow/test/modules/test_mul.py b/python/oneflow/test/modules/test_mul.py index 3aa65fd5698..3dcaee390e3 100644 --- a/python/oneflow/test/modules/test_mul.py +++ b/python/oneflow/test/modules/test_mul.py @@ -18,6 +18,7 @@ from collections import OrderedDict import numpy as np +import torch as torch_original from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList @@ -208,7 +209,7 @@ def test_broadcast_mul(test_case): x.mul_(y) return x - @autotest(n=3) + @autotest(n=6) def test_non_contiguous_inplace_mul(test_case): device = random_device() x = random_tensor(2, 2, 4).to(device) @@ -217,7 +218,7 @@ def test_non_contiguous_inplace_mul(test_case): y *= random_tensor(2, 2, 2).to(device) return y - @autotest(n=5) + @autotest(n=10) def test_scalar_mul_with_random_devices(test_case): x1_device = random_device() x2_device = random_device() diff --git a/python/oneflow/test/modules/test_sub.py b/python/oneflow/test/modules/test_sub.py index 3c70ce725c7..2650a162778 100644 --- a/python/oneflow/test/modules/test_sub.py +++ b/python/oneflow/test/modules/test_sub.py @@ -18,6 +18,7 @@ from collections import OrderedDict import numpy as np +import torch as torch_original from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList @@ -156,7 +157,7 @@ def test_sub_with_alpha(test_case): z3 = torch.sub(s, x3, alpha=alpha) return z1, z2, z3 - @autotest(n=3) + @autotest(n=5) def test_non_contiguous_inplace_sub(test_case): device = random_device() x = random_tensor(2, 2, 4).to(device) @@ -166,7 +167,7 @@ def test_non_contiguous_inplace_sub(test_case): return y @unittest.skip("skip for now, becase it failed 2 times in past week") - @autotest(n=5) + @autotest(n=5, include_complex=True) def test_scalar_sub_with_random_devices(test_case): x1_device = random_device() x2_device = random_device() diff --git a/python/oneflow/test/modules/test_sum.py b/python/oneflow/test/modules/test_sum.py index 3d4cf278c33..975dfda75f7 100644 --- a/python/oneflow/test/modules/test_sum.py +++ b/python/oneflow/test/modules/test_sum.py @@ -97,8 +97,29 @@ def test_sum_dtype(test_case): ) return y + @autotest( + n=10, + check_graph=False, + auto_backward=True, + include_complex=True, + atol=1e-2, + rtol=1e-5, + ) + def test_sum_complex_dtype(test_case): + device = random_device() + x = random_tensor(4, dtype=complex, requires_grad=True).to( + device=device, dtype=random_dtype(["complex"]) + ) + y = torch.sum( + x, + dim=np.random.randint(0, 3), + keepdim=random_bool(), + dtype=random_dtype(["complex"]), + ) + return y + @autotest(check_graph=True, auto_backward=False) - def test_sum_whole_dtype(test_case): + def test_sum_arithmetic_dtype(test_case): device = random_device() x = random_tensor(4, requires_grad=False).to(device) y = torch.sum(x, dtype=random_dtype(["arithmetic"])) diff --git a/python/oneflow/test/tensor/test_complex.py b/python/oneflow/test/tensor/test_complex.py index b5f3b29eaa4..6ea6401af62 100644 --- a/python/oneflow/test/tensor/test_complex.py +++ b/python/oneflow/test/tensor/test_complex.py @@ -523,8 +523,12 @@ def test_mul_cpu(self): # backward flow_ret.sum().backward() - compare_result(flow_x.grad.numpy(), flow_y.numpy(), self.rtol, self.atol) - compare_result(flow_y.grad.numpy(), flow_x.numpy(), self.rtol, self.atol) + compare_result( + flow_x.grad.numpy(), flow_y.numpy().conjugate(), self.rtol, self.atol + ) + compare_result( + flow_y.grad.numpy(), flow_x.numpy().conjugate(), self.rtol, self.atol + ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_mul_cuda(self): @@ -549,10 +553,16 @@ def test_mul_cuda(self): # backward flow_ret.sum().backward() compare_result( - flow_x.grad.cpu().detach().numpy(), flow_y.numpy(), self.rtol, self.atol + flow_x.grad.cpu().detach().numpy(), + flow_y.numpy().conjugate(), + self.rtol, + self.atol, ) compare_result( - flow_y.grad.cpu().detach().numpy(), flow_x.numpy(), self.rtol, self.atol + flow_y.grad.cpu().detach().numpy(), + flow_x.numpy().conjugate(), + self.rtol, + self.atol, ) def test_sum_cpu(self): diff --git a/python/oneflow/test_utils/automated_test_util/generators.py b/python/oneflow/test_utils/automated_test_util/generators.py index 24c22159a6e..db90e09622e 100644 --- a/python/oneflow/test_utils/automated_test_util/generators.py +++ b/python/oneflow/test_utils/automated_test_util/generators.py @@ -261,6 +261,10 @@ def _generate(self, annotation): val = float(rng.random() * (high - low) + low) elif annotation == bool: val = random_util.choice([True, False]) + elif annotation == complex: + val_real = float(rng.random() * (high - low) + low) + val_imag = float(rng.random() * (high - low) + low) + val = val_real + 1.0j * val_imag elif annotation is None: val = None elif annotation is NoneType: @@ -425,6 +429,7 @@ class random_pytorch_dtype(generator): floating_dtype_seq = [torch.float, torch.double] half_dtype_seq = [torch.half] bfloat16_dtype_seq = [torch.bfloat16] + complex_dtype_seq = [torch.complex64, torch.complex128] signed_int_dtype_seq = [torch.int8, torch.int32, torch.int64] unsigned_int_dtype_seq = [torch.uint8] int_dtype_seq = [torch.int8, torch.int32, torch.int64] @@ -440,6 +445,7 @@ class random_pytorch_dtype(generator): "float": floating_dtype_seq, "half": half_dtype_seq, "bfloat16": bfloat16_dtype_seq, + "complex": complex_dtype_seq, "signed": signed_int_dtype_seq, "unsigned": unsigned_int_dtype_seq, "int": int_dtype_seq, diff --git a/python/oneflow/test_utils/test_util.py b/python/oneflow/test_utils/test_util.py index 8474517b737..a817f42c910 100644 --- a/python/oneflow/test_utils/test_util.py +++ b/python/oneflow/test_utils/test_util.py @@ -73,6 +73,8 @@ def __repr__(self): "uint8": flow.uint8, "half": flow.half, "bfloat16": flow.bfloat16, + "complex64": flow.complex64, + "complex128": flow.complex128, } type_name_to_np_type = { "float16": np.float16, @@ -82,6 +84,8 @@ def __repr__(self): "int32": np.int32, "int64": np.int64, "uint8": np.uint8, + "complex64": np.complex64, + "complex128": np.complex128, }