diff --git a/oneflow/core/autograd/gradient_funcs/fft.cpp b/oneflow/core/autograd/gradient_funcs/fft.cpp new file mode 100644 index 00000000000..a0705b31862 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/fft.cpp @@ -0,0 +1,201 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include "oneflow/core/common/container_util.h" +#include "oneflow/core/common/optional.h" +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" +#include "oneflow/core/functional/functional_api.yaml.h" + +namespace oneflow { +namespace one { + +struct FftR2CCaptureState : public AutoGradCaptureState { + bool requires_grad = false; + bool onesided = false; + std::vector dims; + DimVector input_shape_vec; + int32_t norm_mode = 0; +}; + +class FftR2C : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } + + Maybe Capture(FftR2CCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 1) << "RuntimeError: assert `inputs.size() == 1`"; + ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad(); + if (!ctx->requires_grad) { return Maybe::Ok(); } + + ctx->onesided = JUST(attrs.GetAttr("onesided")); + ctx->dims = JUST(attrs.GetAttr>("dims")); + ctx->norm_mode = JUST(attrs.GetAttr("norm_mode")); + ctx->input_shape_vec = JUST(oneflow::VectorAt(inputs, 0))->shape()->dim_vec(); + + return Maybe::Ok(); + } + + Maybe Apply(const FftR2CCaptureState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "RuntimeError: assert `out_grads.size() == 1`"; + if (!ctx->requires_grad) { return Maybe::Ok(); } + + in_grads->resize(1); + if (!ctx->onesided) { + auto complex_grad = JUST(functional::FftC2C(JUST(oneflow::VectorAt(out_grads, 0)), NullOpt, + ctx->dims, ctx->norm_mode, + /*forward=*/false, /*normalized=*/false)); + JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::Real(complex_grad)); + } else { + std::vector fft_dims = ctx->dims; + std::vector fft_shapes(fft_dims.size(), 0); + FOR_RANGE(size_t, i, 0, fft_dims.size()) { + fft_shapes[i] = ctx->input_shape_vec[fft_dims[i]]; + } + + // fill the last dim + bool must_copy = false; + auto x_sizes = JUST(oneflow::VectorAt(out_grads, 0))->shape()->dim_vec(); + std::vector pad_amount(x_sizes.size() * 2, 0); + int64_t last_dim = ctx->dims.back(); + if (x_sizes[last_dim] < ctx->input_shape_vec[last_dim]) { + must_copy = true; + auto pad_idx = pad_amount.size() - 2 * last_dim - 1; + pad_amount[pad_idx] = ctx->input_shape_vec[last_dim] - x_sizes[last_dim]; + } + auto complex_full_grad = + must_copy + ? JUST(functional::ConstantPad(JUST(oneflow::VectorAt(out_grads, 0)), pad_amount, 0)) + : JUST(oneflow::VectorAt(out_grads, 0)); + complex_full_grad = + JUST(functional::FftC2C(complex_full_grad, NullOpt, ctx->dims, ctx->norm_mode, + /*forward=*/false, /*normalized=*/false)); + + JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::Real(complex_full_grad)); + } + + return Maybe::Ok(); + } +}; + +struct FftC2CCaptureState : public AutoGradCaptureState { + bool requires_grad = false; + bool forward = false; + std::vector dims; + int32_t norm_mode = 0; +}; + +class FftC2C : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } + + Maybe Capture(FftC2CCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 1) << "RuntimeError: assert `inputs.size() == 1`"; + + ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad(); + if (!ctx->requires_grad) { return Maybe::Ok(); } + + ctx->forward = JUST(attrs.GetAttr("forward")); + ctx->dims = JUST(attrs.GetAttr>("dims")); + ctx->norm_mode = JUST(attrs.GetAttr("norm_mode")); + + return Maybe::Ok(); + } + + Maybe Apply(const FftC2CCaptureState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "RuntimeError: assert `out_grads.size() == 1`"; + if (!ctx->requires_grad) { return Maybe::Ok(); } + + in_grads->resize(1); + JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::FftC2C( + JUST(oneflow::VectorAt(out_grads, 0)), NullOpt, ctx->dims, ctx->norm_mode, + /*forward=*/!(ctx->forward), /*normalized=*/false)); + return Maybe::Ok(); + } +}; + +struct FftC2RCaptureState : public AutoGradCaptureState { + bool requires_grad = false; + std::vector dims; + int32_t norm_mode = 0; + int64_t last_dim_size = 1; + DimVector input_shape_vec; +}; + +class FftC2R : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } + + Maybe Capture(FftC2RCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 1) << "RuntimeError: assert `inputs.size() == 1`"; + ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad(); + if (!ctx->requires_grad) { return Maybe::Ok(); } + + ctx->dims = JUST(attrs.GetAttr>("dims")); + ctx->norm_mode = JUST(attrs.GetAttr("norm_mode")); + ctx->last_dim_size = JUST(attrs.GetAttr("last_dim_size")); + ctx->input_shape_vec = JUST(oneflow::VectorAt(inputs, 0))->shape()->dim_vec(); + + return Maybe::Ok(); + } + + Maybe Apply(const FftC2RCaptureState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "RuntimeError: out_grads.size() == 1"; + if (!ctx->requires_grad) { return Maybe::Ok(); } + + in_grads->resize(1); + + // NOTE: set `forward` True to prevent conjugating result + auto complex_grad = JUST(functional::FftR2C( + JUST(oneflow::VectorAt(out_grads, 0)), NullOpt, ctx->dims, ctx->norm_mode, + /*onesided=*/true, /*forward=*/true, /*normalized=*/false)); // no need conj + Shape input_shape(ctx->input_shape_vec); + int64_t last_dim = ctx->dims.back(); + auto double_length = + JUST(oneflow::VectorAt(out_grads, 0))->dim(last_dim) - complex_grad->dim(last_dim); + auto in_grad = complex_grad; + + // Mul by 2, and slice + if (double_length > 0) { + in_grad = JUST(functional::Narrow(complex_grad, last_dim, 1, + double_length)); // will change shape of in_grad + in_grad = JUST(functional::ScalarMul(in_grad, 2, /*inplace=*/true)); + } + + std::vector slice_st(input_shape.size(), 0); + std::vector slice_end(input_shape.begin(), input_shape.end()); + std::vector slice_step(input_shape.size(), 1); + auto sliced_tensor = + JUST(functional::Slice(complex_grad, slice_st, slice_end, slice_step, false)); + + JUST(oneflow::VectorAt(*in_grads, 0)) = sliced_tensor; + return Maybe::Ok(); + } +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("fft_r2c", FftR2C); +REGISTER_OP_EXPR_GRAD_FUNCTION("fft_c2c", FftC2C); +REGISTER_OP_EXPR_GRAD_FUNCTION("fft_c2r", FftC2R); + +} // namespace one + +} // namespace oneflow \ No newline at end of file diff --git a/oneflow/core/device/cuda_util.cpp b/oneflow/core/device/cuda_util.cpp index bd1ba58ae5c..aff15ddfcb7 100644 --- a/oneflow/core/device/cuda_util.cpp +++ b/oneflow/core/device/cuda_util.cpp @@ -74,6 +74,28 @@ const char* CurandGetErrorString(curandStatus_t error) { } } +const char* CuFFTGetErrorString(cufftResult_t error) { + switch (error) { + case CUFFT_SUCCESS: return "CUFFT_SUCCESS"; + case CUFFT_INVALID_PLAN: return "CUFFT_INVALID_PLAN"; + case CUFFT_ALLOC_FAILED: return "CUFFT_ALLOC_FAILED"; + case CUFFT_INVALID_TYPE: return "CUFFT_INVALID_TYPE"; + case CUFFT_INVALID_VALUE: return "CUFFT_INVALID_VALUE"; + case CUFFT_INTERNAL_ERROR: return "CUFFT_INTERNAL_ERROR"; + case CUFFT_EXEC_FAILED: return "CUFFT_EXEC_FAILED"; + case CUFFT_SETUP_FAILED: return "CUFFT_SETUP_FAILED"; + case CUFFT_INVALID_SIZE: return "CUFFT_INVALID_SIZE"; + case CUFFT_UNALIGNED_DATA: return "CUFFT_UNALIGNED_DATA"; + case CUFFT_INCOMPLETE_PARAMETER_LIST: return "CUFFT_INCOMPLETE_PARAMETER_LIST"; + case CUFFT_INVALID_DEVICE: return "CUFFT_INVALID_DEVICE"; + case CUFFT_PARSE_ERROR: return "CUFFT_PARSE_ERROR"; + case CUFFT_NO_WORKSPACE: return "CUFFT_NO_WORKSPACE"; + case CUFFT_NOT_IMPLEMENTED: return "CUFFT_NOT_IMPLEMENTED"; + case CUFFT_NOT_SUPPORTED: return "CUFFT_NOT_SUPPORTED"; + default: return "Unknown cufft status"; + } +} + #if CUDA_VERSION >= 11000 const char* CusovlerGetErrorString(cusolverStatus_t error) { switch (error) { diff --git a/oneflow/core/device/cuda_util.h b/oneflow/core/device/cuda_util.h index 67960f33689..19d1654cc62 100644 --- a/oneflow/core/device/cuda_util.h +++ b/oneflow/core/device/cuda_util.h @@ -31,6 +31,7 @@ limitations under the License. #include #include #include +#include #include #include #if CUDA_VERSION >= 11000 @@ -51,6 +52,8 @@ const char* CublasGetErrorString(cublasStatus_t error); const char* CurandGetErrorString(curandStatus_t error); +const char* CuFFTGetErrorString(cufftResult_t error); + #if CUDA_VERSION >= 11000 const char* CusovlerGetErrorString(cusolverStatus_t error); #endif @@ -78,6 +81,12 @@ const char* NvjpegGetErrorString(nvjpegStatus_t error); LOG(FATAL) << "Check failed: " #condition " : " << CublasGetErrorString(_of_cublas_check_status) \ << " (" << _of_cublas_check_status << ") " +#define OF_CUFFT_CHECK(condition) \ + for (cufftResult_t _of_cufft_check_status = (condition); \ + _of_cufft_check_status != CUFFT_SUCCESS;) \ + LOG(FATAL) << "Check failed: " #condition " : " << CuFFTGetErrorString(_of_cufft_check_status) \ + << " (" << _of_cufft_check_status << ") " + #if CUDA_VERSION >= 11000 #define OF_CUSOLVER_CHECK(condition) \ for (cusolverStatus_t _of_cusolver_check_status = (condition); \ diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index e444999780f..7b13c3e33ef 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -3275,6 +3275,111 @@ 'Tensor (Tensor input, Int64 n_fft,Int64 hop_length=None, Int64 win_length=None, Tensor window=None,Bool center=True,String pad_mode="reflect",Bool normalized=False,Bool onesided=True,Bool return_complex=False) =>Stft' bind_python: True +- name: "fft_c2c" + signature: + 'Tensor (Tensor input, Int64List n=None, Int64List dims=None, Int32 norm_mode=0, Bool forward=True, Bool normalized=False) => FftC2C' + bind_python: False + +- name: "fft_r2c" + signature: + 'Tensor (Tensor input, Int64List n=None, Int64List dims=None, Int32 norm_mode=0, Bool onesided=False, Bool forward=True, Bool normalized=False) => FftR2C' + bind_python: False + +- name: "fft_c2r" + signature: + 'Tensor (Tensor input, Int64List n=None, Int64List dims=None, Int32 norm_mode=0, Bool forward=True, Bool normalized=False) =>FftC2R' + bind_python: False + +- name: "fft" + signature: + 'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => Fft' + bind_python: True + +- name: "ifft" + signature: + 'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => IFft' + bind_python: True + +- name: "fft2" + signature: + 'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => Fft2' + bind_python: True + +- name: "ifft2" + signature: + 'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => IFft2' + bind_python: True + +- name: "fftn" + signature: + 'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => FftN' + bind_python: True + +- name: "ifftn" + signature: + 'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => IFftN' + bind_python: True + +- name: "rfft" + signature: + 'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => RFft' + bind_python: True + +- name: "irfft" + signature: + 'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => IRFft' + bind_python: True + +- name: "rfft2" + signature: + 'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => RFft2' + bind_python: True + +- name: "irfft2" + signature: + 'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => IRFft2' + bind_python: True + +- name: "rfftn" + signature: + 'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => RFftN' + bind_python: True + +- name: "irfftn" + signature: + 'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => IRFftN' + bind_python: True + +- name: "hfft" + signature: + 'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => HFft' + bind_python: True + +- name: "ihfft" + signature: + 'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => IHFft' + bind_python: True + +- name: "hfft2" + signature: + 'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => HFft2' + bind_python: True + +- name: "ihfft2" + signature: + 'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => IHFft2' + bind_python: True + +- name: "hfftn" + signature: + 'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => HFftN' + bind_python: True + +- name: "ihfftn" + signature: + 'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => IHFftN' + bind_python: True + - name: "isclose" signature: "Tensor (Tensor input, Tensor other, Float atol=1e-08, Float rtol=1e-05, Bool equal_nan=False) => IsClose" bind_python: True diff --git a/oneflow/core/functional/impl/math_functor.cpp b/oneflow/core/functional/impl/math_functor.cpp index 941441b76c9..693ba02db13 100644 --- a/oneflow/core/functional/impl/math_functor.cpp +++ b/oneflow/core/functional/impl/math_functor.cpp @@ -3900,6 +3900,994 @@ class InplaceAddCDivFunctor { } }; +namespace { +constexpr int64_t cufft_max_ndim = + 3; // must keep Equal to `oneflow/user/kernels/cufft_plan_cache.h:max_rank` +enum class fft_norm_mode { + none = 0, // No normalization + by_root_n, // Divide by sqrt(signal_size) + by_n, // Divide by signal_size +}; + +bool use_optimized_cufft_path(const std::vector& fft_dims) { + // For performance reason, when dim starts with (0, 1), do not use the optimized path. + if (fft_dims.size() > cufft_max_ndim + || (fft_dims.size() >= 2 && fft_dims[0] == 0 && fft_dims[1] == 1)) { + return false; + } else { + return true; + } +} + +// Convert NumPy compatible normalization mode string to enum values +// In Numpy, "forward" translates to `by_n` for a forward transform and `none` for backward. +static fft_norm_mode fft_norm_from_string(const Optional& norm_op, bool forward) { + std::string norm_str = norm_op.value_or("backward"); + if (norm_str == "backward") { + return forward ? fft_norm_mode::none : fft_norm_mode::by_n; + } else if (norm_str == "forward") { + return forward ? fft_norm_mode::by_n : fft_norm_mode::none; + } else if (norm_str == "ortho") { + return fft_norm_mode::by_root_n; + } + + return fft_norm_mode::none; +} + +template +static T fft_compute_fct(int64_t size, fft_norm_mode normalization) { + constexpr auto one = static_cast(1); + switch (normalization) { + case fft_norm_mode::none: return one; + case fft_norm_mode::by_n: return one / static_cast(size); + case fft_norm_mode::by_root_n: return one / std::sqrt(static_cast(size)); + } + return static_cast(0); +} + +template +static T fft_compute_fct(const Shape& in_shape, const std::vector& dims, + fft_norm_mode normalization) { + if (normalization == fft_norm_mode::none) { return static_cast(1); } + int64_t n = 1; + for (int64_t idx : dims) { n *= in_shape.At(idx); } + return fft_compute_fct(n, normalization); +} +} // namespace + +class FftBaseFunctor { + public: + explicit FftBaseFunctor() {} + explicit FftBaseFunctor(std::string op_name) { + op_ = CHECK_JUST(one::OpBuilder(op_name).Input("input").Output("out").Build()); + } + virtual ~FftBaseFunctor() = default; + + Maybe resize_fft_input(const std::shared_ptr& x, + const std::vector& dims, + const std::vector& sizes) const { + CHECK_EQ_OR_THROW(dims.size(), sizes.size()) << "dims.size() != sizes.size()."; + bool must_copy = false; + auto x_sizes = x->shape()->dim_vec(); + std::vector pad_amount(x_sizes.size() * 2); + std::vector slice_st(x_sizes.size()); + std::vector slice_end(x_sizes.size()); + std::vector slice_step(x_sizes.size(), 1); + + FOR_RANGE(int64_t, i, 0, x_sizes.size()) { + slice_st[i] = 0; + slice_end[i] = x_sizes[i]; + } + + FOR_RANGE(int64_t, i, 0, sizes.size()) { + if (sizes[i] == -1) { continue; } + + if (x_sizes[dims[i]] < sizes[i]) { + must_copy = true; + auto pad_idx = pad_amount.size() - 2 * dims[i] - 1; + pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]]; + } + + if (x_sizes[dims[i]] > sizes[i]) { + // slice in dims[i] + slice_end[dims[i]] = sizes[i]; + } + } + + auto sliced_tenosr = JUST(functional::Slice(x, slice_st, slice_end, slice_step, false)); + return must_copy ? functional::ConstantPad(sliced_tenosr, pad_amount, 0) : sliced_tenosr; + } + + Maybe> promote_type_fft(Symbol type, bool require_complex = false) const { + if (type->is_complex()) { return type; } + + if (!type->is_floating_point()) { type = GetDefaultDType(); } + CHECK_OR_RETURN(type->data_type() == kFloat || type->data_type() == kDouble) + << "Unsupported dtype " << type->name() << ", " + << "support kFloat and kDouble"; + + if (!require_complex) { return type; } + + switch (type->data_type()) { + // TO-DO: add kFloat16 + case (kFloat): return CHECK_JUST(DType::Get(DataType::kComplex64)); + case (kDouble): return CHECK_JUST(DType::Get(DataType::kComplex128)); + default: CHECK_OR_RETURN(false) << "RuntimeError: dtype can't be handled"; + } + CHECK_OR_RETURN(false) << "RuntimeError: dtype can't be handled"; + } + + Maybe promote_tensor_fft(const std::shared_ptr& x, + bool require_complex = false) const { + auto cur_type = x->dtype(); + auto new_type = JUST(promote_type_fft(cur_type, require_complex)); + if (cur_type->data_type() == new_type->data_type()) { + return x; + } else { + TensorProcessor tensor_processor; + JUST(tensor_processor.AddInputs({x}, {new_type}).Apply()); + return JUST(oneflow::VectorAt(JUST(tensor_processor.GetInputs()), 0)); + } + } + + Maybe maybe_wrap_dims(std::vector& dims, int64_t dim_post_expr, + bool wrap_scalar = true) const { + if (dim_post_expr <= 0) { + if (!wrap_scalar) { + CHECK_OR_RETURN(false) << "RuntimeError: dimension specified as " << dims[0] + << " but tensor has no dimensions"; + } + dim_post_expr = 1; // this will make range [-1, 0] + } + + int64_t min = -dim_post_expr; + int64_t max = dim_post_expr - 1; + for (auto& dim : dims) { + if (dim < min || dim > max) { + CHECK_OR_RETURN(false) + << "RuntimeError: Dimension out of range (expected to be in range of [" << min << ", " + << max << "], but got " << dim << ")"; + } + if (dim < 0) dim += dim_post_expr; + } + return Maybe::Ok(); + } + + Maybe calculate_fftn_shape_and_dims(const std::shared_ptr& x, + const Optional>& n, + const Optional>& dims, + std::vector& fft_shape, + std::vector& fft_dims) const { + if (dims.has_value()) { + fft_dims = *JUST(dims); + JUST(maybe_wrap_dims(fft_dims, x->ndim())); + std::vector copy = fft_dims; + std::sort(copy.begin(), copy.end()); + auto duplicate = std::adjacent_find(copy.begin(), copy.end()); + CHECK_OR_RETURN(duplicate == copy.end()) << "RuntimeError: FFT dims must be unique"; + } else { + fft_dims.resize(x->ndim()); + for (int i = 0; i < x->ndim(); i++) { fft_dims[i] = i; } + } + + if (!n.has_value()) { + fft_shape.resize(fft_dims.size()); + for (int i = 0; i < fft_dims.size(); i++) { fft_shape[i] = x->dim(fft_dims[i]); } + } else { + fft_shape = *JUST(n); + if (dims.has_value()) { + // got n, also got dim + for (int i = 0; i < fft_dims.size(); i++) { + if (fft_shape[i] == -1) { fft_shape[i] = x->dim(fft_dims[i]); } + } + } else { + // got n, but not got dim + fft_dims.resize(fft_shape.size()); + FOR_RANGE(size_t, i, 0, fft_dims.size()) { fft_dims[i] = x->ndim() - fft_dims.size() + i; } + } + } + + return Maybe::Ok(); + } + + Maybe parse_input_n_and_dims(const std::shared_ptr& x, + const Optional>& n, + const Optional>& dims, + std::vector& fft_len, + std::vector& wrapped_dims) const { + if (n.has_value() && dims.has_value()) { + CHECK_OR_RETURN((*JUST(n)).size() == (*JUST(dims)).size()) + << "RuntimeError: When dim and shape were both given, they must have the same length"; + } + wrapped_dims.resize(x->ndim()); + fft_len.resize(x->ndim()); + if (dims.has_value() && (*JUST(dims)).size() == 1) { + // 1D-discrete fourier transform + wrapped_dims = *JUST(dims); + JUST(maybe_wrap_dims(wrapped_dims, x->ndim())); + fft_len.resize(wrapped_dims.size()); + fft_len[0] = n.has_value() == true ? (*JUST(n))[0] : x->dim(wrapped_dims[0]); + if (fft_len[0] == -1) { fft_len[0] = x->dim(wrapped_dims[0]); } + CHECK_OR_RETURN(fft_len[0] >= 1) << "RuntimeError: Expected n >= 1, but got " << fft_len[0]; + } else if (n.has_value() && JUST(n)->size() == 1) { + // 1D-discrete fourier transform + fft_len = *(JUST(n)); + if (fft_len[0] == -1) { fft_len[0] = x->shape()->back(); } + CHECK_OR_RETURN(fft_len[0] >= 1) << "RuntimeError: Expected n >= 1, but got " << fft_len[0]; + wrapped_dims.resize(1); + wrapped_dims[0] = x->ndim() - 1; + } else { + // ND-discrete fourier transform + JUST(calculate_fftn_shape_and_dims(x, n, dims, fft_len, wrapped_dims)); + } + + return Maybe::Ok(); + } + + Maybe permute_and_reshape(const std::shared_ptr& self, + const std::vector& out_sizes, + const std::vector& fft_dims, + std::vector& out_strides) const { + // Permute and reshape `self` Tensor. + // This can maximizes data locality + const int64_t ndim = self->ndim(); + const int64_t fft_ndim = fft_dims.size(); + const int64_t batch_dims = ndim - fft_ndim; + const auto& in_stride = JUST(self->stride()); + // Permute dimensions to make batch dims come first, and this maximizes data locality + std::vector dim_permute(ndim); + std::iota(dim_permute.begin(), dim_permute.end(), int32_t(0)); + std::vector is_transformed_dim(ndim, false); + for (const auto& dim : fft_dims) { is_transformed_dim[dim] = true; } + + auto batch_end = std::partition(dim_permute.begin(), dim_permute.end(), + [&](int64_t d) { return !is_transformed_dim[d]; }); + std::sort(dim_permute.begin(), batch_end, + [&](int64_t a, int64_t b) { return in_stride->at(a) > in_stride->at(b); }); + std::copy(fft_dims.begin(), fft_dims.end(), batch_end); + + // permute + auto input = JUST(functional::Permute(self, dim_permute)); + + std::vector batched_sizes(fft_ndim + 1); + batched_sizes[0] = -1; + std::copy(input->shape()->begin() + batch_dims, input->shape()->end(), + batched_sizes.begin() + 1); + // reshape + Shape batched_shape(batched_sizes); + input = JUST(functional::Reshape(input, batched_shape)); + + const auto batch_size = input->shape()->At(0); + + batched_sizes[0] = batch_size; + std::vector batched_out_sizes(batched_sizes.begin(), batched_sizes.end()); + FOR_RANGE(int64_t, i, 0, fft_dims.size()) { batched_out_sizes[i + 1] = out_sizes[fft_dims[i]]; } + + // Inplace reshaping to original batch shape and inverting the dimension permutation + out_strides.resize(ndim, 0); + + int64_t batch_numel = 1; + Stride contiguous_out_strides = Stride(batched_out_sizes); + for (int64_t i = batch_dims - 1; i >= 0; --i) { + out_strides[dim_permute[i]] = batch_numel * contiguous_out_strides[0]; + batch_numel *= out_sizes[dim_permute[i]]; + } + FOR_RANGE(int64_t, i, batch_dims, ndim) { + out_strides[dim_permute[i]] = contiguous_out_strides[1 + (i - batch_dims)]; + } + + // Judge if the input needs to be cloned + int64_t signal_ndim = input->shape()->size() - 1; + const Stride& batched_input_strides = *(JUST(input->stride())); + auto last_stride = JUST(oneflow::VectorAt(batched_input_strides, signal_ndim)); + bool must_clone_input = false; + if (JUST(oneflow::VectorAt(batched_input_strides, 0)) == 0) { must_clone_input = true; } + for (auto i = signal_ndim - 1; !must_clone_input && i > 0; i--) { + auto stride = JUST(oneflow::VectorAt(batched_input_strides, i)); + if (JUST(oneflow::VectorAt(*(input->shape()), i)) == 1) { + continue; + } else if (stride > 0 && stride % last_stride == 0) { + last_stride = stride; + } else { + must_clone_input = true; + } + } + + if (must_clone_input) { input = JUST(functional::ToContiguous(input)); } + return input; + } + + Maybe parse_c2r_input_n_and_dims(const std::shared_ptr& x, + const Optional>& n, + const Optional>& dims, + int64_t& last_dim_size, std::vector& fft_len, + std::vector& wrapped_dims) const { + JUST(parse_input_n_and_dims(x, n, dims, fft_len, wrapped_dims)); + // infer last_dim_size + last_dim_size = 0; + if (!n.has_value() || JUST(n)->back() == -1) { + int64_t last_dim = wrapped_dims.back(); + last_dim_size = 2 * (x->dim(last_dim) - 1); + } else { + last_dim_size = JUST(n)->back(); + } + CHECK_OR_RETURN(last_dim_size >= 1) + << "RuntimeError: Invalid number of last_dim_size (" << last_dim_size << ") specified"; + fft_len.back() = last_dim_size / 2 + 1; + + return Maybe::Ok(); + } + + protected: + std::shared_ptr op_; +}; + +class FftC2CFunctor : public FftBaseFunctor { + public: + FftC2CFunctor() : FftBaseFunctor("fft_c2c") {} + Maybe operator()(const std::shared_ptr& x, + const Optional>& n, + const Optional>& dims, int32_t norm_mode, + bool forward, bool normalized) const { + // NOTE: The parameter `normalized` indicates whether the FFT results need to be normalized + // using `ScalarMul`. This parameter is only valid when using CUDA devices. This parameter is + // not valid when using a CPU device, because the cpu's fft operator will be normalized inside + // the cpu oprator according to the parameter `forward` and the type of FFT transform + + CHECK_OR_RETURN(x->dtype()->is_complex()) + << "RuntimeError: expects the dtype of input Tensor is Complex, but gets " + << x->dtype()->name(); + std::vector fft_len(x->ndim(), 0); + std::vector wrapped_dims(x->ndim(), 0); + + JUST(parse_input_n_and_dims(x, n, dims, fft_len, wrapped_dims)); + auto resized_tensor = + n.has_value() == true ? JUST(resize_fft_input(x, wrapped_dims, fft_len)) : x; + + DeviceType input_device{}; + if (x->is_global()) { + input_device = JUST(x->parallel_desc())->device_type(); + } else { + input_device = JUST(x->device())->enum_type(); + } + + double norm_fct = fft_compute_fct(*(resized_tensor->shape()), wrapped_dims, + static_cast(norm_mode)); + + if (input_device == DeviceType::kCPU) { + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dims", "forward", "norm_mode", "norm_fct"); + attrs.SetAllAttrs(wrapped_dims, forward, norm_mode, norm_fct); + return OpInterpUtil::Dispatch(*op_, {resized_tensor}, attrs); + } else if (input_device == DeviceType::kCUDA) { + if (wrapped_dims.empty()) { return resized_tensor; } + std::vector out_sizes(resized_tensor->shape()->dim_vec().begin(), + resized_tensor->shape()->dim_vec().end()); + std::vector sorted_dims(wrapped_dims.begin(), wrapped_dims.end()); + auto working_tensor = resized_tensor; + std::vector out_strides; + std::shared_ptr output; + while (true) { + // Sort Dimemsions every iteration + auto strides = *JUST(working_tensor->stride()); + std::sort(sorted_dims.begin(), sorted_dims.end(), + [&](int64_t a, int64_t b) { return strides[a] > strides[b]; }); + + const auto max_dims = std::min(static_cast(cufft_max_ndim), sorted_dims.size()); + auto first_dims_end = sorted_dims.end(); + auto first_dims_begin = first_dims_end - max_dims; + std::vector first_dims(first_dims_begin, first_dims_end); + + auto input = JUST(permute_and_reshape(working_tensor, out_sizes, first_dims, out_strides)); + + std::vector fft_dims(input->ndim() - 1); // must >= 1 + std::iota(fft_dims.begin(), fft_dims.end(), int64_t(1)); + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dims", "forward", "norm_mode", "norm_fct"); + attrs.SetAllAttrs(fft_dims, forward, norm_mode, norm_fct); + output = JUST(OpInterpUtil::Dispatch(*op_, {input}, attrs)); + output = JUST( + functional::AsStrided(output, out_sizes, out_strides, JUST(output->storage_offset()))); + + sorted_dims.resize(sorted_dims.size() - max_dims); + + if (sorted_dims.empty()) { break; } + working_tensor = std::move(output); + } + + if (normalized) { JUST(functional::ScalarMul(output, Scalar(norm_fct), true)); } + + return output; + } else { + CHECK_OR_RETURN(false) << "RuntimeError: FFTC2C Only support cpu and cuda device."; + } + } +}; + +class FftR2CFunctor : public FftBaseFunctor { + public: + FftR2CFunctor() : FftBaseFunctor("fft_r2c") {} + + Maybe operator()(const std::shared_ptr& x, + const Optional>& n, + const Optional>& dims, int32_t norm_mode, + bool onesided, bool forward, bool normalized) const { + // NOTE: The parameter `normalized` indicates whether the FFT results need to be normalized + // using `ScalarMul`. This parameter is only valid when using CUDA devices. This parameter is + // not valid when using a CPU device, because the cpu's fft operator will be normalized inside + // the cpu oprator according to the parameter `forward` and the type of FFT transform + + CHECK_OR_RETURN(!(x->dtype()->is_complex())) + << "RuntimeError: expects the dtype of input Tensor is Real, but gets " + << x->dtype()->name(); + + auto input_tensor = JUST(promote_tensor_fft(x)); + + if (n.has_value() && dims.has_value()) { + CHECK_OR_RETURN((*JUST(n)).size() == (*JUST(dims)).size()) + << "RuntimeError: When dim and shape were both given, they must have the same length"; + } + + std::vector fft_len(input_tensor->ndim(), 0); + std::vector wrapped_dims(input_tensor->ndim(), 0); + JUST(parse_input_n_and_dims(input_tensor, n, dims, fft_len, wrapped_dims)); + auto resized_tensor = n.has_value() == true + ? JUST(resize_fft_input(input_tensor, wrapped_dims, fft_len)) + : input_tensor; + DeviceType input_device{}; + if (x->is_global()) { + input_device = JUST(x->parallel_desc())->device_type(); + } else { + input_device = JUST(x->device())->enum_type(); + } + + double norm_fct = fft_compute_fct(*(resized_tensor->shape()), wrapped_dims, + static_cast(norm_mode)); + + std::shared_ptr output; + if (input_device == DeviceType::kCPU) { + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dims", "norm_mode", "norm_fct", "onesided"); + attrs.SetAllAttrs(wrapped_dims, norm_mode, norm_fct, onesided); + output = JUST(OpInterpUtil::Dispatch(*op_, {resized_tensor}, attrs)); + } else if (input_device == DeviceType::kCUDA) { + std::vector input_sizes(resized_tensor->shape()->begin(), + resized_tensor->shape()->end()); + std::vector onesided_sizes = input_sizes; + int64_t last_dim = wrapped_dims.back(); + int64_t last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1; + onesided_sizes[last_dim] = last_dim_halfsize; + std::vector out_sizes = onesided ? onesided_sizes : input_sizes; + + if (use_optimized_cufft_path(wrapped_dims)) { + std::vector out_strides; + auto input = + JUST(permute_and_reshape(resized_tensor, out_sizes, wrapped_dims, out_strides)); + + std::vector fft_dims(input->ndim() - 1); // must >= 1 + std::iota(fft_dims.begin(), fft_dims.end(), int64_t(1)); + + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dims", "norm_mode", "norm_fct", "onesided"); + attrs.SetAllAttrs(fft_dims, norm_mode, norm_fct, onesided); + output = JUST(OpInterpUtil::Dispatch(*op_, {input}, attrs)); + output = JUST( + functional::AsStrided(output, out_sizes, out_strides, JUST(output->storage_offset()))); + } else { + // First do the **one-sided** R2C transform on the last dimension + const std::shared_ptr& working_tensor = resized_tensor; + { + std::vector out_strides; + auto input = JUST( + permute_and_reshape(/*self=*/working_tensor, /*out_sizes=*/onesided_sizes, + /*fft_dims=*/{wrapped_dims.back()}, /*out_strides=*/out_strides)); + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dims", "norm_mode", "norm_fct", "onesided"); + int64_t last_dim = input->shape()->size() - 1; + std::vector fft_last_dim_vec = {last_dim}; + attrs.SetAllAttrs(fft_last_dim_vec, norm_mode, norm_fct, /*onesided=*/true); + output = JUST(OpInterpUtil::Dispatch(*op_, {input}, attrs)); + output = JUST(functional::AsStrided(output, out_sizes, out_strides, + JUST(output->storage_offset()))); + } + + // Then any remaining C2C transforms + std::vector sorted_dims(wrapped_dims.begin(), wrapped_dims.end() - 1); + if (!sorted_dims.empty()) { + output = JUST(functional::FftC2C(output, NullOpt, sorted_dims, norm_mode, + /*forward=*/true, /*normalize=*/false)); + } + } + + if (normalized) { JUST(functional::ScalarMul(output, Scalar(norm_fct), true)); } + + } else { + CHECK_OR_RETURN(false) << "RuntimeError: FFTR2C Only support cpu and cuda device."; + } + + if (!forward) { + return functional::ConjPhysical(output); + } else { + return output; + } + } +}; + +class FftC2RFunctor : public FftBaseFunctor { + public: + FftC2RFunctor() : FftBaseFunctor("fft_c2r") {} + + Maybe operator()(const std::shared_ptr& x, + const Optional>& n, + const Optional>& dims, int32_t norm_mode, + bool forward, bool normalized) const { + // NOTE: The parameter `normalized` indicates whether the FFT results need to be normalized + // using `ScalarMul`. This parameter is only valid when using CUDA devices. This parameter is + // not valid when using a CPU device, because the cpu's fft operator will be normalized inside + // the cpu oprator according to the parameter `forward` and the type of FFT transform + + CHECK_OR_RETURN(x->dtype()->is_complex()) + << "RuntimeError: expects the dtype of input Tensor is Complex, but gets " + << x->dtype()->name(); + + if (n.has_value() && dims.has_value()) { + CHECK_OR_RETURN((*JUST(n)).size() == (*JUST(dims)).size()) + << "RuntimeError: When dim and shape were both given, they must have the same length"; + } + + std::vector wrapped_dims(x->ndim(), 0); + std::vector fft_len(x->ndim(), 0); + int64_t last_dim_size = 0; + JUST(parse_c2r_input_n_and_dims(x, n, dims, last_dim_size, fft_len, wrapped_dims)); + + auto resized_tensor = + n.has_value() == true ? JUST(resize_fft_input(x, wrapped_dims, fft_len)) : x; + + Shape out_shape = *(resized_tensor->shape()); + out_shape[wrapped_dims.back()] = last_dim_size; + double norm_fct = + fft_compute_fct(out_shape, wrapped_dims, static_cast(norm_mode)); + + if (forward) { resized_tensor = JUST(functional::ConjPhysical(resized_tensor)); } + + DeviceType input_device{}; + if (x->is_global()) { + input_device = JUST(x->parallel_desc())->device_type(); + } else { + input_device = JUST(x->device())->enum_type(); + } + + if (input_device == DeviceType::kCPU) { + auto& attrs = + THREAD_CACHED_MUTABLE_ATTR_MAP("dims", "norm_mode", "norm_fct", "last_dim_size"); + attrs.SetAllAttrs(wrapped_dims, norm_mode, norm_fct, last_dim_size); + return OpInterpUtil::Dispatch(*op_, {resized_tensor}, attrs); + } else if (input_device == DeviceType::kCUDA) { + std::shared_ptr output; + if (use_optimized_cufft_path(wrapped_dims)) { + auto input = JUST(functional::ToContiguous(resized_tensor)); + std::vector out_sizes(out_shape.dim_vec().begin(), out_shape.dim_vec().end()); + std::vector out_strides; + input = JUST(permute_and_reshape(input, out_sizes, wrapped_dims, out_strides)); + + std::vector fft_dims(input->ndim() - 1); // must >= 1 + std::iota(fft_dims.begin(), fft_dims.end(), int64_t(1)); + + auto& attrs = + THREAD_CACHED_MUTABLE_ATTR_MAP("dims", "norm_mode", "norm_fct", "last_dim_size"); + attrs.SetAllAttrs(fft_dims, norm_mode, norm_fct, last_dim_size); + output = JUST(OpInterpUtil::Dispatch(*op_, {input}, attrs)); + output = JUST( + functional::AsStrided(output, out_sizes, out_strides, JUST(output->storage_offset()))); + } else { + // First complete any C2C transforms + std::shared_ptr temp; + if (wrapped_dims.size() > 1) { + std::vector any_c2c_dims(wrapped_dims.begin(), wrapped_dims.end() - 1); + temp = JUST(functional::FftC2C(resized_tensor, NullOpt, any_c2c_dims, + static_cast(fft_norm_mode::none), + /*forward=*/false, /*normalized=*/false)); + } else { + temp = JUST(functional::ToContiguous(resized_tensor)); + } + + // Finally, do the 1D C2R transforms on the last dim + std::vector out_strides; + std::vector out_sizes(out_shape.dim_vec().begin(), out_shape.dim_vec().end()); + auto input = JUST(permute_and_reshape(/*self=*/temp, /*out_sizes=*/out_sizes, + /*fft_dims=*/{wrapped_dims.back()}, + /*out_strides=*/out_strides)); + + auto& attrs = + THREAD_CACHED_MUTABLE_ATTR_MAP("dims", "norm_mode", "norm_fct", "last_dim_size"); + int64_t last_dim = input->shape()->size() - 1; + std::vector fft_last_dim_vec = {last_dim}; + attrs.SetAllAttrs(fft_last_dim_vec, norm_mode, norm_fct, /*last_dim_size=*/last_dim_size); + + output = JUST(OpInterpUtil::Dispatch(*op_, {input}, attrs)); + output = JUST( + functional::AsStrided(output, out_sizes, out_strides, JUST(output->storage_offset()))); + } + + if (normalized) { JUST(functional::ScalarMul(output, Scalar(norm_fct), /*inplace=*/true)); } + return output; + } else { + CHECK_OR_RETURN(false) << "RuntimeError: FFTC2R Only support cpu and cuda device."; + } + } +}; + +class FftFunctor { + public: + Maybe operator()(const std::shared_ptr& input, int64_t n, int64_t dim, + const Optional& norm) const { + std::string norm_str = norm.value_or("backward"); + std::vector fft_dim{dim}; + + bool forward = true; + fft_norm_mode norm_mode = fft_norm_mode::none; + norm_mode = fft_norm_from_string(norm_str, forward); + + std::vector len{n}; + return input->dtype()->is_complex() + ? functional::FftC2C(input, len, fft_dim, static_cast(norm_mode), + /*forward=*/forward, /*normalized=*/true) + : functional::FftR2C(input, len, fft_dim, static_cast(norm_mode), + /*onesided=*/false, /*forward=*/forward, /*normalized=*/true); + } +}; + +class IFftFunctor { + public: + Maybe operator()(const std::shared_ptr& input, int64_t n, int64_t dim, + const Optional& norm) const { + auto norm_str = norm.value_or("backward"); + std::vector fft_dim{dim}; + + bool forward = false; + fft_norm_mode norm_mode = fft_norm_mode::none; + norm_mode = fft_norm_from_string(norm_str, forward); + std::vector len{n}; + return input->dtype()->is_complex() + ? functional::FftC2C(input, len, fft_dim, static_cast(norm_mode), + /*forward=*/forward, /*normalized=*/true) + : functional::FftR2C(input, len, fft_dim, static_cast(norm_mode), + /*onesided=*/false, /*forward=*/forward, /*normalized=*/true); + } +}; + +class Fft2Functor { + public: + Maybe operator()(const std::shared_ptr& input, + const Optional>& s, const std::vector& dim, + const Optional& norm) const { + return functional::FftN(input, s, dim, norm); + } +}; + +class IFft2Functor { + public: + Maybe operator()(const std::shared_ptr& input, + const Optional>& s, const std::vector& dim, + const Optional& norm) const { + return functional::IFftN(input, s, dim, norm); + } +}; + +class FftNFunctor { + public: + Maybe operator()(const std::shared_ptr& input, + const Optional>& s, + const Optional>& dim, + const Optional& norm) const { + std::string norm_str = norm.value_or("backward"); + bool forward = true; + fft_norm_mode norm_mode = fft_norm_mode::none; + norm_mode = fft_norm_from_string(norm_str, forward); + + if (!(input->dtype()->is_complex())) { + // cast to complex + TensorProcessor tensor_processor; + Symbol complex_dtype; + if (input->dtype() == DType::Double()) { + complex_dtype = DType::Complex128(); + } else { + complex_dtype = DType::Complex64(); + } + JUST(tensor_processor.AddInputs({input}, {complex_dtype}).Apply()); + TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); + return functional::FftC2C(JUST(oneflow::VectorAt(input_tuple, 0)), s, dim, + static_cast(norm_mode), /*forward=*/forward, + /*normalized=*/true); + } else { + return functional::FftC2C(input, s, dim, static_cast(norm_mode), /*forward=*/forward, + /*normalized=*/true); + } + } +}; + +class IFftNFunctor { + public: + Maybe operator()(const std::shared_ptr& input, + const Optional>& s, + const Optional>& dim, + const Optional& norm) const { + std::string norm_str = norm.value_or("backward"); + bool forward = false; + fft_norm_mode norm_mode = fft_norm_mode::none; + norm_mode = fft_norm_from_string(norm_str, forward); + + if (!(input->dtype()->is_complex())) { + // cast to complex + TensorProcessor tensor_processor; + Symbol complex_dtype; + if (input->dtype() == DType::Double()) { + complex_dtype = DType::Complex128(); + } else { + complex_dtype = DType::Complex64(); + } + JUST(tensor_processor.AddInputs({input}, {complex_dtype}).Apply()); + TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); + return functional::FftC2C(JUST(oneflow::VectorAt(input_tuple, 0)), s, dim, + static_cast(norm_mode), /*forward=*/forward, + /*normalized=*/true); + } else { + return functional::FftC2C(input, s, dim, static_cast(norm_mode), /*forward=*/forward, + /*normalized=*/true); + } + } +}; + +class RFftFunctor { + public: + Maybe operator()(const std::shared_ptr& input, int64_t n, int64_t dim, + const Optional& norm) const { + CHECK_OR_RETURN(!(input->dtype()->is_complex())) + << "RuntimeError: expects the dtype of input Tensor is Real, but gets " + << input->dtype()->name(); + + std::string norm_str = norm.value_or("backward"); + std::vector fft_dim{dim}; + bool forward = true; + fft_norm_mode norm_mode = fft_norm_mode::none; + norm_mode = fft_norm_from_string(norm_str, forward); + + std::vector len{n}; + return functional::FftR2C(input, len, fft_dim, static_cast(norm_mode), + /*onesided=*/true, /*forward=*/forward, /*normalized=*/true); + } +}; + +class IRFftFunctor { + public: + Maybe operator()(const std::shared_ptr& input, int64_t n, int64_t dim, + const Optional& norm) const { + std::string norm_str = norm.value_or("backward"); + std::vector fft_dim{dim}; + + bool forward = false; + fft_norm_mode norm_mode = fft_norm_mode::none; + norm_mode = fft_norm_from_string(norm_str, forward); + + std::vector len{n}; + return functional::FftC2R(input, len, fft_dim, static_cast(norm_mode), + /*forward=*/forward, /*normalized=*/true); + } +}; + +class RFft2Functor { + public: + Maybe operator()(const std::shared_ptr& input, + const Optional>& s, const std::vector& dim, + const Optional& norm) const { + return functional::RFftN(input, s, dim, norm); + } +}; + +class IRFft2Functor { + public: + Maybe operator()(const std::shared_ptr& input, + const Optional>& s, const std::vector& dim, + const Optional& norm) const { + return functional::IRFftN(input, s, dim, norm); + } +}; + +class RFftNFunctor { + public: + Maybe operator()(const std::shared_ptr& input, + const Optional>& s, + const Optional>& dim, + const Optional& norm) const { + CHECK_OR_RETURN(!(input->dtype()->is_complex())) + << "RuntimeError: expects the dtype of input Tensor is Real, but gets " + << input->dtype()->name(); + + std::string norm_str = norm.value_or("backward"); + bool forward = true; + fft_norm_mode norm_mode = fft_norm_mode::none; + norm_mode = fft_norm_from_string(norm_str, forward); + + return functional::FftR2C(input, s, dim, static_cast(norm_mode), /*onesided=*/true, + /*forward=*/forward, /*normalized=*/true); + } +}; + +class IRFftNFunctor { + public: + Maybe operator()(const std::shared_ptr& input, + const Optional>& s, + const Optional>& dim, + const Optional& norm) const { + CHECK_OR_RETURN(input->dtype()->is_complex()) + << "RuntimeError: expects the dtype of input Tensor is Complex, but gets " + << input->dtype()->name(); + + std::string norm_str = norm.value_or("backward"); + bool forward = false; + fft_norm_mode norm_mode = fft_norm_mode::none; + norm_mode = fft_norm_from_string(norm_str, forward); + + return functional::FftC2R(input, s, dim, static_cast(norm_mode), /*forward=*/false, + /*normalized=*/true); + } +}; + +class HFftFunctor { + public: + Maybe operator()(const std::shared_ptr& input, int64_t n, int64_t dim, + const Optional& norm) const { + CHECK_OR_RETURN(input->dtype()->is_complex()) + << "RuntimeError: expects the dtype of input Tensor is Complex, but gets " + << input->dtype()->name(); + + std::string norm_str = norm.value_or("backward"); + std::vector fft_dim{dim}; + + bool forward = true; + fft_norm_mode norm_mode = fft_norm_mode::none; + norm_mode = fft_norm_from_string(norm_str, forward); + + std::vector len{n}; + return functional::FftC2R(input, len, fft_dim, static_cast(norm_mode), + /*forward=*/forward, /*normalized=*/true); + } +}; + +class IHFftFunctor { + public: + Maybe operator()(const std::shared_ptr& input, int64_t n, int64_t dim, + const Optional& norm) const { + CHECK_OR_RETURN(!(input->dtype()->is_complex())) + << "RuntimeError: expects the dtype of input Tensor is Real, but gets " + << input->dtype()->name(); + + std::string norm_str = norm.value_or("backward"); + std::vector fft_dim{dim}; + + bool forward = false; + fft_norm_mode norm_mode = fft_norm_mode::none; + norm_mode = fft_norm_from_string(norm_str, forward); + + std::vector len{n}; + return functional::FftR2C(input, len, fft_dim, static_cast(norm_mode), + /*onesided=*/true, + /*forward=*/forward, /*normalized=*/true); + } +}; + +class HFft2Functor { + public: + Maybe operator()(const std::shared_ptr& input, + const Optional>& s, const std::vector& dim, + const Optional& norm) const { + return functional::HFftN(input, s, dim, norm); + } +}; + +class IHFft2Functor { + public: + Maybe operator()(const std::shared_ptr& input, + const Optional>& s, const std::vector& dim, + const Optional& norm) const { + return functional::IHFftN(input, s, dim, norm); + } +}; + +class HFftNFunctor : FftBaseFunctor { + public: + HFftNFunctor() : FftBaseFunctor() {} + Maybe operator()(const std::shared_ptr& input, + const Optional>& s, + const Optional>& dim, + const Optional& norm) const { + CHECK_OR_RETURN(input->dtype()->is_complex()) + << "RuntimeError: expects the dtype of input Tensor is Complex, but gets " + << input->dtype()->name(); + + std::string norm_str = norm.value_or("backward"); + + bool forward = true; + fft_norm_mode norm_mode = fft_norm_mode::none; + norm_mode = fft_norm_from_string(norm_str, forward); + + if (s.has_value() && dim.has_value()) { + CHECK_OR_RETURN((*JUST(s)).size() == (*JUST(dim)).size()) + << "RuntimeError: When dim and shape were both given, they must have the same length"; + } + + std::vector wrapped_dims(input->ndim(), 0); + std::vector fft_len(input->ndim(), 0); + int64_t last_dim_size = 0; + JUST(parse_c2r_input_n_and_dims(input, s, dim, last_dim_size, fft_len, wrapped_dims)); + + auto resized_tensor = + s.has_value() == true ? JUST(resize_fft_input(input, wrapped_dims, fft_len)) : input; + + std::shared_ptr temp; + if (wrapped_dims.size() > 1) { + // ND Fast Fourier Transform + std::vector c2c_dims(wrapped_dims.begin(), wrapped_dims.end() - 1); + temp = JUST(functional::FftC2C(resized_tensor, NullOpt, c2c_dims, + static_cast(norm_mode), /*forward=*/forward, + /*normalized=*/true)); + } else { + temp = resized_tensor; + } + + // Finally, do 1D fft_c2r + int64_t last_dim = wrapped_dims.back(); + std::vector last_dim_vec = {last_dim}; + std::vector last_dim_size_vec = {last_dim_size}; + return functional::FftC2R(temp, last_dim_size_vec, last_dim_vec, + static_cast(norm_mode), /*forward=*/forward, + /*normalized=*/true); + } +}; + +class IHFftNFunctor : FftBaseFunctor { + public: + IHFftNFunctor() : FftBaseFunctor() {} + Maybe operator()(const std::shared_ptr& input, + const Optional>& s, + const Optional>& dim, + const Optional& norm) const { + CHECK_OR_RETURN(!(input->dtype()->is_complex())) + << "RuntimeError: expects the dtype of input Tensor is Real, but gets " + << input->dtype()->name(); + + std::string norm_str = norm.value_or("backward"); + bool forward = false; + fft_norm_mode norm_mode = fft_norm_mode::none; + norm_mode = fft_norm_from_string(norm_str, forward); + + auto input_tensor = JUST(promote_tensor_fft(input, false)); + + if (s.has_value() && dim.has_value()) { + CHECK_OR_RETURN((*JUST(s)).size() == (*JUST(dim)).size()) + << "RuntimeError: When dim and shape were both given, they must have the same length"; + } + + std::vector fft_len(input_tensor->ndim(), 0); + std::vector wrapped_dims(input_tensor->ndim(), 0); + JUST(parse_input_n_and_dims(input_tensor, s, dim, fft_len, wrapped_dims)); + auto resized_tensor = s.has_value() == true + ? JUST(resize_fft_input(input_tensor, wrapped_dims, fft_len)) + : input_tensor; + + // First do 1D R2C Transform on the last dim + const auto last_dim_len = fft_len.back(); + const auto last_dim = wrapped_dims.back(); + std::vector r2c_fft_len = {last_dim_len}; + std::vector r2c_fft_dim = {last_dim}; + auto temp = JUST(functional::FftR2C(resized_tensor, r2c_fft_len, r2c_fft_dim, + static_cast(norm_mode), /*onesided=*/true, + /*forward=*/forward, /*normalized=*/true)); + // NOTE: `temp` is already conjugated in `functional::FftR2C` + if (wrapped_dims.size() == 1) { return temp; } + + // Finally do C2C Transform on the remaining dims + std::vector c2c_dims(wrapped_dims.begin(), wrapped_dims.end() - 1); + return functional::FftC2C(temp, NullOpt, c2c_dims, static_cast(norm_mode), + /*forward=*/forward, /*normalized=*/true); + } +}; + class StftFunctor { public: StftFunctor() { @@ -3911,6 +4899,7 @@ class StftFunctor { const Optional& window, const bool center, const std::string& mode, const bool normalized, const bool onesided, const bool return_complex) const { + CHECK_OR_RETURN(n_fft > 0) << Error::RuntimeError() << "Expected 0 < n_fft , but got " << n_fft; int64_t new_hop_length = hop_length.has_value() == true ? JUST(hop_length) : n_fft / 4; int64_t new_win_length = win_length.has_value() == true ? JUST(win_length) : n_fft; auto input_tensor = input; @@ -4012,6 +5001,7 @@ class StftFunctor { private: std::shared_ptr op_; }; + class FusedWeightedSumFunctor { public: FusedWeightedSumFunctor() { @@ -4689,7 +5679,30 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Det"); m.add_functor("GeluWithApproximate"); m.add_functor("Trunc"); + m.add_functor("Stft"); + m.add_functor("FftC2C"); + m.add_functor("FftR2C"); + m.add_functor("FftC2R"); + m.add_functor("Fft"); + m.add_functor("IFft"); + m.add_functor("Fft2"); + m.add_functor("IFft2"); + m.add_functor("FftN"); + m.add_functor("IFftN"); + m.add_functor("RFft"); + m.add_functor("IRFft"); + m.add_functor("RFft2"); + m.add_functor("IRFft2"); + m.add_functor("RFftN"); + m.add_functor("IRFftN"); + m.add_functor("HFft"); + m.add_functor("IHFft"); + m.add_functor("HFft2"); + m.add_functor("IHFft2"); + m.add_functor("HFftN"); + m.add_functor("IHFftN"); + m.add_functor("FusedWeightedSum"); m.add_functor("FusedCenter"); m.add_functor("FusedCenterGrad"); diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index a5ec94f5289..c78138dfb7c 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -5087,7 +5087,70 @@ def OneFlow_ErfInvOp : OneFlow_BaseOp<"erfinv", [NoMemoryEffect, DeclareOpInterf let has_data_type_infer_fn = 1; } -def OneFlow_StftOp : OneFlow_BaseOp<"stft", [SupportNonContiguous,NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { +def OneFlow_FftC2COp : OneFlow_BaseOp<"fft_c2c", [SupportNonContiguous, NoMemoryEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input + ); + let output = (outs + OneFlow_Tensor:$out + ); + + let attrs = (ins + SI64ArrayAttr:$dims, + BoolAttr:$forward, + SI32Attr:$norm_mode, + F64Attr:$norm_fct + ); + + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FftR2COp : OneFlow_BaseOp<"fft_r2c", [SupportNonContiguous, NoMemoryEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input + ); + let output = (outs + OneFlow_Tensor:$out + ); + + let attrs = (ins + SI64ArrayAttr:$dims, + SI32Attr:$norm_mode, + F64Attr:$norm_fct, + BoolAttr:$onesided + ); + + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FftC2ROp : OneFlow_BaseOp<"fft_c2r", [SupportNonContiguous, NoMemoryEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input + ); + let output = (outs + OneFlow_Tensor:$out + ); + + let attrs = (ins + SI64ArrayAttr:$dims, + SI32Attr:$norm_mode, + F64Attr:$norm_fct, + SI64Attr:$last_dim_size + ); + + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_StftOp : OneFlow_BaseOp<"stft", [SupportNonContiguous, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, Optional:$window diff --git a/oneflow/user/kernels/cufft_plan_cache.h b/oneflow/user/kernels/cufft_plan_cache.h index 7fb0f95ab1f..75994fc7eb7 100644 --- a/oneflow/user/kernels/cufft_plan_cache.h +++ b/oneflow/user/kernels/cufft_plan_cache.h @@ -19,9 +19,18 @@ limitations under the License. #include #include +#include +#include +#include +#include +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/common/data_type.pb.h" +#include "oneflow/core/common/shape_vec.h" +#include "oneflow/core/common/throw.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" +#include "oneflow/core/device/cuda_util.h" #include "oneflow/core/kernel/kernel.h" namespace oneflow { @@ -30,64 +39,228 @@ namespace { constexpr int max_rank = 3; +enum class CUFFT_EXCUTETYPE { R2C, C2C, C2R }; + +struct CuFFTDataTypeDesc { + cudaDataType inputtype; + cudaDataType outputtype; + cudaDataType executiontype; +}; + +} // namespace + +class CuFFTHandle { + cufftHandle handle; + + public: + CuFFTHandle() { OF_CUFFT_CHECK(cufftCreate(&handle)); } + + cufftHandle& get() { return handle; } + const cufftHandle& get() const { return handle; } + + ~CuFFTHandle() { cufftDestroy(handle); } +}; + +// NOTE: The implementation of `CuFFTDataLayout`, `cufft_simple_embed` and `as_cufft_embed` are +// mostly taken from pytorch. For more details pls refer to `CuFFTPlanCache.h` in PyTorch. +typedef long long cufft_size_type; +typedef small_vector cufft_dim_vector; +struct CuFFTDataLayout { + small_vector embed; + cufft_size_type stride, dist; + bool must_clone, simple; +}; + +// Returns a cufft embedding for a contiguous signal of the given size. +// e.g. if the input is cloned, this will be the resulting data layout +inline CuFFTDataLayout cufft_simple_embed(const cufft_dim_vector& sizes, bool onesided) { + CuFFTDataLayout layout; + layout.simple = true; + layout.must_clone = false; + layout.embed.assign(sizes.cbegin() + 1, sizes.cend()); + if (onesided) { layout.embed.back() = sizes.back() / 2 + 1; } + layout.stride = 1; + layout.dist = 1; + for (const auto& len : layout.embed) { layout.dist *= len; } + return layout; +} + +// Convert strides to a CuFFT embedded representation. +// If strides cannot be embedded, returns a simple layout and sets must_clone flag +inline CuFFTDataLayout as_cufft_embed(const cufft_dim_vector& strides, + const cufft_dim_vector& sizes, bool onesided) { + const auto signal_ndim = strides.size() - 1; + CuFFTDataLayout layout; + auto last_stride = strides[signal_ndim]; + layout.must_clone = (last_stride <= 0); + + const auto last_dim_size = onesided ? sizes[signal_ndim] / 2 + 1 : sizes[signal_ndim]; + + const auto signal_numel = std::accumulate(sizes.begin() + 1, sizes.end() - 1, (cufft_size_type)1, + std::multiplies()) + * last_dim_size; + + // Zero stides are not allowed, even if the batch size is one. + // If that happens just set a dummy case + if (sizes[0] == 1) { + layout.dist = signal_numel; + } else if (strides[0] == 0) { + layout.must_clone = true; + } else { + layout.dist = strides[0]; + } + + // Calculate the embedding shape, or set must_clone if the strides cannot be embedded + layout.embed.resize(signal_ndim); + for (auto i = signal_ndim - 1; !layout.must_clone && i > 0; i--) { + auto stride = strides[i]; + if (sizes[i] == 1) { + layout.embed[i] = 1; + } else if (stride > 0 && stride % last_stride == 0) { + layout.embed[i] = stride / last_stride; + last_stride = stride; + } else { + layout.must_clone = true; + } + } + // must_clone == false + if (layout.must_clone) { + // If the input needs to be cloned, assume it will be contiguous + layout = cufft_simple_embed(sizes, onesided); + layout.must_clone = true; + } else { + layout.embed[0] = sizes[1]; + layout.stride = strides[signal_ndim]; + + // Determine if layout represents a simple embedding (contiguous data) + layout.simple = [&] { + FOR_RANGE(int, i, 1, signal_ndim - 1) { + if (layout.embed[i] != sizes[i + 1]) { return false; } + } + return (layout.stride == 1 && layout.dist == signal_numel + && layout.embed.back() == last_dim_size); + }(); + } + return layout; } -struct CuFFtParams { - int32_t ndim; - int32_t output_shape[max_rank + 1]; - int32_t input_shape[max_rank + 1]; - int32_t input_strides[max_rank + 1]; - int32_t output_strides[max_rank + 1]; - int32_t* rank; - int32_t batch; - CuFFtParams(int32_t dims, int32_t* r, const Stride& in_strides, // NOLINT - const Stride& out_strides, const Shape& in_shape, const Shape& out_shape, int32_t b) - : ndim(dims), rank(r), batch(b) { - std::copy(in_strides.begin(), in_strides.end(), input_strides); - std::copy(out_strides.begin(), out_strides.end(), output_strides); - std::copy(in_shape.begin(), in_shape.end(), input_shape); - std::copy(out_shape.begin(), out_shape.end(), output_shape); +struct CuFFTParams { + int64_t ndim; + cufft_dim_vector input_shape; + cufft_dim_vector input_strides; + cufft_dim_vector output_shape; + cufft_dim_vector output_strides; + cufft_dim_vector data_shape; + CUFFT_EXCUTETYPE excute_type; + DataType real_data_type; + + CuFFTParams() = default; + CuFFTParams(const Shape& in_shape, const Shape& out_shape, const Stride& in_strides, + const Stride& out_strides, int64_t dims, CUFFT_EXCUTETYPE type, DataType real) + : ndim(dims), excute_type(type), real_data_type(real) { + CHECK_OR_THROW(ndim >= 1 && ndim <= max_rank); + CHECK_OR_THROW(in_shape.size() == ndim + 1); + CHECK_OR_THROW(out_shape.size() == ndim + 1); + CHECK_OR_THROW(in_shape.size() == in_strides.size()); + CHECK_OR_THROW(out_shape.size() == out_strides.size()); + data_shape.resize(ndim + 1); + input_shape.resize(in_shape.size()); + input_strides.resize(in_strides.size()); + output_shape.resize(out_shape.size()); + output_strides.resize(out_strides.size()); + + std::copy(in_strides.begin(), in_strides.end(), input_strides.begin()); + std::copy(out_strides.begin(), out_strides.end(), output_strides.begin()); + std::copy(in_shape.begin(), in_shape.end(), input_shape.begin()); + std::copy(out_shape.begin(), out_shape.end(), output_shape.begin()); + + data_shape[0] = input_shape[0]; // batch size + FOR_RANGE(int64_t, i, 0, ndim) { + auto in_size = input_shape[i + 1]; + auto out_size = output_shape[i + 1]; + data_shape[i + 1] = std::max(in_size, out_size); + CHECK_OR_THROW(in_size == data_shape[i + 1] || in_size == (data_shape[i + 1] / 2) + 1); + CHECK_OR_THROW(out_size == data_shape[i + 1] || out_size == (data_shape[i + 1] / 2) + 1); + } } }; -template -class CuFFtConfig { +class CuFFTConfig { public: - CuFFtConfig(const CuFFtConfig&) = delete; - CuFFtConfig& operator=(CuFFtConfig const&) = delete; - ~CuFFtConfig() = default; - - explicit CuFFtConfig(CuFFtParams& params) { // NOLINT - infer_cufft_type_(); - cufftPlanMany(&plan_handle_, params.ndim, params.rank, params.input_shape, - params.input_strides[0], params.input_strides[1], params.output_shape, - params.output_strides[0], params.output_strides[1], exectype_, params.batch); - } + CuFFTConfig(const CuFFTConfig&) = delete; + CuFFTConfig& operator=(CuFFTConfig const&) = delete; + ~CuFFTConfig() = default; - void excute_plan(const T* in, C* out) { - switch (exectype_) { - case CUFFT_R2C: cufftExecR2C(plan_handle_, (cufftReal*)in, (cufftComplex*)out); break; + explicit CuFFTConfig(CuFFTParams& params) { // NOLINT - case CUFFT_D2Z: - cufftExecD2Z(plan_handle_, (cufftDoubleReal*)in, (cufftDoubleComplex*)out); - break; - default: break; + if (params.real_data_type == kBFloat16 || params.real_data_type == kFloat16) { + // CuFFT support half data type, but there are some limits: + // https://docs.nvidia.com/cuda/cufft/#half-precision-cufft-transforms + CHECK_OR_THROW(false) << "Unsupported datatype kBFloat16 and kFloat16."; } + + CuFFTDataLayout input_layout = as_cufft_embed(params.input_strides, params.data_shape, + params.excute_type == CUFFT_EXCUTETYPE::C2R); + CuFFTDataLayout output_layout = as_cufft_embed(params.output_strides, params.data_shape, + params.excute_type == CUFFT_EXCUTETYPE::R2C); + + bool clone_input = input_layout.must_clone; // that means: input should be contiguous because + // original input can't be embeded + const bool is_layout_simple = input_layout.simple && output_layout.simple; + + // disable cuFFT the default behavior of allocating work area at plan generating time + OF_CUFFT_CHECK(cufftSetAutoAllocation(plan_handle_.get(), 0)); + infer_cufft_type_(params.excute_type, params.real_data_type); + + // exclude input_shape[0] whtich is batch dim + cufft_dim_vector fft_shape(params.data_shape.begin() + 1, params.data_shape.end()); + cufft_size_type batch = params.data_shape[0]; + if (is_layout_simple) { + OF_CUFFT_CHECK(cufftXtMakePlanMany(plan_handle_.get(), params.ndim, fft_shape.data(), + /*inembed=*/nullptr, /*istride=*/1, /*idist=*/1, + /*inputtype=*/data_type_desc_.inputtype, + /*onembed=*/nullptr, /*ostride=*/1, /*odist=*/1, + /*outputtype=*/data_type_desc_.outputtype, + /*batch=*/batch, /*workSize=*/&work_size_, + /*executiontype=*/data_type_desc_.executiontype)); + } else { + OF_CUFFT_CHECK(cufftXtMakePlanMany( + plan_handle_.get(), params.ndim, fft_shape.data(), + /*inembed=*/input_layout.embed.data(), /*istride=*/input_layout.stride, + /*idist=*/input_layout.dist, /*inputtype=*/data_type_desc_.inputtype, + /*onembed=*/output_layout.embed.data(), /*ostride=*/output_layout.stride, + /*odist=*/output_layout.dist, /*outputtype=*/data_type_desc_.outputtype, + /*batch=*/batch, /*workSize=*/&work_size_, + /*executiontype=*/data_type_desc_.executiontype)); + } + } + + size_t workspace_size() const { return work_size_; } + const cufftHandle& plan() const { return plan_handle_.get(); } + + void excute(void* input, void* output, bool forward) { + OF_CUFFT_CHECK( + cufftXtExec(plan_handle_.get(), input, output, forward ? CUFFT_FORWARD : CUFFT_INVERSE)); } private: - // infer representing the FFT type(暂时只支持R2C,D2Z) - void infer_cufft_type_() { - bool isDouble = std::is_same::value; - if (isDouble) { - exectype_ = CUFFT_D2Z; + void infer_cufft_type_(CUFFT_EXCUTETYPE excute_type, DataType real_data_type) { + if (real_data_type == kFloat) { + data_type_desc_.executiontype = CUDA_C_32F; + data_type_desc_.inputtype = excute_type == CUFFT_EXCUTETYPE::R2C ? CUDA_R_32F : CUDA_C_32F; + data_type_desc_.outputtype = excute_type == CUFFT_EXCUTETYPE::C2R ? CUDA_R_32F : CUDA_C_32F; + } else if (real_data_type == kDouble) { + data_type_desc_.executiontype = CUDA_C_64F; + data_type_desc_.inputtype = excute_type == CUFFT_EXCUTETYPE::R2C ? CUDA_R_64F : CUDA_C_64F; + data_type_desc_.outputtype = excute_type == CUFFT_EXCUTETYPE::C2R ? CUDA_R_64F : CUDA_C_64F; } else { - exectype_ = CUFFT_R2C; + CHECK_OR_THROW(false) << "cuFFT doesn't support type " << real_data_type; } } - cufftHandle plan_handle_; - cufftType exectype_; + CuFFTHandle plan_handle_; + CuFFTDataTypeDesc data_type_desc_; + size_t work_size_; }; } // namespace oneflow diff --git a/oneflow/user/kernels/fft_kernel_util.cpp b/oneflow/user/kernels/fft_kernel_util.cpp new file mode 100644 index 00000000000..99b1ebe5055 --- /dev/null +++ b/oneflow/user/kernels/fft_kernel_util.cpp @@ -0,0 +1,172 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/user/kernels/fft_kernel_util.h" +#include +#include "pocketfftplan.h" +#include "oneflow/core/common/device_type.pb.h" +#include "oneflow/core/common/preprocessor.h" +#include "oneflow/core/framework/user_op_tensor.h" + +namespace oneflow { + +template +static void _conj_symmetry_cpu(T* data_out, const Shape& shape, const std::vector& strides, + const int64_t last_dim, int64_t elem_count) { + const oneflow::NdIndexStrideOffsetHelper helper(strides.data(), + shape.size()); + // NOTE: dims must be sorted + int64_t last_dim_size = shape[last_dim]; + int64_t last_dim_half = last_dim_size / 2; + + int64_t ndim = shape.size(); + std::vector indices(ndim); + for (int offset = 0; offset < elem_count; offset++) { + helper.OffsetToNdIndex(offset, indices.data(), ndim); + if (indices[last_dim] <= last_dim_half) { continue; } + + int64_t cur_last_dim_index = indices[last_dim]; + // get symmetric + indices[last_dim] = last_dim_size - cur_last_dim_index; + int64_t symmetric_offset = helper.NdIndexToOffset(indices.data(), ndim); + + // conj + data_out[offset] = std::conj(data_out[symmetric_offset]); + } +} + +template +struct FillConjSymmetryUtil { + static void FillConjSymmetryForward(ep::Stream* stream, T* data_out, const Shape& shape, + const Stride& strides, const int64_t last_dim, + int64_t elem_count) { + std::vector strides_vec(strides.begin(), strides.end()); + _conj_symmetry_cpu(/*data_out*/ data_out, /*shape*/ shape, /*strides*/ strides_vec, + /*last_dim*/ last_dim, /*elem_count*/ elem_count); + } +}; + +template +struct ComplexConvertUtil { + static void ConvertToDoubleSized(ep::Stream* stream, const complex_type* in, complex_type* dst, + size_t len, size_t n) { + size_t fact_len = 2 * len - 2; // input_shape.back() + for (int i = 0; i < n; i++) { + int index_x = i / fact_len; + int index_y = i % fact_len; + if (index_y == 0) { + dst[i] = in[index_x * len]; + } else if (index_y == len - 1) { + dst[i] = in[(index_x + 1) * len - 1]; + } else if (index_y < len - 1 && index_y > 0) { + dst[i] = in[index_x * len + index_y]; + } else { + auto index = (index_x + 2) * len - index_y - 2; + auto realvalue = in[index].real(); + dst[i].real(realvalue); + auto imagvalue = -in[index].imag(); + dst[i].imag(imagvalue); + } + } + } + static void ConvertComplexToReal(ep::Stream* stream, const complex_type* in, real_type* out, + size_t n) { + for (int i = 0; i < n; i++) { + out[2 * i] = in[i].real(); + out[2 * i + 1] = in[i].imag(); + } + } +}; + +template +struct FftC2CKernelUtil { + static void FftC2CForward(ep::Stream* stream, const T* data_in, T* data_out, + const Shape& input_shape, const Shape& output_shape, + const Stride& input_stride, const Stride& output_stride, bool forward, + const std::vector& dims, FCT_TYPE norm_fct, + DataType real_type) { + PocketFFtParams params(input_shape, output_shape, input_stride, output_stride, dims, + forward, norm_fct /*1.f*/, FFT_EXCUTETYPE::C2C); + PocketFFtConfig config(params); + config.excute(data_in, data_out); + } +}; + +template +struct FftR2CKernelUtil { + static void FftR2CForward(ep::Stream* stream, const IN* data_in, OUT* data_out, + const Shape& input_shape, const Shape& output_shape, + const Stride& input_stride, const Stride& output_stride, bool forward, + const std::vector& dims, IN norm_fct, DataType real_type) { + PocketFFtParams params(input_shape, output_shape, input_stride, output_stride, dims, + forward, norm_fct /*1.f*/, FFT_EXCUTETYPE::R2C); + PocketFFtConfig config(params); + config.excute(data_in, data_out); + } +}; + +template +struct FftC2RKernelUtil { + static void FftC2RForward(ep::Stream* stream, const IN* data_in, OUT* data_out, + const Shape& input_shape, const Shape& output_shape, + const Stride& input_stride, const Stride& output_stride, bool forward, + int64_t last_dim_size, const std::vector& dims, OUT norm_fct, + DataType real_type) { + PocketFFtParams params(input_shape, output_shape, input_stride, output_stride, dims, + /*is_forward=*/false, norm_fct /*1.f*/, FFT_EXCUTETYPE::C2R); + PocketFFtConfig config(params); + config.excute(data_in, data_out); + } +}; + +template +struct FftStftKernelUtil { + static void FftStftForward(ep::Stream* stream, const IN* data_in, OUT* data_out, + const Shape& input_shape, const Shape& output_shape, + const Stride& input_stride, const Stride& output_stride, bool forward, + const std::vector& axes, IN norm_fct, int64_t len, + int64_t dims, int64_t batch) { + PocketFFtParams params(input_shape, output_shape, input_stride, output_stride, axes, + forward, norm_fct /*1.f*/, FFT_EXCUTETYPE::R2C); + PocketFFtConfig config(params); + int64_t in_offset = len; + int64_t out_offset = len / 2 + 1; + for (int j = 0; j < dims; j++) { + for (int i = 0; i < batch; i++) { + const IN* in = data_in + j * batch * in_offset + i * in_offset; + OUT* out = data_out + j * batch * out_offset + i * out_offset; + config.excute(in, out); + } + } + } +}; +template struct FillConjSymmetryUtil>; +template struct FillConjSymmetryUtil>; + +template struct ComplexConvertUtil>; +template struct ComplexConvertUtil>; + +template struct FftC2CKernelUtil, float>; +template struct FftC2CKernelUtil, double>; + +template struct FftR2CKernelUtil>; +template struct FftR2CKernelUtil>; + +template struct FftC2RKernelUtil, float>; +template struct FftC2RKernelUtil, double>; + +template struct FftStftKernelUtil>; +template struct FftStftKernelUtil>; +} // namespace oneflow \ No newline at end of file diff --git a/oneflow/user/kernels/fft_kernel_util.cu b/oneflow/user/kernels/fft_kernel_util.cu new file mode 100644 index 00000000000..2fa47b02b68 --- /dev/null +++ b/oneflow/user/kernels/fft_kernel_util.cu @@ -0,0 +1,313 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include "oneflow/core/device/cuda_util.h" +#include "oneflow/core/framework/user_op_tensor.h" +#include "oneflow/user/kernels/to_contiguous_kernel.h" + +#if CUDA_VERSION >= 11000 +#include "oneflow/user/kernels/fft_kernel_util.h" +#include "cufft_plan_cache.h" + +namespace oneflow { + +namespace { +template +__global__ void fft_apply_normalization(FFTTYPE* dst, const double normalization_scale, size_t n, + bool IsNormalized) { + if (!IsNormalized) { return; } + CUDA_1D_KERNEL_LOOP(i, n) { + dst[i].x *= normalization_scale; + dst[i].y *= normalization_scale; + }; +} + +struct FillConjSymmetricParams { + int64_t last_dim; + int64_t elem_count; + int64_t ndim; + oneflow::NdIndexStrideOffsetHelper helper; + int64_t last_dim_size; + int64_t last_dim_half; + + FillConjSymmetricParams() = default; + FillConjSymmetricParams(const Shape& shape, const Stride& strides, int64_t last_dim_, + int64_t elemcnt) + : last_dim(last_dim_), + elem_count(elemcnt), + ndim(strides.size()), + helper(strides.data(), ndim) { + CHECK_OR_THROW(strides.size() == shape.size()); + last_dim_size = shape[last_dim]; + last_dim_half = last_dim_size / 2; + } +}; + +} // namespace + +template +__global__ void _conj_symmetry_cuda(T* data_out, FillConjSymmetricParams param) { + CUDA_1D_KERNEL_LOOP_T(int64_t, offset, param.elem_count) { + int64_t ndim = param.ndim; + int64_t indices[SHAPE_MAX_AXIS_SIZE]; + param.helper.OffsetToNdIndex(offset, indices, ndim); + if (indices[param.last_dim] <= param.last_dim_half) { continue; } + int64_t cur_last_dim_index = indices[param.last_dim]; + // get symmetric + indices[param.last_dim] = param.last_dim_size - cur_last_dim_index; + int64_t symmetric_offset = param.helper.NdIndexToOffset(indices, ndim); + + // conj + data_out[offset] = T{data_out[symmetric_offset].x, -data_out[symmetric_offset].y}; + } +} + +template +struct FillConjSymmetryUtil { + static void FillConjSymmetryForward(ep::Stream* stream, T* data_out, const Shape& shape, + const Stride& strides, const int64_t last_dim, + int64_t elem_count) { + FillConjSymmetricParams param(shape, strides, last_dim, elem_count); + _conj_symmetry_cuda<<As()->cuda_stream()>>>(data_out, param); + } +}; + +template +__global__ void _convert_to_double_sized(const IN* in, OUT* dst, size_t len, size_t n) { + size_t fact_len = 2 * len - 2; + CUDA_1D_KERNEL_LOOP(i, n) { + int index_x = i / fact_len; + int index_y = i % fact_len; + if (index_y == 0) { + dst[i] = in[index_x * len]; + } else if (index_y == len - 1) { + dst[i] = in[(index_x + 1) * len - 1]; + } else if (index_y < len - 1 && index_y > 0) { + dst[i] = in[index_x * len + index_y]; + } else { + auto index = (index_x + 2) * len - index_y - 2; + dst[i].x = in[index].x; + dst[i].y = -in[index].y; + } + } +} + +template +__global__ void _convert_complex_to_real(const IN* in, OUT* out, size_t n) { + CUDA_1D_KERNEL_LOOP(i, n) { + out[2 * i] = in[i].x; + out[2 * i + 1] = in[i].y; + }; +} + +template +struct ComplexConvertUtil { + static void ConvertToDoubleSized(ep::Stream* stream, const complex_type* in, complex_type* dst, + size_t len, size_t n) { + _convert_to_double_sized<<As()->cuda_stream()>>>(in, dst, len, n); + } + static void ConvertComplexToReal(ep::Stream* stream, const complex_type* in, real_type* out, + size_t n) { + _convert_complex_to_real<<As()->cuda_stream()>>>(in, out, n); + } +}; + +template +class StftGpuKernel final : public user_op::OpKernel { + public: + StftGpuKernel() = default; + ~StftGpuKernel() = default; + + private: + using user_op::OpKernel::Compute; + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); + user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0); + user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + const bool normalized = ctx->Attr("normalized"); + const bool onesided = ctx->Attr("onesided"); + const bool return_complex = ctx->Attr("return_complex"); + + const ShapeView& input_shape = input->shape_view(); + const ShapeView& output_shape = output->shape_view(); + + const Stride& input_stride = input->stride(); + const int out_elem_cnt = + return_complex ? output->shape_view().elem_cnt() : output->shape_view().elem_cnt() / 2; + + const dtype_in* data_in = input->dptr(); + dtype_in* data_out = output->mut_dptr(); + dtype_out* out_tmp_buffer = reinterpret_cast(tmp_buffer->mut_dptr()); + + int64_t ndim = 1; + int64_t batch = static_cast(input_shape.At(1)); + int64_t fft_size = static_cast(input_shape.At(2)); + int64_t rank[1] = {fft_size}; + const Stride& in_stride = {input_stride.at(1), input_stride.at(2)}; + const Shape& in_shape = {batch, fft_size}; + const Shape& out_shape = {batch, fft_size / 2 + 1}; + Stride out_stride = Stride(out_shape); + CuFFTParams params(in_shape, out_shape, in_stride, out_stride, ndim, CUFFT_EXCUTETYPE::R2C, + input->data_type()); + CuFFTConfig config(params); + auto& plan = config.plan(); + OF_CUFFT_CHECK(cufftSetStream(plan, ctx->stream()->As()->cuda_stream())); + void* workspace{}; + OF_CUDA_CHECK(cudaMalloc(&workspace, config.workspace_size())); + OF_CUFFT_CHECK(cufftSetWorkArea(plan, workspace)); + + int64_t in_offset = input_stride.at(0); + int64_t out_offset = + std::accumulate(out_shape.begin(), out_shape.end(), 0, std::multiplies()); + int64_t signal_groups_count = static_cast(input_shape.At(0)); + for (int64_t i = 0; i < signal_groups_count; i++) { + config.excute((void*)(data_in + i * in_offset), (void*)(out_tmp_buffer + i * out_offset), + /*forward=*/true); + } + OF_CUDA_CHECK(cudaFree(workspace)); + + if (!onesided) { + size_t last_dim_length = fft_size / 2 + 1; + dtype_out* doublesided_tmp_buffer = + reinterpret_cast(tmp_buffer->mut_dptr()) + out_elem_cnt; + ComplexConvertUtil::ConvertToDoubleSized( + ctx->stream(), out_tmp_buffer, doublesided_tmp_buffer, last_dim_length, out_elem_cnt); + out_tmp_buffer = doublesided_tmp_buffer; + } + + const double normalization_scale = + _fft_normalization_scale(input_shape.back(), normalized); + fft_apply_normalization<<stream()->As()->cuda_stream()>>>( + out_tmp_buffer, normalization_scale, out_elem_cnt, normalized); + + if (!return_complex) { + ComplexConvertUtil::ConvertComplexToReal( + ctx->stream(), out_tmp_buffer, data_out, out_elem_cnt); + } else { + // TODO(yzm):support return_complex after oneflow supports complex numbers + } + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_STFT_GPU_KERNEL(intype, outtype) \ + REGISTER_USER_KERNEL("stft") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("input", 0) == GetDataType::value)) \ + .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ + const Shape& output_shape = ctx->InputShape("output", 0); \ + const bool return_complex = ctx->Attr("return_complex"); \ + const bool onesided = ctx->Attr("onesided"); \ + int64_t output_elem_cnt = \ + return_complex ? output_shape.elem_cnt() : output_shape.elem_cnt() / 2; \ + const int64_t output_bytes = GetCudaAlignedSize(output_elem_cnt * sizeof(outtype)); \ + return onesided ? output_bytes : 2 * output_bytes; \ + }); + +REGISTER_STFT_GPU_KERNEL(float, cufftComplex) +REGISTER_STFT_GPU_KERNEL(double, cufftDoubleComplex) + +template +class FftC2CKernelUtil { + static void FftC2CForward(ep::Stream* stream, const T* data_in, T* data_out, + const Shape& input_shape, const Shape& output_shape, + const Stride& input_stride, const Stride& output_stride, bool forward, + const std::vector& dims, FCT_TYPE normalization, + DataType real_type) { + // NOTE: before calling `FftC2CKernelUtil`, input must be + // batched out already + CuFFTParams params(input_shape, output_shape, input_stride, output_stride, dims.size(), + CUFFT_EXCUTETYPE::C2C, real_type); + CuFFTConfig config(params); + auto& plan = config.plan(); + OF_CUFFT_CHECK(cufftSetStream(plan, stream->As()->cuda_stream())); + void* workspace{}; + OF_CUDA_CHECK(cudaMalloc(&workspace, config.workspace_size())); + OF_CUFFT_CHECK(cufftSetWorkArea(plan, workspace)); + + config.excute((void*)data_in, (void*)data_out, forward); + OF_CUDA_CHECK(cudaFree(workspace)); + } +}; + +template +struct FftR2CKernelUtil { + static void FftR2CForward(ep::Stream* stream, const IN* data_in, OUT* data_out, + const Shape& input_shape, const Shape& output_shape, + const Stride& input_stride, const Stride& output_stride, bool forward, + const std::vector& dims, IN normalization, + DataType real_type) { + // NOTE: before calling `FftR2CKernelUtil`, input must be batched + // out already + CuFFTParams params(input_shape, output_shape, input_stride, output_stride, dims.size(), + CUFFT_EXCUTETYPE::R2C, real_type); + CuFFTConfig config(params); + auto& plan = config.plan(); + OF_CUFFT_CHECK(cufftSetStream(plan, stream->As()->cuda_stream())); + void* workspace{}; + OF_CUDA_CHECK(cudaMalloc(&workspace, config.workspace_size())); + OF_CUFFT_CHECK(cufftSetWorkArea(plan, workspace)); + + config.excute((void*)data_in, (void*)data_out, forward); + OF_CUDA_CHECK(cudaFree(workspace)); + } +}; + +template +struct FftC2RKernelUtil { + static void FftC2RForward(ep::Stream* stream, const IN* data_in, OUT* data_out, + const Shape& input_shape, const Shape& output_shape, + const Stride& input_stride, const Stride& output_stride, bool forward, + int64_t last_dim_size, const std::vector& dims, + OUT normalization, DataType real_type) { + // NOTE: before calling `FftC2RKernelUtil`, input must be batched + // out already + CuFFTParams params(input_shape, output_shape, input_stride, output_stride, dims.size(), + CUFFT_EXCUTETYPE::C2R, real_type); + CuFFTConfig config(params); + auto& plan = config.plan(); + OF_CUFFT_CHECK(cufftSetStream(plan, stream->As()->cuda_stream())); + void* workspace{}; + OF_CUDA_CHECK(cudaMalloc(&workspace, config.workspace_size())); + OF_CUFFT_CHECK(cufftSetWorkArea(plan, workspace)); + + config.excute((void*)data_in, (void*)data_out, forward); + OF_CUDA_CHECK(cudaFree(workspace)); + } +}; + +template struct FillConjSymmetryUtil; +template struct FillConjSymmetryUtil; + +template struct ComplexConvertUtil; +template struct ComplexConvertUtil; + +template struct FftC2CKernelUtil; +template struct FftC2CKernelUtil; + +template struct FftR2CKernelUtil; +template struct FftR2CKernelUtil; + +template struct FftC2RKernelUtil; +template struct FftC2RKernelUtil; +} // namespace oneflow + +#endif // CUDA_VERSION >= 11000 diff --git a/oneflow/user/kernels/fft_kernel_util.h b/oneflow/user/kernels/fft_kernel_util.h new file mode 100644 index 00000000000..6ae2783613c --- /dev/null +++ b/oneflow/user/kernels/fft_kernel_util.h @@ -0,0 +1,87 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_KERNELS_FFT_KERNEL_UTIL_H_ +#define ONEFLOW_USER_KERNELS_FFT_KERNEL_UTIL_H_ + +#include +#include +#include "oneflow/core/common/data_type.pb.h" +#include "oneflow/core/kernel/kernel_util.h" +#include "oneflow/core/common/shape_view.h" +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/framework/op_kernel.h" +#include "oneflow/core/common/nd_index_offset_helper.h" + +namespace oneflow { + +template +inline T _fft_normalization_scale(const int32_t frame_length, bool normalized) { + if (!normalized) { return static_cast(1.0); } + return static_cast(1.0 / std::sqrt(frame_length)); +} + +template +struct FillConjSymmetryUtil { + static void FillConjSymmetryForward(ep::Stream* stream, T* data_out, const Shape& shape, + const Stride& strides, const int64_t last_dim, + int64_t elem_count); +}; + +template +struct ComplexConvertUtil { + static void ConvertToDoubleSized(ep::Stream* stream, const complex_type* in, complex_type* dst, + size_t len, size_t n); + static void ConvertComplexToReal(ep::Stream* stream, const complex_type* in, real_type* out, + size_t n); +}; + +template +struct FftC2CKernelUtil { + static void FftC2CForward(ep::Stream* stream, const T* data_in, T* data_out, + const Shape& input_shape, const Shape& output_shape, + const Stride& input_stride, const Stride& output_stride, bool forward, + const std::vector& dims, FCT_TYPE norm_fct, + DataType real_type); +}; + +template +struct FftR2CKernelUtil { + static void FftR2CForward(ep::Stream* stream, const IN* data_in, OUT* data_out, + const Shape& input_shape, const Shape& output_shape, + const Stride& input_stride, const Stride& output_stride, bool forward, + const std::vector& dims, IN norm_fct, DataType real_type); +}; + +template +struct FftC2RKernelUtil { + static void FftC2RForward(ep::Stream* stream, const IN* data_in, OUT* data_out, + const Shape& input_shape, const Shape& output_shape, + const Stride& input_stride, const Stride& output_stride, bool forward, + int64_t last_dim_size, const std::vector& dims, OUT norm_fct, + DataType real_type); +}; + +template +struct FftStftKernelUtil { + static void FftStftForward(ep::Stream* stream, const IN* data_in, OUT* data_out, + const Shape& input_shape, const Shape& output_shape, + const Stride& input_stride, const Stride& output_stride, bool forward, + const std::vector& axes, IN norm_fct, int64_t len, + int64_t dims, int64_t batch); +}; + +} // namespace oneflow +#endif // ONEFLOW_USER_KERNELS_FFT_KERNEL_UTIL_H_ \ No newline at end of file diff --git a/oneflow/user/kernels/fft_kernels.cpp b/oneflow/user/kernels/fft_kernels.cpp new file mode 100644 index 00000000000..f38dfe0ae9b --- /dev/null +++ b/oneflow/user/kernels/fft_kernels.cpp @@ -0,0 +1,245 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include +#include "pocketfftplan.h" +#include "oneflow/core/common/stride.h" +#include "oneflow/user/kernels/fft_kernel_util.h" + +using namespace pocketfft; +namespace oneflow { + +template +class FftC2CKernel final : public user_op::OpKernel { + public: + FftC2CKernel() = default; + ~FftC2CKernel() = default; + + private: + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); + user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + bool forward = ctx->Attr("forward"); + double norm_fct = ctx->Attr("norm_fct"); + + const std::vector& dims = ctx->Attr>("dims"); + + const T* input_ptr = input->dptr(); + T* out_ptr = out->mut_dptr(); + + Shape input_shape(input->shape_view()); + Shape out_shape(out->shape_view()); + + if (input->data_type() == kComplex64) { + FftC2CKernelUtil::FftC2CForward( + ctx->stream(), input_ptr, out_ptr, input_shape, out_shape, input->stride(), out->stride(), + forward, dims, static_cast(norm_fct), DataType::kFloat); + } else if (input->data_type() == kComplex128) { + FftC2CKernelUtil::FftC2CForward( + ctx->stream(), input_ptr, out_ptr, input_shape, out_shape, input->stride(), out->stride(), + forward, dims, static_cast(norm_fct), DataType::kDouble); + } else { + CHECK_OR_THROW(false) << "expects kComplex64 or kComplex128, but got " << input->data_type(); + } + } +}; + +template +class FftR2CKernel final : public user_op::OpKernel { + public: + FftR2CKernel() = default; + ~FftR2CKernel() = default; + + private: + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); + user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + bool onesided = ctx->Attr("onesided"); + double norm_fct = ctx->Attr("norm_fct"); + const std::vector& dims = ctx->Attr>("dims"); + const dtype_in* input_ptr = input->dptr(); + dtype_out* out_ptr = out->mut_dptr(); + + Shape input_shape(input->shape_view()); + Shape out_shape(out->shape_view()); + + if (input->data_type() == kFloat || input->data_type() == kDouble) { + FftR2CKernelUtil::FftR2CForward( + ctx->stream(), input_ptr, out_ptr, input_shape, out_shape, input->stride(), out->stride(), + /*forward=*/true, dims, norm_fct, /*real_type=*/input->data_type()); + } else { + CHECK_OR_THROW(false) << "expects kFloat or kDouble, but gets " << input->data_type(); + } + + if (!onesided) { + FillConjSymmetryUtil::FillConjSymmetryForward( + ctx->stream(), out_ptr, out_shape, out->stride(), dims.back(), out_shape.elem_cnt()); + } + } +}; + +template +class FftC2RKernel final : public user_op::OpKernel { + public: + FftC2RKernel() = default; + ~FftC2RKernel() = default; + + private: + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); + user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + int64_t last_dim_size = ctx->Attr("last_dim_size"); + double norm_fct = ctx->Attr("norm_fct"); + const std::vector& dims = ctx->Attr>("dims"); + + const dtype_in* input_ptr = input->dptr(); + dtype_out* out_ptr = out->mut_dptr(); + + Shape input_shape(input->shape_view()); + Shape out_shape(out->shape_view()); + + out_shape[dims.back()] = last_dim_size; + + if (input->data_type() == kComplex64 || input->data_type() == kComplex128) { + FftC2RKernelUtil::FftC2RForward( + ctx->stream(), input_ptr, out_ptr, input_shape, out_shape, input->stride(), out->stride(), + /*forward=*/false, + /*last_dim_size=*/last_dim_size, dims, norm_fct, /*real_type=*/out->data_type()); + } else { + CHECK_OR_THROW(false) << "expects kComplex64 or kComplex128, but gets " << input->data_type(); + } + } +}; + +template +class StftCpuKernel final : public user_op::OpKernel { + public: + StftCpuKernel() = default; + ~StftCpuKernel() = default; + + private: + using user_op::OpKernel::Compute; + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); + user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0); + user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + const auto normalized = ctx->Attr("normalized"); + const auto return_complex = ctx->Attr("return_complex"); + const bool onesided = ctx->Attr("onesided"); + + const ShapeView input_shape = input->shape_view(); + const ShapeView output_shape = output->shape_view(); + const auto output_elem_cnt = output_shape.elem_cnt() / 2; + + int64_t dims = input_shape.At(0); + int64_t batch = input_shape.At(1); + int64_t len = input_shape.back(); + const dtype_in* data_in = input->dptr(); + dtype_in* data_out = output->mut_dptr(); + + dtype_out* out_tmp_buffer = reinterpret_cast(tmp_buffer->mut_dptr()); + Shape out_tmp_shape = Shape{len}; + Stride out_tmp_stride = Stride(out_tmp_shape); + std::vector axes(out_tmp_shape.size()); + std::iota(axes.begin(), axes.end(), 0); + auto norm_fct = _fft_normalization_scale(len, normalized); + FftStftKernelUtil::FftStftForward( + ctx->stream(), data_in, out_tmp_buffer, out_tmp_shape, out_tmp_shape, out_tmp_stride, + out_tmp_stride, true, /*axes=*/axes, /*norm_fct=*/norm_fct, + /*len=*/len, /*dims=*/dims, /*batch=*/batch); + + if (!onesided) { + dtype_out* doublesided_tmp_buffer = + reinterpret_cast(tmp_buffer->mut_dptr()) + output_elem_cnt; + size_t last_dim_length = len / 2 + 1; + size_t elem_conut = output_elem_cnt; + ComplexConvertUtil::ConvertToDoubleSized( + ctx->stream(), out_tmp_buffer, doublesided_tmp_buffer, last_dim_length, elem_conut); + out_tmp_buffer = doublesided_tmp_buffer; + } + + if (!return_complex) { + ComplexConvertUtil::ConvertComplexToReal( + ctx->stream(), out_tmp_buffer, data_out, output_elem_cnt); + } + } + + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_STFT_CPU_KERNEL(dtype_in, dtype_out) \ + REGISTER_USER_KERNEL("stft") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == kCPU) \ + && (user_op::HobDataType("input", 0) == GetDataType::value)) \ + .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ + const Shape& output_shape = ctx->InputShape("output", 0); \ + const bool return_complex = ctx->Attr("return_complex"); \ + const bool onesided = ctx->Attr("onesided"); \ + int64_t output_elem_cnt = \ + return_complex ? output_shape.elem_cnt() : output_shape.elem_cnt() / 2; \ + const int64_t output_bytes = (output_elem_cnt * sizeof(dtype_out)); \ + return onesided ? output_bytes : 2 * output_bytes; \ + }); + +REGISTER_STFT_CPU_KERNEL(double, std::complex) +REGISTER_STFT_CPU_KERNEL(float, std::complex) + +#define REGISTER_FFTC2C_KERNELS(device_type, dtype, fct_type) \ + REGISTER_USER_KERNEL("fft_c2c") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == device_type) \ + && (user_op::HobDataType("input", 0) == GetDataType::value) \ + && (user_op::HobDataType("out", 0) == GetDataType::value)) + +REGISTER_FFTC2C_KERNELS(DeviceType::kCPU, std::complex, float); +REGISTER_FFTC2C_KERNELS(DeviceType::kCPU, std::complex, double); +#ifdef WITH_CUDA +REGISTER_FFTC2C_KERNELS(DeviceType::kCUDA, cuComplex, float); +REGISTER_FFTC2C_KERNELS(DeviceType::kCUDA, cuDoubleComplex, double); +#endif + +#define REGISTER_FFTR2C_KERNELS(device_type, dtype_in, dtype_out) \ + REGISTER_USER_KERNEL("fft_r2c") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == device_type) \ + && (user_op::HobDataType("input", 0) == GetDataType::value) \ + && (user_op::HobDataType("out", 0) == GetDataType::value)) + +REGISTER_FFTR2C_KERNELS(DeviceType::kCPU, float, std::complex); +REGISTER_FFTR2C_KERNELS(DeviceType::kCPU, double, std::complex); +#ifdef WITH_CUDA +REGISTER_FFTR2C_KERNELS(DeviceType::kCUDA, float, cuComplex); +REGISTER_FFTR2C_KERNELS(DeviceType::kCUDA, double, cuDoubleComplex); +#endif + +#define REGISTER_FFTC2R_KERNELS(device_type, dtype_in, dtype_out) \ + REGISTER_USER_KERNEL("fft_c2r") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == device_type) \ + && (user_op::HobDataType("input", 0) == GetDataType::value) \ + && (user_op::HobDataType("out", 0) == GetDataType::value)) + +REGISTER_FFTC2R_KERNELS(DeviceType::kCPU, std::complex, float); +REGISTER_FFTC2R_KERNELS(DeviceType::kCPU, std::complex, double); +#ifdef WITH_CUDA +REGISTER_FFTC2R_KERNELS(DeviceType::kCUDA, cuComplex, float); +REGISTER_FFTC2R_KERNELS(DeviceType::kCUDA, cuDoubleComplex, double); +#endif +} // namespace oneflow \ No newline at end of file diff --git a/oneflow/user/kernels/pocketfftplan.h b/oneflow/user/kernels/pocketfftplan.h index 89a5a5ecf10..cbb386c3118 100644 --- a/oneflow/user/kernels/pocketfftplan.h +++ b/oneflow/user/kernels/pocketfftplan.h @@ -14,87 +14,80 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "pocketfft_hdronly.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" -#include "oneflow/core/ep/cuda/cuda_stream.h" -#include "pocketfft_hdronly.h" #include "oneflow/core/kernel/kernel.h" -using namespace pocketfft; +#include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { namespace { -enum class FFT_EXCUTETYPE { R2C, C2C }; +enum class FFT_EXCUTETYPE { R2C, C2C, C2R }; -template +template struct PocketFFtParams { - shape_t input_shape; - shape_t output_shape; - stride_t in_stridef; - stride_t out_stridef; - shape_t axes; bool IsForward; FFT_EXCUTETYPE excute_type; - IN fct; + dtype fct; + pocketfft::shape_t axes; + pocketfft::stride_t in_stridef; + pocketfft::stride_t out_stridef; + pocketfft::shape_t input_shape; + pocketfft::shape_t output_shape; PocketFFtParams() = default; - PocketFFtParams(const Shape& in_shape, const Shape& out_shape, const bool is_froward, const IN f, - FFT_EXCUTETYPE type) - : IsForward(is_froward), excute_type(type), fct(f) { + PocketFFtParams(const Shape& in_shape, const Shape& out_shape, const Stride& in_stride, + const Stride& out_stride, const std::vector& dims, const bool is_forward, + const dtype f, FFT_EXCUTETYPE type) + : IsForward(is_forward), + excute_type(type), + fct(f), + axes(dims.begin(), dims.end()), + in_stridef(in_stride.begin(), in_stride.end()), + out_stridef(out_stride.begin(), out_stride.end()) { input_shape.resize(in_shape.size()); output_shape.resize(out_shape.size()); - in_stridef.resize(input_shape.size()); - out_stridef.resize(output_shape.size()); - axes.resize(input_shape.size()); std::copy(in_shape.begin(), in_shape.end(), input_shape.begin()); std::copy(out_shape.begin(), out_shape.end(), output_shape.begin()); - std::iota(axes.begin(), axes.end(), 0); - - size_t out_tmpf = sizeof(OUT); - size_t in_tmpf = sizeof(IN); - for (int i = input_shape.size() - 1; i >= 0; --i) { - in_stridef[i] = in_tmpf; - in_tmpf *= input_shape[i]; - out_stridef[i] = out_tmpf; - out_tmpf *= output_shape[i]; - } + + // calc element size + size_t in_elemsize = type == FFT_EXCUTETYPE::C2C || type == FFT_EXCUTETYPE::C2R + ? sizeof(std::complex) + : sizeof(dtype); + size_t out_elemsize = type == FFT_EXCUTETYPE::R2C || type == FFT_EXCUTETYPE::C2C + ? sizeof(std::complex) + : sizeof(dtype); + for (auto& s : in_stridef) { s *= in_elemsize; } + for (auto& s : out_stridef) { s *= out_elemsize; } } }; -template +template class PocketFFtConfig { public: PocketFFtConfig(const PocketFFtConfig&) = delete; PocketFFtConfig& operator=(PocketFFtConfig const&) = delete; - explicit PocketFFtConfig(const PocketFFtParams& params) : fftparams(params) {} - - void excute(const IN* in, OUT* out, int64_t dims, int64_t batch, int64_t len) { - int64_t in_offset = len; - int64_t out_offset = len / 2 + 1; - for (int j = 0; j < dims; j++) { - for (int i = 0; i < batch; i++) { - const IN* data_in = in + j * batch * in_offset + i * in_offset; - OUT* data_out = out + j * batch * out_offset + i * out_offset; - switch (fftparams.excute_type) { - case FFT_EXCUTETYPE::R2C: - r2c(fftparams.input_shape, fftparams.in_stridef, fftparams.out_stridef, fftparams.axes, - fftparams.IsForward, data_in, data_out, fftparams.fct); - break; - - case FFT_EXCUTETYPE::C2C: - // c2c(fftparams.input_shape, fftparams.in_stridef, fftparams.out_stridef, - // fftparams.axes, fftparams.IsForward, in, - // out, fftparams.fct); - break; - default: break; - } - } - } + explicit PocketFFtConfig(const PocketFFtParams& params) : fftparams(params) {} + + void excute(const std::complex* in, std::complex* out) { + pocketfft::c2c(fftparams.input_shape, fftparams.in_stridef, fftparams.out_stridef, + fftparams.axes, fftparams.IsForward, in, out, fftparams.fct); + } + + void excute(const dtype* in, std::complex* out) { + pocketfft::r2c(fftparams.input_shape, fftparams.in_stridef, fftparams.out_stridef, + fftparams.axes, fftparams.IsForward, in, out, fftparams.fct); + } + + void excute(const std::complex* in, dtype* out) { + pocketfft::c2r(fftparams.output_shape, fftparams.in_stridef, fftparams.out_stridef, + fftparams.axes, fftparams.IsForward, in, out, fftparams.fct); } private: - PocketFFtParams fftparams; + PocketFFtParams fftparams; }; } // namespace diff --git a/oneflow/user/kernels/stft_kernel.cpp b/oneflow/user/kernels/stft_kernel.cpp deleted file mode 100644 index 6ac06238457..00000000000 --- a/oneflow/user/kernels/stft_kernel.cpp +++ /dev/null @@ -1,138 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "oneflow/core/framework/framework.h" -#include "pocketfftplan.h" -using namespace pocketfft; -namespace oneflow { - -namespace { - -enum class fft_norm_mode { - none, // No normalization - by_root_n, // Divide by sqrt(signal_size) - by_n, // Divide by signal_size -}; - -template -T compute_fct(int64_t size, fft_norm_mode normalization) { - constexpr auto one = static_cast(1); - switch (normalization) { - case fft_norm_mode::none: return one; - case fft_norm_mode::by_n: return one / static_cast(size); - case fft_norm_mode::by_root_n: return one / std::sqrt(static_cast(size)); - } - return static_cast(0); -} -template -void convert_to_doublesized(const std::complex* in, std::complex* dst, size_t len, size_t n) { - size_t fact_len = 2 * len - 2; - for (int i = 0; i < n; i++) { - int index_x = i / fact_len; - int index_y = i % fact_len; - if (index_y == 0) { - dst[i] = in[index_x * len]; - } else if (index_y == len - 1) { - dst[i] = in[(index_x + 1) * len - 1]; - } else if (index_y < len - 1 && index_y > 0) { - dst[i] = in[index_x * len + index_y]; - } else { - auto index = (index_x + 2) * len - index_y - 2; - auto realvalue = in[index].real(); - dst[i].real(realvalue); - auto imagvalue = -in[index].imag(); - dst[i].imag(imagvalue); - } - } -} - -template -void comvert_to_real(const std::complex* in, T* out, size_t n) { - for (int i = 0; i < n; i++) { - out[2 * i] = in[i].real(); - out[2 * i + 1] = in[i].imag(); - } -} - -template -class StftCpuKernel final : public user_op::OpKernel { - public: - StftCpuKernel() = default; - ~StftCpuKernel() = default; - - private: - using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx) const override { - const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); - user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0); - user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - const auto normalized = ctx->Attr("normalized"); - const auto return_complex = ctx->Attr("return_complex"); - const bool onesized = ctx->Attr("onesided"); - - const ShapeView& input_shape = input->shape_view(); - const ShapeView& output_shape = output->shape_view(); - const auto output_elem_cnt = output_shape.elem_cnt() / 2; - - int64_t dims = input_shape.At(0); - int64_t batch = input_shape.At(1); - int64_t len = input_shape.back(); - const IN* data_in = input->dptr(); - IN* data_out = output->mut_dptr(); - auto normalization = normalized ? fft_norm_mode::by_root_n : fft_norm_mode::none; - PocketFFtParams params(Shape{len}, Shape{len}, true, - compute_fct(len, normalization) /*1.f*/, - FFT_EXCUTETYPE::R2C); - PocketFFtConfig config(params); - - OUT* out_tmp_buffer = reinterpret_cast(tmp_buffer->mut_dptr()); - config.excute(data_in, out_tmp_buffer, dims, batch, len); - - if (!onesized) { - OUT* doublesided_tmp_buffer = - reinterpret_cast(tmp_buffer->mut_dptr()) + output_elem_cnt; - size_t last_dim_length = len / 2 + 1; - size_t elem_conut = output_elem_cnt; - convert_to_doublesized(out_tmp_buffer, doublesided_tmp_buffer, last_dim_length, - elem_conut); - out_tmp_buffer = doublesided_tmp_buffer; - } - - if (!return_complex) { comvert_to_real(out_tmp_buffer, data_out, output_elem_cnt); } - } - - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } -}; - -#define REGISTER_STFT_CPU_KERNEL(intype, outtype) \ - REGISTER_USER_KERNEL("stft") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == kCPU) \ - && (user_op::HobDataType("input", 0) == GetDataType::value)) \ - .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ - const Shape& output_shape = ctx->InputShape("output", 0); \ - const bool return_complex = ctx->Attr("return_complex"); \ - const bool onesided = ctx->Attr("onesided"); \ - int64_t output_elem_cnt = \ - return_complex ? output_shape.elem_cnt() : output_shape.elem_cnt() / 2; \ - const int64_t output_bytes = (output_elem_cnt * sizeof(outtype)); \ - return onesided ? output_bytes : 2 * output_bytes; \ - }); - -REGISTER_STFT_CPU_KERNEL(double, std::complex) -REGISTER_STFT_CPU_KERNEL(float, std::complex) - -} // namespace -} // namespace oneflow \ No newline at end of file diff --git a/oneflow/user/kernels/stft_kernel.cu b/oneflow/user/kernels/stft_kernel.cu deleted file mode 100644 index 626209a3ed7..00000000000 --- a/oneflow/user/kernels/stft_kernel.cu +++ /dev/null @@ -1,163 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -#include - -#if CUDA_VERSION >= 11000 - -#include "cufft_plan_cache.h" - -namespace oneflow { - -namespace { - -template -__global__ void convert_complex_to_real(IN* dst, const OUT* src, size_t n) { - CUDA_1D_KERNEL_LOOP(i, n) { - dst[2 * i] = src[i].x; - dst[2 * i + 1] = src[i].y; - }; -} - -double _fft_normalization_scale(const int32_t frame_length) { - return static_cast(1.0 / std::sqrt(frame_length)); -} - -template -__global__ void fft_apply_normalization(FFTTYPE* dst, const double normalization_scale, size_t n, - bool IsNormalized) { - if (!IsNormalized) { return; } - CUDA_1D_KERNEL_LOOP(i, n) { - dst[i].x *= normalization_scale; - dst[i].y *= normalization_scale; - }; -} - -// TODO(yzm):support doublesided -template -__global__ void convert_doublesided(const FFTTYPE* src, FFTTYPE* dst, size_t len, size_t n) { - size_t fact_len = 2 * len - 2; - CUDA_1D_KERNEL_LOOP(i, n) { - int index_x = i / fact_len; - int index_y = i % fact_len; - if (index_y == 0) { - dst[i] = src[index_x * len]; - } else if (index_y == len - 1) { - dst[i] = src[(index_x + 1) * len - 1]; - } else if (index_y < len - 1 && index_y > 0) { - dst[i] = src[index_x * len + index_y]; - } else { - auto index = (index_x + 2) * len - index_y - 2; - dst[i].x = src[index].x; - dst[i].y = -src[index].y; - } - } -} - -} // namespace - -template -class StftGpuKernel final : public user_op::OpKernel { - public: - StftGpuKernel() = default; - ~StftGpuKernel() = default; - - private: - using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx) const override { - const user_op::Tensor* input = ctx->Tensor4ArgNameAndIndex("input", 0); - user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0); - user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - const bool normalized = ctx->Attr("normalized"); - const bool onesided = ctx->Attr("onesided"); - const bool return_complex = ctx->Attr("return_complex"); - - const ShapeView& input_shape = input->shape_view(); - const ShapeView& output_shape = output->shape_view(); - - const Stride& input_stride = input->stride(); - const int out_elem_cnt = - return_complex ? output->shape_view().elem_cnt() : output->shape_view().elem_cnt() / 2; - - const IN* data_in = input->dptr(); - IN* data_out = output->mut_dptr(); - OUT* out_tmp_buffer = reinterpret_cast(tmp_buffer->mut_dptr()); - - int32_t ndim = 1; - int32_t n_frames = static_cast(input_shape.At(1)); - int32_t fft_size = static_cast(input_shape.At(2)); - const Stride& in_stride = {input_stride.at(2), input_stride.at(1)}; - const Stride& out_stride = {1, fft_size / 2 + 1}; - const Shape& in_shape = {fft_size, n_frames}; - const Shape& out_shape = in_shape; - int32_t batch = n_frames; - int32_t rank[1] = {fft_size}; - CuFFtParams params(ndim, rank, in_stride, out_stride, in_shape, out_shape, batch); - CuFFtConfig config(params); - - int32_t in_offset = input_stride.at(0); - int32_t out_offset = n_frames * (fft_size / 2 + 1); - int32_t signal_groups_count = static_cast(input_shape.At(0)); - for (int32_t i = 0; i < signal_groups_count; i++) { - config.excute_plan(data_in + i * in_offset, out_tmp_buffer + i * out_offset); - } - - if (!onesided) { - size_t last_dim_length = fft_size / 2 + 1; - OUT* doublesided_tmp_buffer = - reinterpret_cast(tmp_buffer->mut_dptr()) + out_elem_cnt; - convert_doublesided<<stream()->As()->cuda_stream()>>>( - out_tmp_buffer, doublesided_tmp_buffer, last_dim_length, out_elem_cnt); - out_tmp_buffer = doublesided_tmp_buffer; - } - - const double normalization_scale = _fft_normalization_scale(input_shape.back()); - fft_apply_normalization<<stream()->As()->cuda_stream()>>>( - out_tmp_buffer, normalization_scale, out_elem_cnt, normalized); - - if (!return_complex) { - convert_complex_to_real<<stream()->As()->cuda_stream()>>>( - data_out, out_tmp_buffer, out_elem_cnt); - } else { - // TODO(yzm):support return_complex after oneflow supports complex numbers - } - } - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } -}; -#define REGISTER_STFT_GPU_KERNEL(intype, outtype) \ - REGISTER_USER_KERNEL("stft") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ - && (user_op::HobDataType("input", 0) == GetDataType::value)) \ - .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ - const Shape& output_shape = ctx->InputShape("output", 0); \ - const bool return_complex = ctx->Attr("return_complex"); \ - const bool onesided = ctx->Attr("onesided"); \ - int64_t output_elem_cnt = \ - return_complex ? output_shape.elem_cnt() : output_shape.elem_cnt() / 2; \ - const int64_t output_bytes = GetCudaAlignedSize(output_elem_cnt * sizeof(outtype)); \ - return onesided ? output_bytes : 2 * output_bytes; \ - }); - -REGISTER_STFT_GPU_KERNEL(float, cufftComplex) -REGISTER_STFT_GPU_KERNEL(double, cufftDoubleComplex) - -} // namespace oneflow - -#endif diff --git a/oneflow/user/kernels/to_contiguous_kernel.h b/oneflow/user/kernels/to_contiguous_kernel.h index c924ce4451e..f1a24a46233 100644 --- a/oneflow/user/kernels/to_contiguous_kernel.h +++ b/oneflow/user/kernels/to_contiguous_kernel.h @@ -95,17 +95,22 @@ struct ToContiguousUtil : ToContiguousUtilBase { OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat) \ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble) -#define TO_CONTIGUOUS_CPU_TYPES \ - TO_CONTIGUOUS_COMMON_TYPES OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16) \ - OF_PP_MAKE_TUPLE_SEQ(bfloat16, DataType::kBFloat16) +#define TO_CONTIGUOUS_CPU_TYPES \ + TO_CONTIGUOUS_COMMON_TYPES COMPLEX_DATA_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ( \ + float16, DataType::kFloat16) OF_PP_MAKE_TUPLE_SEQ(bfloat16, DataType::kBFloat16) #ifdef WITH_CUDA #if CUDA_VERSION >= 11000 -#define TO_CONTIGUOUS_CUDA_SPECIAL_TYPE \ - OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16) \ - OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16) +#define TO_CONTIGUOUS_CUDA_SPECIAL_TYPE \ + OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16) \ + OF_PP_MAKE_TUPLE_SEQ(nv_bfloat16, DataType::kBFloat16) \ + OF_PP_MAKE_TUPLE_SEQ(cuComplex, DataType::kComplex64) \ + OF_PP_MAKE_TUPLE_SEQ(cuDoubleComplex, DataType::kComplex128) #else -#define TO_CONTIGUOUS_CUDA_SPECIAL_TYPE OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16) +#define TO_CONTIGUOUS_CUDA_SPECIAL_TYPE \ + OF_PP_MAKE_TUPLE_SEQ(half, DataType::kFloat16) \ + OF_PP_MAKE_TUPLE_SEQ(cuComplex, DataType::kComplex64) \ + OF_PP_MAKE_TUPLE_SEQ(cuDoubleComplex, DataType::kComplex128) #endif // CUDA_VERSION >= 11000 #endif // WITH_CUDA #endif // ONEFLOW_USER_KERNELS_TO_CONTIGUOUS_KERNEL_H_ diff --git a/oneflow/user/ops/fft_ops.cpp b/oneflow/user/ops/fft_ops.cpp new file mode 100644 index 00000000000..3d4f2c32ba3 --- /dev/null +++ b/oneflow/user/ops/fft_ops.cpp @@ -0,0 +1,121 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include "oneflow/core/common/data_type.pb.h" +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" +namespace oneflow { + +/* static */ Maybe FftC2COp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("input", 0); + Stride out_stride = Stride(in_shape); // contiguous + ctx->SetOutputShape("out", 0, in_shape); + ctx->SetOutputStride("out", 0, out_stride); + ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("input", 0)); + return Maybe::Ok(); +} + +/*static*/ Maybe FftC2COp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe FftC2COp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .PartialSum(user_op::OpArg("input", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe FftC2COp::InferDataType(user_op::InferContext* ctx) { + ctx->SetOutputDType("out", 0, ctx->InputDType("input", 0)); + return Maybe::Ok(); +} + +/* static */ Maybe FftR2COp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("input", 0); + const auto& dims = ctx->Attr>("dims"); + bool onesided = ctx->Attr("onesided"); + + Shape out_shape = in_shape; + auto last_dim = dims.back(); + if (onesided) { out_shape[last_dim] = out_shape[last_dim] / 2 + 1; } + Stride out_stride = Stride(out_shape); + ctx->SetOutputShape("out", 0, out_shape); + ctx->SetOutputStride("out", 0, out_stride); + ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("input", 0)); + return Maybe::Ok(); +} + +/*static*/ Maybe FftR2COp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe FftR2COp::GetSbp(user_op::SbpContext* ctx) { + // TO-DO : Validate sbp + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe FftR2COp::InferDataType(user_op::InferContext* ctx) { + const DataType& input_type = ctx->InputDType("input", 0); + switch (input_type) { + case (kFloat): ctx->SetOutputDType("out", 0, kComplex64); break; + case (kDouble): ctx->SetOutputDType("out", 0, kComplex128); break; + default: CHECK_OR_RETURN(false) << "RuntimeError: dtype can't be handled"; + } + + return Maybe::Ok(); +} + +/* static */ Maybe FftC2ROp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("input", 0); + + const auto& dims = ctx->Attr>("dims"); + int64_t last_dim_size = ctx->Attr("last_dim_size"); + + Shape out_shape = in_shape; + out_shape[dims.back()] = last_dim_size; + Stride out_stride = Stride(out_shape); + ctx->SetOutputShape("out", 0, out_shape); + ctx->SetOutputStride("out", 0, out_stride); + ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("input", 0)); + return Maybe::Ok(); +} + +/*static*/ Maybe FftC2ROp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe FftC2ROp::GetSbp(user_op::SbpContext* ctx) { + // TO-DO : Validate sbp + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe FftC2ROp::InferDataType(user_op::InferContext* ctx) { + const DataType& input_type = ctx->InputDType("input", 0); + switch (input_type) { + case (kComplex64): ctx->SetOutputDType("out", 0, kFloat); break; + case (kComplex128): ctx->SetOutputDType("out", 0, kDouble); break; + default: CHECK_OR_RETURN(false) << "RuntimeError: dtype can't be handled"; + } + + return Maybe::Ok(); +} + +} // namespace oneflow \ No newline at end of file diff --git a/python/oneflow/__init__.py b/python/oneflow/__init__.py index 126829469da..85faa262056 100644 --- a/python/oneflow/__init__.py +++ b/python/oneflow/__init__.py @@ -215,6 +215,7 @@ def use_deterministic_algorithms(mode, *, warn_only=False): from oneflow._C import argmax from oneflow._C import argmin from oneflow._C import std + from oneflow._C import stft from oneflow._C import var from oneflow._C import stack, hstack, vstack, dstack, column_stack, row_stack @@ -489,6 +490,7 @@ def atexit_hook(hook): amp, hub, fx, + fft, special, ) import oneflow.utils.data diff --git a/python/oneflow/fft/__init__.py b/python/oneflow/fft/__init__.py new file mode 100644 index 00000000000..0b97bd98706 --- /dev/null +++ b/python/oneflow/fft/__init__.py @@ -0,0 +1,877 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from oneflow.framework.tensor import Tensor +import oneflow as flow + + +def fft(input, n=None, dim=-1, norm=None) -> Tensor: + r""" + + Computes the one dimensional discrete Fourier transform of :attr:`input`. + + Note: + The Fourier domain representation of any real signal satisfies the + Hermitian property: `X[i] = conj(X[-i])`. This function always returns both + the positive and negative frequency terms even though, for real inputs, the + negative frequencies are redundant. :func:`oneflow.fft.rfft` returns the + more compact one-sided representation where only the positive frequencies + are returned. + + Args: + input (Tensor): the input tensor + n (int, optional): Signal length. If given, the input will either be zero-padded + or trimmed to this length before computing the FFT. + dim (int, optional): The dimension along which to take the one dimensional FFT. + norm (str, optional): Normalization mode. For the forward transform + (:func:`oneflow.fft.fft`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal) + + Calling the backward transform (:func:`oneflow.fft.ifft`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`oneflow.fft.ifft` + the exact inverse. + + Default is ``"backward"`` (no normalization). + + Example: + + >>> t = oneflow.arange(4) + >>> t + tensor([0, 1, 2, 3]) + >>> oneflow.fft.fft(t) + tensor([ 6+0j, -2+2j, -2+0j, -2-2j], dtype=oneflow.complex64) + + >>> t = oneflow.tensor([0.+1.j, 2.+3.j, 4.+5.j, 6.+7.j]) + >>> oneflow.fft.fft(t) + tensor([12+16j, -8+0j, -4-4j, -8j], dtype=oneflow.complex128) + """ + if n is None: + n = -1 + return flow._C.fft(input, n, dim, norm) + + +def ifft(input, n=None, dim=-1, norm=None) -> Tensor: + r""" + + Computes the one dimensional inverse discrete Fourier transform of :attr:`input`. + + Args: + input (Tensor): the input tensor + n (int, optional): Signal length. If given, the input will either be zero-padded + or trimmed to this length before computing the IFFT. + dim (int, optional): The dimension along which to take the one dimensional IFFT. + norm (str, optional): Normalization mode. For the backward transform + (:func:`oneflow.fft.ifft`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal) + + Calling the forward transform (:func:`~oneflow.fft.fft`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`oneflow.fft.ifft` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + + Example: + + >>> t = oneflow.tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]) + >>> oneflow.fft.ifft(t) + tensor([0j, (1+0j), (2+0j), (3+0j)], dtype=oneflow.complex128) + """ + if n is None: + n = -1 + return flow._C.ifft(input, n, dim, norm) + + +def fft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor: + r""" + + Computes the 2 dimensional discrete Fourier transform of :attr:`input`. + Equivalent to :func:`~oneflow.fft.fftn` but FFTs only the last two dimensions by default. + + Note: + The Fourier domain representation of any real signal satisfies the + Hermitian property: ``X[i, j] = conj(X[-i, -j])``. This + function always returns all positive and negative frequency terms even + though, for real inputs, half of these values are redundant. + :func:`~oneflow.fft.rfft2` returns the more compact one-sided representation + where only the positive frequencies of the last dimension are returned. + + Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the forward transform + (:func:`oneflow.fft.fft2`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`oneflow.fft.ifft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` + between the two transforms. This is required to make + :func:`~oneflow.fft.ifft2` the exact inverse. + + Default is ``"backward"`` (no normalization). + + """ + return flow._C.fft2(input, s, dim, norm) + + +def ifft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor: + r""" + + Computes the 2 dimensional inverse discrete Fourier transform of :attr:`input`. + Equivalent to :func:`oneflow.fft.ifftn` but IFFTs only the last two dimensions by default. + + Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the IFFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the backward transform + (:func:`oneflow.fft.ifft2`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`oneflow.fft.fft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`oneflow.fft.ifft2` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + + + """ + return flow._C.ifft2(input, s, dim, norm) + + +def fftn(input, s=None, dim=None, norm=None) -> Tensor: + r""" + + Computes the N dimensional discrete Fourier transform of :attr:`input`. + + Note: + The Fourier domain representation of any real signal satisfies the + Hermitian property: ``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])``. This + function always returns all positive and negative frequency terms even + though, for real inputs, half of these values are redundant. + :func:`oneflow.fft.rfftn` returns the more compact one-sided representation + where only the positive frequencies of the last dimension are returned. + + Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. + norm (str, optional): Normalization mode. For the forward transform + (:func:`oneflow.fft.fftn`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`oneflow.fft.ifftn`) with the same + normalization mode will apply an overall normalization of ``1/n`` + between the two transforms. This is required to make + :func:`oneflow.fft.ifftn` the exact inverse. + + Default is ``"backward"`` (no normalization). + + """ + return flow._C.fftn(input, s, dim, norm) + + +def ifftn(input, s=None, dim=None, norm=None) -> Tensor: + r""" + + Computes the N dimensional inverse discrete Fourier transform of :attr:`input`. + + Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the IFFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. + norm (str, optional): Normalization mode. For the backward transform + (:func:`oneflow.fft.ifftn`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`oneflow.fft.fftn`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`oneflow.fft.ifftn` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + + """ + return flow._C.ifftn(input, s, dim, norm) + + +def rfft(input, n=None, dim=-1, norm=None) -> Tensor: + r""" + + Computes the one dimensional Fourier transform of real-valued :attr:`input`. + + The FFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])`` so + the output contains only the positive frequencies below the Nyquist frequency. + To compute the full output, use :func:`oneflow.fft.fft` + + Args: + input (Tensor): the real input tensor + n (int, optional): Signal length. If given, the input will either be zero-padded + or trimmed to this length before computing the real FFT. + dim (int, optional): The dimension along which to take the one dimensional real FFT. + norm (str, optional): Normalization mode. For the forward transform + (:func:`oneflow.fft.rfft`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal) + + Calling the backward transform (:func:`oneflow.fft.irfft`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`oneflow.fft.irfft` + the exact inverse. + + Default is ``"backward"`` (no normalization). + + Example: + + >>> t = oneflow.arange(4) + >>> t + tensor([0, 1, 2, 3], dtype=oneflow.int64) + >>> oneflow.fft.rfft(t) + tensor([ (6+0j), (-2+2j), (-2+0j)], dtype=oneflow.complex64) + + Compare against the full output from :func:`oneflow.fft.fft`: + + >>> oneflow.fft.fft(t) + tensor([ (6+0j), (-2+2j), (-2+0j), (-2-2j)], dtype=oneflow.complex64) + + Notice that the symmetric element ``T[-1] == T[1].conj()`` is omitted. + At the Nyquist frequency ``T[-2] == T[2]`` is it's own symmetric pair, + and therefore must always be real-valued. + """ + + if n is None: + n = -1 + return flow._C.rfft(input, n, dim, norm) + + +def irfft(input, n=None, dim=-1, norm=None) -> Tensor: + r""" + + Computes the inverse of :func:`oneflow.fft.rfft`. + + :attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier + domain, as produced by :func:`oneflow.fft.rfft`. By the Hermitian property, the + output will be real-valued. + + Note: + Some input frequencies must be real-valued to satisfy the Hermitian + property. In these cases the imaginary component will be ignored. + For example, any imaginary component in the zero-frequency term cannot + be represented in a real output and so will always be ignored. + + Note: + The correct interpretation of the Hermitian input depends on the length of + the original data, as given by :attr:`n`. This is because each input shape + could correspond to either an odd or even length signal. By default, the + signal is assumed to be even length and odd signals will not round-trip + properly. So, it is recommended to always pass the signal length :attr:`n`. + + Args: + input (Tensor): the input tensor representing a half-Hermitian signal + n (int, optional): Output signal length. This determines the length of the + output signal. If given, the input will either be zero-padded or trimmed to this + length before computing the real IFFT. + Defaults to even output: ``n=2*(input.size(dim) - 1)``. + dim (int, optional): The dimension along which to take the one dimensional real IFFT. + norm (str, optional): Normalization mode. For the backward transform + (:func:`oneflow.fft.irfft`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal) + + Calling the forward transform (:func:`oneflow.fft.rfft`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`oneflow.fft.irfft` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + + + """ + + if n is None: + n = -1 + return flow._C.irfft(input, n, dim, norm) + + +def rfft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor: + r""" + + Computes the 2-dimensional discrete Fourier transform of real :attr:`input`. + Equivalent to :func:`oneflow.fft.rfftn` but FFTs only the last two dimensions by default. + + The FFT of a real signal is Hermitian-symmetric, ``X[i, j] = conj(X[-i, -j])``, + so the full :func:`oneflow.fft.fft2` output contains redundant information. + :func:`oneflow.fft.rfft2` instead omits the negative frequencies in the last + dimension. + + Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the real FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the forward transform + (:func:`oneflow.fft.rfft2`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`oneflow.fft.irfft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`oneflow.fft.irfft2` + the exact inverse. + + Default is ``"backward"`` (no normalization). + + """ + + return flow._C.rfft2(input, s, dim, norm) + + +def irfft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor: + r""" + + Computes the inverse of :func:`oneflow.fft.rfft2`. + Equivalent to :func:`oneflow.fft.irfftn` but IFFTs only the last two dimensions by default. + + :attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier + domain, as produced by :func:`oneflow.fft.rfft2`. By the Hermitian property, the + output will be real-valued. + + Note: + Some input frequencies must be real-valued to satisfy the Hermitian + property. In these cases the imaginary component will be ignored. + For example, any imaginary component in the zero-frequency term cannot + be represented in a real output and so will always be ignored. + + Note: + The correct interpretation of the Hermitian input depends on the length of + the original data, as given by :attr:`s`. This is because each input shape + could correspond to either an odd or even length signal. By default, the + signal is assumed to be even length and odd signals will not round-trip + properly. So, it is recommended to always pass the signal shape :attr:`s`. + + Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the real FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Defaults to even output in the last dimension: + ``s[-1] = 2*(input.size(dim[-1]) - 1)``. + dim (Tuple[int], optional): Dimensions to be transformed. + The last dimension must be the half-Hermitian compressed dimension. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the backward transform + (:func:`oneflow.fft.irfft2`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`oneflow.fft.rfft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`oneflow.fft.irfft2` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + + + """ + return flow._C.irfft2(input, s, dim, norm) + + +def rfftn(input, s=None, dim=None, norm=None) -> Tensor: + r""" + + Computes the N-dimensional discrete Fourier transform of real :attr:`input`. + + The FFT of a real signal is Hermitian-symmetric, + ``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])`` so the full + :func:`oneflow.fft.fftn` output contains redundant information. + :func:`oneflow.fft.rfftn` instead omits the negative frequencies in the + last dimension. + + Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the real FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. + norm (str, optional): Normalization mode. For the forward transform + (:func:`oneflow.fft.rfftn`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`oneflow.fft.irfftn`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`oneflow.fft.irfftn` + the exact inverse. + + Default is ``"backward"`` (no normalization). + + """ + + return flow._C.rfftn(input, s, dim, norm) + + +def irfftn(input, s=None, dim=None, norm=None) -> Tensor: + r""" + + Computes the inverse of :func:`oneflow.fft.rfftn`. + + :attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier + domain, as produced by :func:`oneflow.fft.rfftn`. By the Hermitian property, the + output will be real-valued. + + Note: + Some input frequencies must be real-valued to satisfy the Hermitian + property. In these cases the imaginary component will be ignored. + For example, any imaginary component in the zero-frequency term cannot + be represented in a real output and so will always be ignored. + + Note: + The correct interpretation of the Hermitian input depends on the length of + the original data, as given by :attr:`s`. This is because each input shape + could correspond to either an odd or even length signal. By default, the + signal is assumed to be even length and odd signals will not round-trip + properly. So, it is recommended to always pass the signal shape :attr:`s`. + + Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the real FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Defaults to even output in the last dimension: + ``s[-1] = 2*(input.size(dim[-1]) - 1)``. + dim (Tuple[int], optional): Dimensions to be transformed. + The last dimension must be the half-Hermitian compressed dimension. + Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. + norm (str, optional): Normalization mode. For the backward transform + (:func:`oneflow.fft.irfftn`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`oneflow.fft.rfftn`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`oneflow.fft.irfftn` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + + """ + return flow._C.irfftn(input, s, dim, norm) + + +def hfft(input, n=None, dim=-1, norm=None) -> Tensor: + r""" + hfft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor + + Computes the one dimensional discrete Fourier transform of a Hermitian + symmetric :attr:`input` signal. + + Note: + + :func:`oneflow.fft.hfft`/:func:`oneflow.fft.ihfft` are analogous to + :func:`oneflow.fft.rfft`/:func:`oneflow.fft.irfft`. The real FFT expects + a real signal in the time-domain and gives a Hermitian symmetry in the + frequency-domain. The Hermitian FFT is the opposite; Hermitian symmetric in + the time-domain and real-valued in the frequency-domain. For this reason, + special care needs to be taken with the length argument :attr:`n`, in the + same way as with :func:`oneflow.fft.irfft`. + + Note: + Because the signal is Hermitian in the time-domain, the result will be + real in the frequency domain. Note that some input frequencies must be + real-valued to satisfy the Hermitian property. In these cases the imaginary + component will be ignored. For example, any imaginary component in + ``input[0]`` would result in one or more complex frequency terms which + cannot be represented in a real output and so will always be ignored. + + Note: + The correct interpretation of the Hermitian input depends on the length of + the original data, as given by :attr:`n`. This is because each input shape + could correspond to either an odd or even length signal. By default, the + signal is assumed to be even length and odd signals will not round-trip + properly. So, it is recommended to always pass the signal length :attr:`n`. + + Args: + input (Tensor): the input tensor representing a half-Hermitian signal + n (int, optional): Output signal length. This determines the length of the + real output. If given, the input will either be zero-padded or trimmed to this + length before computing the Hermitian FFT. + Defaults to even output: ``n=2*(input.size(dim) - 1)``. + dim (int, optional): The dimension along which to take the one dimensional Hermitian FFT. + norm (str, optional): Normalization mode. For the forward transform + (:func:`oneflow.fft.hfft`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal) + + Calling the backward transform (:func:`oneflow.fft.ihfft`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`oneflow.fft.ihfft` + the exact inverse. + + Default is ``"backward"`` (no normalization). + + Example: + + Taking a real-valued frequency signal and bringing it into the time domain + gives Hermitian symmetric output: + + >>> t = oneflow.linspace(0, 1, 5) + >>> t + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000], dtype=oneflow.float32) + >>> T = oneflow.fft.ifft(t) + >>> T + tensor([ (0.5000-0.0000j), (-0.1250-0.1720j), (-0.1250-0.0406j), (-0.1250+0.0406j), + (-0.1250+0.1720j)], dtype=oneflow.complex64) + + Note that ``T[1] == T[-1].conj()`` and ``T[2] == T[-2].conj()`` is + redundant. We can thus compute the forward transform without considering + negative frequencies: + + >>> oneflow.fft.hfft(T[:3], n=5) + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000], dtype=oneflow.float32) + + Like with :func:`oneflow.fft.irfft`, the output length must be given in order + to recover an even length output: + + >>> oneflow.fft.hfft(T[:3]) + tensor([0.1250, 0.2809, 0.6250, 0.9691], dtype=oneflow.float32) + """ + + if n is None: + n = -1 + return flow._C.hfft(input, n, dim, norm) + + +def ihfft(input, n=None, dim=-1, norm=None) -> Tensor: + r""" + + Computes the inverse of :func:`oneflow.fft.hfft`. + + :attr:`input` must be a real-valued signal, interpreted in the Fourier domain. + The IFFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])``. + :func:`oneflow.fft.ihfft` represents this in the one-sided form where only the + positive frequencies below the Nyquist frequency are included. To compute the + full output, use :func:`oneflow.fft.ifft`. + + + Args: + input (Tensor): the real input tensor + n (int, optional): Signal length. If given, the input will either be zero-padded + or trimmed to this length before computing the Hermitian IFFT. + dim (int, optional): The dimension along which to take the one dimensional Hermitian IFFT. + norm (str, optional): Normalization mode. For the backward transform + (:func:`oneflow.fft.ihfft`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal) + + Calling the forward transform (:func:`oneflow.fft.hfft`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`oneflow.fft.ihfft` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + + Example: + + >>> t = oneflow.arange(5) + >>> t + tensor([0, 1, 2, 3, 4], dtype=oneflow.int64) + >>> oneflow.fft.ihfft(t) + tensor([ (2.0000-0.0000j), (-0.5000-0.6882j), (-0.5000-0.1625j)], dtype=oneflow.complex64) + + Compare against the full output from :func:`oneflow.fft.ifft`: + + >>> oneflow.fft.ifft(t) + tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j, + -0.5000+0.6882j]) + tensor([ (2.0000-0.0000j), (-0.5000-0.6882j), (-0.5000-0.1625j), (-0.5000+0.1625j), + (-0.5000+0.6882j)], dtype=oneflow.complex64) + """ + if n is None: + n = -1 + return flow._C.ihfft(input, n, dim, norm) + + +def hfft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor: + r""" + + Computes the 2-dimensional discrete Fourier transform of a Hermitian symmetric + :attr:`input` signal. Equivalent to :func:`oneflow.fft.hfftn` but only + transforms the last two dimensions by default. + + :attr:`input` is interpreted as a one-sided Hermitian signal in the time + domain. By the Hermitian property, the Fourier transform will be real-valued. + + Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the Hermitian FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Defaults to even output in the last dimension: + ``s[-1] = 2*(input.size(dim[-1]) - 1)``. + dim (Tuple[int], optional): Dimensions to be transformed. + The last dimension must be the half-Hermitian compressed dimension. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the forward transform + (:func:`oneflow.fft.hfft2`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`oneflow.fft.ihfft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`oneflow.fft.ihfft2` + the exact inverse. + + Default is ``"backward"`` (no normalization). + + + Example: + + Starting from a real frequency-space signal, we can generate a + Hermitian-symmetric time-domain signal: + >>> T = oneflow.rand(10, 9) + >>> t = oneflow.fft.ihfft2(T) + + Without specifying the output length to :func:`oneflow.fft.hfftn`, the + output will not round-trip properly because the input is odd-length in the + last dimension: + + >>> oneflow.fft.hfft2(t).size() + oneflow.Size([10, 10]) + + So, it is recommended to always pass the signal shape :attr:`s`. + + >>> roundtrip = oneflow.fft.hfft2(t, T.size()) + >>> roundtrip.size() + oneflow.Size([10, 9]) + >>> oneflow.allclose(roundtrip, T) + True + + """ + return flow._C.hfft2(input, s, dim, norm) + + +def ihfft2(input, s=None, dim=(-2, -1), norm=None) -> Tensor: + r""" + + Computes the 2-dimensional inverse discrete Fourier transform of real + :attr:`input`. Equivalent to :func:`oneflow.fft.ihfftn` but transforms only the + two last dimensions by default. + + Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the Hermitian IFFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the backward transform + (:func:`oneflow.fft.ihfft2`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`oneflow.fft.hfft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`oneflow.fft.ihfft2` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + + """ + return flow._C.ihfft2(input, s, dim, norm) + + +def hfftn(input, s=None, dim=None, norm=None) -> Tensor: + r""" + + Computes the n-dimensional discrete Fourier transform of a Hermitian symmetric + :attr:`input` signal. + + :attr:`input` is interpreted as a one-sided Hermitian signal in the time + domain. By the Hermitian property, the Fourier transform will be real-valued. + + Note: + :func:`oneflow.fft.hfftn`/:func:`oneflow.fft.ihfftn` are analogous to + :func:`oneflow.fft.rfftn`/:func:`oneflow.fft.irfftn`. The real FFT expects + a real signal in the time-domain and gives Hermitian symmetry in the + frequency-domain. The Hermitian FFT is the opposite; Hermitian symmetric in + the time-domain and real-valued in the frequency-domain. For this reason, + special care needs to be taken with the shape argument :attr:`s`, in the + same way as with :func:`oneflow.fft.irfftn`. + + Note: + Some input frequencies must be real-valued to satisfy the Hermitian + property. In these cases the imaginary component will be ignored. + For example, any imaginary component in the zero-frequency term cannot + be represented in a real output and so will always be ignored. + + Note: + The correct interpretation of the Hermitian input depends on the length of + the original data, as given by :attr:`s`. This is because each input shape + could correspond to either an odd or even length signal. By default, the + signal is assumed to be even length and odd signals will not round-trip + properly. It is recommended to always pass the signal shape :attr:`s`. + + + Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the real FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Defaults to even output in the last dimension: + ``s[-1] = 2*(input.size(dim[-1]) - 1)``. + dim (Tuple[int], optional): Dimensions to be transformed. + The last dimension must be the half-Hermitian compressed dimension. + Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. + norm (str, optional): Normalization mode. For the forward transform + (:func:`oneflow.fft.hfftn`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`oneflow.fft.ihfftn`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`oneflow.fft.ihfftn` + the exact inverse. + + Default is ``"backward"`` (no normalization). + + """ + return flow._C.hfftn(input, s, dim, norm) + + +def ihfftn(input, s=None, dim=None, norm=None) -> Tensor: + r""" + + Computes the N-dimensional inverse discrete Fourier transform of real :attr:`input`. + + :attr:`input` must be a real-valued signal, interpreted in the Fourier domain. + The n-dimensional IFFT of a real signal is Hermitian-symmetric, + ``X[i, j, ...] = conj(X[-i, -j, ...])``. :func:`oneflow.fft.ihfftn` represents + this in the one-sided form where only the positive frequencies below the + Nyquist frequency are included in the last signal dimension. To compute the + full output, use :func:`oneflow.fft.ifftn`. + + Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the Hermitian IFFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. + norm (str, optional): Normalization mode. For the backward transform + (:func:`oneflow.fft.ihfftn`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`oneflow.fft.hfftn`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`oneflow.fft.ihfftn` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + + """ + return flow._C.ihfftn(input, s, dim, norm) diff --git a/python/oneflow/test/exceptions/test_stft_op.py b/python/oneflow/test/exceptions/test_stft_op.py index c013f78045d..38d1cb10193 100644 --- a/python/oneflow/test/exceptions/test_stft_op.py +++ b/python/oneflow/test/exceptions/test_stft_op.py @@ -53,7 +53,7 @@ def test_stft_illegal_nfft(test_case): return_complex=False, normalized=False, ) - test_case.assertTrue("Expected 0 < n_fft <" in str(ctx.exception)) + test_case.assertTrue("Expected 0 < n_fft" in str(ctx.exception)) def test_stft_illegal_hop_length(test_case): np_tensor = np.arange(1, 13, dtype=float).reshape(4, 3) diff --git a/python/oneflow/test/modules/test_fft.py b/python/oneflow/test/modules/test_fft.py new file mode 100644 index 00000000000..de25eeb7e1f --- /dev/null +++ b/python/oneflow/test/modules/test_fft.py @@ -0,0 +1,898 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +from collections import OrderedDict + +import numpy as np +import torch as torch_original +from packaging import version + +import oneflow as flow +import oneflow.unittest +from oneflow.test_utils.test_util import GenArgList + +from oneflow.test_utils.automated_test_util import * + + +def is_cufft_available(): + if flow.cuda.is_available(): + (major, _minor) = flow.cuda.get_device_capability() + return major >= 7 + else: + return False + + +def is_complex_dtype(dtype): + if hasattr(dtype, "pytorch") and hasattr(dtype, "oneflow"): + # is DualObject + return dtype.pytorch.is_complex + else: + return dtype in [ + flow.complex64, + flow.complex128, + torch_original.complex64, + torch_original.complex128, + torch.pytorch.complex64, + torch.pytorch.complex128, + ] + + +def gen_params_1d_fft(lower_n_dims=1, upper_n_dims=5): + num_dims = np.random.randint(lower_n_dims, upper_n_dims) + shape = [np.random.randint(1, 5) * 2 for _ in range(num_dims)] + + if np.random.randint(2) == 1: + dim = np.random.randint(low=-num_dims, high=num_dims - 1) + else: + dim = -1 + + norm = np.random.choice(["backward", "forward", "ortho", None]) + + if np.random.randint(2) == 1: + n = None + else: + n = np.random.randint(low=1, high=shape[dim] * 2) + + params = { + "num_dims": num_dims, + "shape": shape, + "n": n, + "dim": dim, + "norm": norm, + } + return params + + +def gen_params_2d_fft(lower_n_dims=2, upper_n_dims=5): + num_dims = np.random.randint(lower_n_dims, upper_n_dims) + shape = [np.random.randint(1, 5) * 2 for _ in range(num_dims)] + len_fft_dim = np.random.randint(low=1, high=3) + + total_dims_range = np.arange(num_dims) + if np.random.randint(2) == 1: + dims = np.random.choice( + total_dims_range, size=len_fft_dim, replace=False + ).tolist() + else: + dims = (-2, -1) + + norm = np.random.choice(["backward", "forward", "ortho", None]) + len_fft_dim = len(dims) + if np.random.randint(2) == 1 and dims is not None: + n = [] + for i in range(len_fft_dim): + n_ = ( + np.random.randint(low=1, high=2 * shape[i]) + if np.random.randint(2) == 1 + else -1 + ) + n.append(n_) + else: + n = None + + params = { + "num_dims": num_dims, + "shape": shape, + "n": n, + "dim": dims, + "norm": norm, + } + return params + + +def gen_params_nd_fft(lower_n_dims=2, upper_n_dims=5): + num_dims = np.random.randint(lower_n_dims, upper_n_dims) + shape = [np.random.randint(1, 5) * 2 for _ in range(num_dims)] + len_fft_dim = np.random.randint(low=1, high=num_dims + 1) + + total_dims_range = np.arange(num_dims) + if np.random.randint(2) == 1: + dims = np.random.choice( + total_dims_range, size=len_fft_dim, replace=False + ).tolist() + else: + dims = None + + norm = np.random.choice(["backward", "forward", "ortho", None]) + + if np.random.randint(2) == 1: + n = None + else: + n = [] + len_fft_dim = ( + len(dims) + if dims is not None + else np.random.randint(low=1, high=num_dims + 1) + ) + for i in range(len_fft_dim): + n_ = ( + np.random.randint(low=1, high=2 * shape[i]) + if np.random.randint(2) == 1 + else -1 + ) + n.append(n_) + + params = { + "num_dims": num_dims, + "shape": shape, + "n": n, + "dim": dims, + "norm": norm, + } + return params + + +def _test_fft(test_case): + + if is_cufft_available(): + device = random_device() + else: + device = cpu_device() + + lower_n_dims = test_case.ndims_dict["1d"]["lower_n_dims"] + upper_n_dims = test_case.ndims_dict["1d"]["upper_n_dims"] + params = gen_params_1d_fft(lower_n_dims, upper_n_dims) + + num_dims = params["num_dims"] + shape = params["shape"] + n = params["n"] + dim = params["dim"] + norm = params["norm"] + + x = random_tensor(num_dims, dtype=float, *shape) + if is_complex_dtype(x.dtype): + # test fft_c2c + dtype = test_case.dtype_dict["complex"] + x = x.to(device=device, dtype=dtype) + else: + # test fft_r2c + dtype = test_case.dtype_dict["real"] + x = x.to(device=device, dtype=dtype) + y = torch.fft.fft(x, n, dim, norm) + return y + + +def _test_ifft(test_case): + if is_cufft_available(): + device = random_device() + else: + device = cpu_device() + + lower_n_dims = test_case.ndims_dict["1d"]["lower_n_dims"] + upper_n_dims = test_case.ndims_dict["1d"]["upper_n_dims"] + params = gen_params_1d_fft(lower_n_dims, upper_n_dims) + + num_dims = params["num_dims"] + shape = params["shape"] + n = params["n"] + dim = params["dim"] + norm = params["norm"] + + x = random_tensor(num_dims, dtype=float, *shape) + if is_complex_dtype(x.dtype): + # test fft_c2c + dtype = test_case.dtype_dict["complex"] + x = x.to(device=device, dtype=dtype) + else: + # test fft_r2c + dtype = test_case.dtype_dict["real"] + x = x.to(device=device, dtype=dtype) + + y = torch.fft.ifft(x, n, dim, norm) + + return y + + +def _test_rfft(test_case): + if is_cufft_available(): + device = random_device() + else: + device = cpu_device() + + lower_n_dims = test_case.ndims_dict["1d"]["lower_n_dims"] + upper_n_dims = test_case.ndims_dict["1d"]["upper_n_dims"] + params = gen_params_1d_fft(lower_n_dims, upper_n_dims) + + num_dims = params["num_dims"] + shape = params["shape"] + n = params["n"] + dim = params["dim"] + norm = params["norm"] + + dtype = test_case.dtype_dict["real"] + + x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) + y = torch.fft.rfft(x, n, dim, norm) + + return y + + +def _test_irfft(test_case): + if is_cufft_available(): + device = random_device() + else: + device = cpu_device() + + lower_n_dims = test_case.ndims_dict["1d"]["lower_n_dims"] + upper_n_dims = test_case.ndims_dict["1d"]["upper_n_dims"] + params = gen_params_1d_fft(lower_n_dims, upper_n_dims) + + num_dims = params["num_dims"] + shape = params["shape"] + n = params["n"] + dim = params["dim"] + norm = params["norm"] + dtype = test_case.dtype_dict["complex"] + + x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) + y = torch.fft.irfft(x, n, dim, norm) + + return y + + +def _test_hfft(test_case): + if is_cufft_available(): + device = random_device() + else: + device = cpu_device() + + lower_n_dims = test_case.ndims_dict["1d"]["lower_n_dims"] + upper_n_dims = test_case.ndims_dict["1d"]["upper_n_dims"] + params = gen_params_1d_fft(lower_n_dims, upper_n_dims) + + num_dims = params["num_dims"] + shape = params["shape"] + n = params["n"] + dim = params["dim"] + norm = params["norm"] + dtype = test_case.dtype_dict["complex"] + + x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) + y = torch.fft.hfft(x, n, dim, norm) + + return y + + +def _test_ihfft(test_case): + if is_cufft_available(): + device = random_device() + else: + device = cpu_device() + + lower_n_dims = test_case.ndims_dict["1d"]["lower_n_dims"] + upper_n_dims = test_case.ndims_dict["1d"]["upper_n_dims"] + params = gen_params_1d_fft(lower_n_dims, upper_n_dims) + + num_dims = params["num_dims"] + shape = params["shape"] + n = params["n"] + dim = params["dim"] + norm = params["norm"] + dtype = test_case.dtype_dict["real"] + + x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) + y = torch.fft.ihfft(x, n, dim, norm) + + return y + + +def _test_fft2(test_case): + + if is_cufft_available(): + device = random_device() + else: + device = cpu_device() + + lower_n_dims = test_case.ndims_dict["2d"]["lower_n_dims"] + upper_n_dims = test_case.ndims_dict["2d"]["upper_n_dims"] + params = gen_params_2d_fft(lower_n_dims, upper_n_dims) + + num_dims = params["num_dims"] + shape = params["shape"] + n = params["n"] + dim = params["dim"] + norm = params["norm"] + + x = random_tensor(num_dims, dtype=float, *shape) + if is_complex_dtype(x.dtype): + # test fft_c2c + dtype = test_case.dtype_dict["complex"] + x = x.to(device=device, dtype=dtype) + else: + # test fft_r2c + dtype = test_case.dtype_dict["real"] + x = x.to(device=device, dtype=dtype) + y = torch.fft.fft2(x, n, dim, norm) + + return y + + +def _test_ifft2(test_case): + if is_cufft_available(): + device = random_device() + else: + device = cpu_device() + + lower_n_dims = test_case.ndims_dict["2d"]["lower_n_dims"] + upper_n_dims = test_case.ndims_dict["2d"]["upper_n_dims"] + params = gen_params_2d_fft(lower_n_dims, upper_n_dims) + + num_dims = params["num_dims"] + shape = params["shape"] + n = params["n"] + dim = params["dim"] + norm = params["norm"] + + x = random_tensor(num_dims, dtype=float, *shape) + if is_complex_dtype(x.dtype): + # test fft_c2c + dtype = test_case.dtype_dict["complex"] + x = x.to(device=device, dtype=dtype) + else: + # test fft_r2c + dtype = test_case.dtype_dict["real"] + x = x.to(device=device, dtype=dtype) + + y = torch.fft.ifft2(x, n, dim, norm) + + return y + + +def _test_rfft2(test_case): + if is_cufft_available(): + device = random_device() + else: + device = cpu_device() + + lower_n_dims = test_case.ndims_dict["2d"]["lower_n_dims"] + upper_n_dims = test_case.ndims_dict["2d"]["upper_n_dims"] + params = gen_params_2d_fft(lower_n_dims, upper_n_dims) + + num_dims = params["num_dims"] + shape = params["shape"] + n = params["n"] + dim = params["dim"] + norm = params["norm"] + + dtype = test_case.dtype_dict["real"] + + x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) + y = torch.fft.rfft2(x, n, dim, norm) + + return y + + +def _test_irfft2(test_case): + if is_cufft_available(): + device = random_device() + else: + device = cpu_device() + + lower_n_dims = test_case.ndims_dict["2d"]["lower_n_dims"] + upper_n_dims = test_case.ndims_dict["2d"]["upper_n_dims"] + params = gen_params_2d_fft(lower_n_dims, upper_n_dims) + + num_dims = params["num_dims"] + shape = params["shape"] + n = params["n"] + dim = params["dim"] + norm = params["norm"] + dtype = test_case.dtype_dict["complex"] + + x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) + y = torch.fft.irfft2(x, n, dim, norm) + + return y + + +def _test_hfft2(test_case): + if is_cufft_available(): + device = random_device() + else: + device = cpu_device() + + lower_n_dims = test_case.ndims_dict["2d"]["lower_n_dims"] + upper_n_dims = test_case.ndims_dict["2d"]["upper_n_dims"] + params = gen_params_2d_fft(lower_n_dims, upper_n_dims) + + num_dims = params["num_dims"] + shape = params["shape"] + n = params["n"] + dim = params["dim"] + norm = params["norm"] + dtype = test_case.dtype_dict["complex"] + + x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) + y = torch.fft.hfft2(x, n, dim, norm) + + return y + + +def _test_ihfft2(test_case): + if is_cufft_available(): + device = random_device() + else: + device = cpu_device() + + lower_n_dims = test_case.ndims_dict["2d"]["lower_n_dims"] + upper_n_dims = test_case.ndims_dict["2d"]["upper_n_dims"] + params = gen_params_2d_fft(lower_n_dims, upper_n_dims) + + num_dims = params["num_dims"] + shape = params["shape"] + n = params["n"] + dim = params["dim"] + norm = params["norm"] + dtype = test_case.dtype_dict["real"] + + x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) + y = torch.fft.ihfft2(x, n, dim, norm) + + return y + + +def _test_fftn(test_case): + if is_cufft_available(): + device = random_device() + else: + device = cpu_device() + + lower_n_dims = test_case.ndims_dict["nd"]["lower_n_dims"] + upper_n_dims = test_case.ndims_dict["nd"]["upper_n_dims"] + params = gen_params_nd_fft(lower_n_dims, upper_n_dims) + + num_dims = params["num_dims"] + shape = params["shape"] + n = params["n"] + dim = params["dim"] + norm = params["norm"] + + x = random_tensor(num_dims, dtype=float, *shape) + if is_complex_dtype(x.dtype): + # test fft_c2c + dtype = test_case.dtype_dict["complex"] + x = x.to(device=device, dtype=dtype) + else: + # test fft_r2c + dtype = test_case.dtype_dict["real"] + x = x.to(device=device, dtype=dtype) + y = torch.fft.fftn(x, n, dim, norm) + + return y + + +def _test_ifftn(test_case): + if is_cufft_available(): + device = random_device() + else: + device = cpu_device() + + lower_n_dims = test_case.ndims_dict["nd"]["lower_n_dims"] + upper_n_dims = test_case.ndims_dict["nd"]["upper_n_dims"] + params = gen_params_nd_fft(lower_n_dims, upper_n_dims) + + num_dims = params["num_dims"] + shape = params["shape"] + n = params["n"] + dim = params["dim"] + norm = params["norm"] + + x = random_tensor(num_dims, dtype=float, *shape) + if is_complex_dtype(x.dtype): + # test fft_c2c + dtype = test_case.dtype_dict["complex"] + x = x.to(device=device, dtype=dtype) + else: + # test fft_r2c + dtype = test_case.dtype_dict["real"] + x = x.to(device=device, dtype=dtype) + + y = torch.fft.ifftn(x, n, dim, norm) + + return y + + +def _test_rfftn(test_case): + if is_cufft_available(): + device = random_device() + else: + device = cpu_device() + + lower_n_dims = test_case.ndims_dict["nd"]["lower_n_dims"] + upper_n_dims = test_case.ndims_dict["nd"]["upper_n_dims"] + params = gen_params_nd_fft(lower_n_dims, upper_n_dims) + + num_dims = params["num_dims"] + shape = params["shape"] + n = params["n"] + dim = params["dim"] + norm = params["norm"] + + dtype = test_case.dtype_dict["real"] + + x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) + y = torch.fft.rfftn(x, n, dim, norm) + + return y + + +def _test_irfftn(test_case): + if is_cufft_available(): + device = random_device() + else: + device = cpu_device() + + lower_n_dims = test_case.ndims_dict["nd"]["lower_n_dims"] + upper_n_dims = test_case.ndims_dict["nd"]["upper_n_dims"] + params = gen_params_nd_fft(lower_n_dims, upper_n_dims) + + num_dims = params["num_dims"] + shape = params["shape"] + n = params["n"] + dim = params["dim"] + norm = params["norm"] + dtype = test_case.dtype_dict["complex"] + + x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) + y = torch.fft.irfftn(x, n, dim, norm) + + return y + + +def _test_hfftn(test_case): + if is_cufft_available(): + device = random_device() + else: + device = cpu_device() + + lower_n_dims = test_case.ndims_dict["nd"]["lower_n_dims"] + upper_n_dims = test_case.ndims_dict["nd"]["upper_n_dims"] + params = gen_params_nd_fft(lower_n_dims, upper_n_dims) + + num_dims = params["num_dims"] + shape = params["shape"] + n = params["n"] + dim = params["dim"] + norm = params["norm"] + dtype = test_case.dtype_dict["complex"] + + x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) + y = torch.fft.hfftn(x, n, dim, norm) + + return y + + +def _test_ihfftn(test_case): + if is_cufft_available(): + device = random_device() + else: + device = cpu_device() + + lower_n_dims = test_case.ndims_dict["nd"]["lower_n_dims"] + upper_n_dims = test_case.ndims_dict["nd"]["upper_n_dims"] + params = gen_params_nd_fft(lower_n_dims, upper_n_dims) + + num_dims = params["num_dims"] + shape = params["shape"] + n = params["n"] + dim = params["dim"] + norm = params["norm"] + dtype = test_case.dtype_dict["real"] + + x = random_tensor(num_dims, dtype=float, *shape).to(device=device, dtype=dtype) + y = torch.fft.ihfftn(x, n, dim, norm) + + return y + + +# NOTE: skip for multi-nodes and multi-devices now, because it failed in ci randomly +@flow.unittest.skip_unless_1n1d() +class TestComplex64Fft(flow.unittest.TestCase): + def setUp(test_case): + # should override by other data type of complex + test_case.ndims_dict = { + "1d": {"lower_n_dims": 1, "upper_n_dims": 5}, + "2d": {"lower_n_dims": 2, "upper_n_dims": 5}, + "nd": {"lower_n_dims": 1, "upper_n_dims": 5}, + } + + test_case.dtype_dict = {"real": torch.float32, "complex": torch.complex64} + + test_case.rtol = 1e-5 + test_case.atol = 1e-5 + test_case.initTestFft() + + def initTestFft(test_case): + test_case.test_fft = autotest( + n=5, + auto_backward=True, + rtol=test_case.rtol, + atol=test_case.atol, + check_graph=False, + check_grad_use_random_data=True, + include_complex=True, + )(_test_fft) + + test_case.test_ifft = autotest( + n=5, + auto_backward=True, + rtol=test_case.rtol, + atol=test_case.atol, + check_graph=False, + check_grad_use_random_data=True, + include_complex=True, + )(_test_ifft) + + test_case.test_rfft = autotest( + n=5, + auto_backward=True, + rtol=test_case.rtol, + atol=test_case.atol, + check_graph=False, + check_grad_use_random_data=True, + include_complex=False, + )(_test_rfft) + + test_case.test_irfft = autotest( + n=5, + auto_backward=True, + rtol=test_case.rtol, + atol=test_case.atol, + check_graph=False, + check_grad_use_random_data=True, + include_complex=True, + )(_test_irfft) + + test_case.test_hfft = autotest( + n=5, + auto_backward=True, + rtol=test_case.rtol, + atol=test_case.atol, + check_graph=False, + check_grad_use_random_data=True, + include_complex=True, + )(_test_hfft) + + test_case.test_ihfft = autotest( + n=5, + auto_backward=True, + rtol=test_case.rtol, + atol=test_case.atol, + check_graph=False, + check_grad_use_random_data=True, + include_complex=False, + )(_test_ihfft) + + test_case.test_fft2 = autotest( + n=5, + auto_backward=True, + rtol=test_case.rtol, + atol=test_case.atol, + check_graph=False, + check_grad_use_random_data=True, + include_complex=True, + )(_test_fft2) + + test_case.test_ifft2 = autotest( + n=5, + auto_backward=True, + rtol=test_case.rtol, + atol=test_case.atol, + check_graph=False, + check_grad_use_random_data=True, + include_complex=True, + )(_test_ifft2) + + test_case.test_rfft2 = autotest( + n=5, + auto_backward=True, + rtol=test_case.rtol, + atol=test_case.atol, + check_graph=False, + check_grad_use_random_data=True, + include_complex=False, + )(_test_rfft2) + + test_case.test_irfft2 = autotest( + n=5, + auto_backward=True, + rtol=test_case.rtol, + atol=test_case.atol + * 100, # NOTE: ND-dimension of fft_c2r expands the numerical accuracy error + check_graph=False, + check_grad_use_random_data=True, + include_complex=True, + )(_test_irfft2) + + test_case.test_hfft2 = autotest( + n=5, + auto_backward=True, + rtol=test_case.rtol, + atol=test_case.atol + * 100, # NOTE: ND-dimension of fft_c2r expands the numerical accuracy error + check_graph=False, + check_grad_use_random_data=True, + include_complex=True, + )(_test_hfft2) + + test_case.test_ihfft2 = autotest( + n=5, + auto_backward=True, + rtol=test_case.rtol, + atol=test_case.atol, + check_graph=False, + check_grad_use_random_data=True, + include_complex=False, + )(_test_ihfft2) + + test_case.test_fftn = autotest( + n=5, + auto_backward=True, + rtol=test_case.rtol, + atol=test_case.atol * 1e2, # NOTE: + check_graph=False, + check_grad_use_random_data=True, + include_complex=True, + )(_test_fftn) + + test_case.test_ifftn = autotest( + n=5, + auto_backward=True, + rtol=test_case.rtol, + atol=test_case.atol * 1e2, + check_graph=False, + check_grad_use_random_data=True, + include_complex=True, + )(_test_ifftn) + + test_case.test_rfftn = autotest( + n=5, + auto_backward=True, + rtol=test_case.rtol, + atol=test_case.atol * 1e2, + check_graph=False, + check_grad_use_random_data=True, + include_complex=False, + )(_test_rfftn) + + test_case.test_irfftn = autotest( + n=5, + auto_backward=True, + rtol=test_case.rtol, + atol=test_case.atol + * 1e2, # NOTE: ND-dimension of fft_c2r expands the numerical accuracy error + check_graph=False, + check_grad_use_random_data=True, + include_complex=True, + )(_test_irfftn) + + test_case.test_hfftn = autotest( + n=5, + auto_backward=True, + rtol=test_case.rtol, + atol=test_case.atol + * 1e2, # NOTE: ND-dimension of fft_c2r expands the numerical accuracy error + check_graph=False, + check_grad_use_random_data=True, + include_complex=True, + )(_test_hfftn) + + test_case.test_ihfftn = autotest( + n=5, + auto_backward=True, + rtol=test_case.rtol, + atol=test_case.atol * 1e2, + check_graph=False, + check_grad_use_random_data=True, + include_complex=False, + )(_test_ihfftn) + + def test_1d_fft(test_case): + arg_dict = OrderedDict() + arg_dict["test_fun"] = [ + test_case.test_fft, + test_case.test_ifft, + test_case.test_rfft, + test_case.test_irfft, + test_case.test_hfft, + test_case.test_ihfft, + ] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + + def test_2d_fft_except_hfft2(test_case): + arg_dict = OrderedDict() + arg_dict["test_fun"] = [ + test_case.test_fft2, + test_case.test_ifft2, + test_case.test_rfft2, + test_case.test_irfft2, + ] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + + @unittest.skipIf( + version.parse(torch_original.__version__) < version.parse("1.11.0"), + "module 'torch.fft' has no attribute 'hfft2' or 'ihfft2' before '1.11.0'", + ) + def test_2d_fft_hfft2(test_case): + arg_dict = OrderedDict() + arg_dict["test_fun"] = [test_case.test_hfft2, test_case.test_ihfft2] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + + def test_nd_fft_except_hfftn(test_case): + arg_dict = OrderedDict() + arg_dict["test_fun"] = [ + test_case.test_fftn, + test_case.test_ifftn, + test_case.test_rfftn, + test_case.test_irfftn, + ] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + + @unittest.skipIf( + version.parse(torch_original.__version__) < version.parse("1.11.0"), + "module 'torch.fft' has no attribute 'hfftn' or 'ihfftn' before '1.11.0'", + ) + def test_nd_fft_hfftn(test_case): + arg_dict = OrderedDict() + arg_dict["test_fun"] = [test_case.test_hfftn, test_case.test_ihfftn] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + + +# NOTE: skip for multi-nodes and multi-devices now, because it failed in ci randomly +@flow.unittest.skip_unless_1n1d() +class TestComplex128Fft(TestComplex64Fft): + def setUp(test_case): + # should override by other data type of complex + test_case.ndims_dict = { + "1d": {"lower_n_dims": 1, "upper_n_dims": 5}, + "2d": {"lower_n_dims": 2, "upper_n_dims": 5}, + "nd": {"lower_n_dims": 1, "upper_n_dims": 5}, + } + + test_case.dtype_dict = {"real": torch.float64, "complex": torch.complex128} + + test_case.rtol = 1e-7 + test_case.atol = 1e-7 + test_case.initTestFft() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test_utils/automated_test_util/generators.py b/python/oneflow/test_utils/automated_test_util/generators.py index 07160a22590..24c22159a6e 100644 --- a/python/oneflow/test_utils/automated_test_util/generators.py +++ b/python/oneflow/test_utils/automated_test_util/generators.py @@ -39,7 +39,7 @@ annotation2default_generator = {} annotation2torch_to_flow_converter = {} NoneType = type(None) -random_value_default_range = {int: (-10, 11), float: (-1, 1)} +random_value_default_range = {int: (-10, 11), float: (-1, 1), complex: (-10, 10)} def data_generator(annotation): @@ -374,6 +374,14 @@ def _calc_value(self): if pin_memory: res = res.pin_memory() return res + elif dtype == complex: + np_arr = rng.uniform(low=low, high=high, size=shape) + 1.0j * rng.uniform( + low=low, high=high, size=shape + ) + res = torch.tensor(np_arr, dtype=torch.complex64) + if pin_memory: + res = res.pin_memory() + return res else: raise NotImplementedError(f"Not implemented dtype {dtype} in random") diff --git a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py index 86f8ed83caa..78a5c50b229 100644 --- a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py +++ b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py @@ -55,6 +55,7 @@ testing = False testing_graph = False +testing_complex = False global_check_allclose = True global_atol = 1e-5 global_rtol = 1e-5 @@ -1137,7 +1138,11 @@ def check_tensor_equality( assert ( flow_tensor.grad is not None ), f"OneFlow tensor doesn't have grad while PyTorch tensor has one, PyTorch tensor is\n {torch_tensor}\n, OneFlow tensor is\n{flow_tensor} " - torch_grad = torch_tensor.grad.detach().cpu().numpy() + torch_grad = ( + torch_tensor.grad.detach().cpu().numpy() + if not torch_original.is_conj(torch_tensor.grad) + else torch_original.resolve_conj(torch_tensor.grad.detach()).cpu().numpy() + ) flow_grad = flow_tensor.grad.numpy() if not np.allclose( torch_grad, flow_grad, rtol=rtol, atol=atol, equal_nan=True, @@ -1150,7 +1155,11 @@ def check_tensor_equality( f"Grads are not equal. PyTorch grad: \n{torch_grad}\n, OneFlow grad: \n{flow_grad}" ) return False - torch_numpy = torch_tensor.detach().cpu().numpy() + torch_numpy = ( + torch_tensor.detach().cpu().numpy() + if not torch_original.is_conj(torch_tensor) + else torch_original.resolve_conj(torch_tensor.detach()).cpu().numpy() + ) oneflow_numpy = flow_tensor.numpy() equality_res = np.allclose( torch_numpy, oneflow_numpy, rtol=rtol, atol=atol, equal_nan=True, @@ -1219,6 +1228,7 @@ def autotest( check_allclose=True, check_dtype=False, check_grad_use_random_data=True, + include_complex=False, ): verbose = os.getenv("ONEFLOW_TEST_VERBOSE") is not None @@ -1253,9 +1263,16 @@ def new_f(test_case, *args, **kwargs): testing = True if check_graph: testing_graph = True + + global testing_complex + if include_complex: + testing_complex = True + res = f(test_case, *args, **kwargs) + testing = False testing_graph = False + testing_complex = False except (PyTorchDoesNotSupportError, BothDoNotSupportError) as e: if verbose: print(f"{f.__name__}") @@ -1387,6 +1404,10 @@ def random_tensor( ): if isinstance(requires_grad, generator): requires_grad = requires_grad.value() + if dtype == float and testing_complex: + # Generate complex with the probability of 0.5 + dtype = complex if rng.integers(0, 2) == 1 else float + pytorch_tensor = ( random_pytorch_tensor( ndim, dim0, dim1, dim2, dim3, dim4, low, high, dtype, pin_memory