diff --git a/oneflow/user/kernels/cum_backward_kernel.cpp b/oneflow/user/kernels/cum_backward_kernel.cpp index 6440c8f9e91..6e6967d1daa 100644 --- a/oneflow/user/kernels/cum_backward_kernel.cpp +++ b/oneflow/user/kernels/cum_backward_kernel.cpp @@ -45,11 +45,16 @@ void CumProdBackward(const T* dy_ptr, T* dx_ptr, const T* output_ptr, const T* i } for (size_t j = 0; j < down_space; j++) { - auto* cumsum_zeros_number_ptr = j + dx_ptr_base; + const auto* cur_output_ptr = output_ptr_base + j; + const auto* cur_input_ptr = input_ptr_base + j; + const auto* cur_dy_ptr = dy_ptr_base + j; + auto* cur_dx_ptr = dx_ptr_base + j; + const auto* cumsum_zeros_number_ptr = dx_ptr_base + j; + size_t first_zero_index = space; // Find index of first zero in input. for (size_t k = 0; k < space; k++) { - if (cumsum_zeros_number_ptr[j + k * down_space] == 1) { + if (cumsum_zeros_number_ptr[k * down_space] == 1) { first_zero_index = k; break; } @@ -58,24 +63,24 @@ void CumProdBackward(const T* dy_ptr, T* dx_ptr, const T* output_ptr, const T* i // for element which index is less than z grad is computed as below: T reverse_cumsum = 0; for (size_t k = 0; k < first_zero_index; k++) { - const size_t data_offset = j + (first_zero_index - k - 1) * down_space; - reverse_cumsum += output_ptr_base[data_offset] * dy_ptr_base[data_offset]; - dx_ptr_base[data_offset] = reverse_cumsum / input_ptr_base[data_offset]; + const size_t data_offset = (first_zero_index - k - 1) * down_space; + reverse_cumsum += cur_output_ptr[data_offset] * cur_dy_ptr[data_offset]; + cur_dx_ptr[data_offset] = reverse_cumsum / cur_input_ptr[data_offset]; } // For where index is z, its grad is computed as below: if (first_zero_index == space) { continue; } T cumprod = 1; T cumsum = 0; T cumprod_before_first_zero = - first_zero_index == 0 ? 1 : output_ptr_base[(first_zero_index - 1) * down_space]; + first_zero_index == 0 ? 1 : cur_output_ptr[(first_zero_index - 1) * down_space]; for (size_t k = first_zero_index; k < space; k++) { - const size_t data_offset = j + k * down_space; + const size_t data_offset = k * down_space; // Recover dx_ptr default value - if (dx_ptr_base[data_offset] >= 1) { dx_ptr_base[data_offset] = 0; } - if (k != first_zero_index) { cumprod *= input_ptr_base[data_offset]; } - cumsum += cumprod_before_first_zero * dy_ptr_base[data_offset] * cumprod; + if (cur_dx_ptr[data_offset] >= 1) { cur_dx_ptr[data_offset] = 0; } + if (k != first_zero_index) { cumprod *= cur_input_ptr[data_offset]; } + cumsum += cumprod_before_first_zero * cumprod * cur_dy_ptr[data_offset]; } - dx_ptr_base[j + first_zero_index * down_space] = cumsum; + cur_dx_ptr[first_zero_index * down_space] = cumsum; } } } diff --git a/python/oneflow/test/modules/test_cum_ops.py b/python/oneflow/test/modules/test_cum_ops.py index 6f366e37259..2088440a292 100644 --- a/python/oneflow/test/modules/test_cum_ops.py +++ b/python/oneflow/test/modules/test_cum_ops.py @@ -15,9 +15,11 @@ """ import unittest from collections import OrderedDict +import numpy as np import oneflow as flow import oneflow.unittest +import torch as ori_torch from oneflow.test_utils.automated_test_util import * @@ -64,6 +66,29 @@ def test_cumprod_with_user_dy(test_case): z = y * 2 return z + def test_cumprod_with_zero(test_case): + np_arr = np.ones((5, 5)) + np_arr_grad = np_arr + np_arr[2][3] = 0 + np_arr[4][3] = 0 + of_tensor = flow.tensor(np_arr, dtype=flow.float, requires_grad=True) + of_res = of_tensor.cumprod(dim=0) + of_res.backward(flow.tensor(np_arr_grad, dtype=flow.float)) + + torch_tensor = ori_torch.tensor( + np_arr, dtype=ori_torch.float, requires_grad=True + ) + torch_res = torch_tensor.cumprod(dim=0) + torch_res.backward(ori_torch.tensor(np_arr_grad, dtype=ori_torch.float)) + test_case.assertTrue( + np.allclose( + of_tensor.grad.numpy(), + torch_tensor.grad.numpy(), + rtol=0.0001, + atol=1e-05, + ) + ) + if __name__ == "__main__": unittest.main()