Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FFT operators #10027

Merged
merged 190 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
190 commits
Select commit Hold shift + click to select a range
0281680
init commit for add cfloat and cdouble dtype
Mar 7, 2023
9b915e7
code polish
Mar 7, 2023
faf2076
Merge branch 'master' into lml/complex
levi131 Mar 7, 2023
ca41336
Merge remote-tracking branch 'upstream/master' into lml/complex
Mar 8, 2023
c602d02
save work
Mar 9, 2023
661df1e
Merge remote-tracking branch 'upstream/master' into lml/complex
Mar 9, 2023
9d5b054
save work
Mar 9, 2023
9466b8e
merge upstream master
Mar 9, 2023
3896bd1
update work
Mar 10, 2023
797d483
add oneflow.complex64 and oneflow.complex128
Mar 10, 2023
e25f0a4
fix bug for complex128
Mar 10, 2023
d4295de
fix bug for get item for scalar complex tensor
Mar 11, 2023
34cb7ff
update format
Mar 11, 2023
d3287a8
update work state
Mar 12, 2023
6b6e884
update format
Mar 12, 2023
f66c2d4
rm some useless code
Mar 13, 2023
07df8de
Merge remote-tracking branch 'upstream/master' into lml/complex
Mar 13, 2023
ffbca31
fix format
Mar 13, 2023
46f7d3c
save work state
Mar 14, 2023
e17b942
add complex64 and complex128 for cpu primitive
Mar 14, 2023
ea74be6
refine format
Mar 14, 2023
cc59ba2
skip test tensor cuda on CPU only CI
Mar 14, 2023
d22f576
add import os
Mar 14, 2023
fbb4c3e
Add c2c, r2c, c2r Op.
MarioLulab Mar 14, 2023
2c6af0f
rm default value for ComplexDoubleAttr
Mar 14, 2023
95c418c
Merge remote-tracking branch 'upstream/master' into lml/complex_tenso…
Mar 15, 2023
02d0656
rm unused construct function for calss scalar
Mar 15, 2023
795ae7d
refine class Scalar and add transform for ComplexDoubleAttr
Mar 15, 2023
b00334d
refine format
Mar 15, 2023
fac2e1f
add set active_tag
Mar 15, 2023
2c0f6eb
use DataType_ARRAYSIZE macro and oneflow::Hash
Mar 15, 2023
417bd6d
fix bug in Importer.cpp
Mar 15, 2023
44fff90
fix for ci
Mar 15, 2023
7364d18
fix for ci
Mar 15, 2023
79f2708
fix bug, the second set real -> set imag
Mar 16, 2023
beb2df2
modify place of some code and remove clear just before set
Mar 16, 2023
0cff5c6
modify attr of fft op
MarioLulab Mar 16, 2023
3204ab0
add c2c, r2c, fft, ifft functor
MarioLulab Mar 16, 2023
74a825b
fix complie error
Mar 16, 2023
c714a6e
add c2c cpu kernels, to-do register.
MarioLulab Mar 16, 2023
ce10783
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
MarioLulab Mar 17, 2023
1d62f49
register fft_c2c keernel
MarioLulab Mar 17, 2023
5e11946
Merge remote-tracking branch 'remotes/lml_complex_tensor_and_complex_…
MarioLulab Mar 17, 2023
1b88afa
modify fft kernels.
MarioLulab Mar 17, 2023
870b515
Merge branch 'Oneflow-Inc:master' into luqi/dev_fft_based_complex
MarioLulab Mar 17, 2023
f300b6a
tmodify calling logic of pocketfftplan, in order to avoid compile er…
MarioLulab Mar 17, 2023
63a26b5
add conj_symmetric() and r2c kernel.
MarioLulab Mar 19, 2023
e8bb7d5
add backward interface.
MarioLulab Mar 19, 2023
2a73308
disable r2c functor and add helper function
MarioLulab Mar 20, 2023
077ac6a
add c2c baackward
MarioLulab Mar 20, 2023
62e1b4c
fix compile error
MarioLulab Mar 20, 2023
6ec4c62
foormat code using [files] 2754
MarioLulab Mar 20, 2023
fe69b40
Merge branch 'Oneflow-Inc:master' into luqi/dev_fft_based_complex
MarioLulab Mar 20, 2023
e81886d
modify include head of math_functor
MarioLulab Mar 20, 2023
65c32e2
fix compile error.
MarioLulab Mar 21, 2023
85328c2
fix undefined symbol error
MarioLulab Mar 21, 2023
592be25
explicitly instantiate
MarioLulab Mar 21, 2023
3d38db2
Merge branch 'Oneflow-Inc:master' into luqi/dev_fft_based_complex
MarioLulab Mar 21, 2023
bddc999
fix compile error.
MarioLulab Mar 23, 2023
7c0c413
decrease template parameters of fft_kernels and fft_kernel_util
MarioLulab Mar 23, 2023
d8d89b1
for debuug
MarioLulab Mar 23, 2023
018358c
add testcase
MarioLulab Mar 24, 2023
50d259b
fix runtime error for SupportContiguousTensor
MarioLulab Mar 24, 2023
396939c
save work status
Mar 24, 2023
afca267
rm conj op
Mar 24, 2023
5c140d8
remove imag_op.cpp
Mar 24, 2023
af420cd
support autograd
Mar 24, 2023
62d767d
register complex data type for reduce binary functors
MarioLulab Mar 24, 2023
8f99f46
success for fft_c2c forward but failed in backward.
MarioLulab Mar 24, 2023
7b6908b
add complex type seq into CPU_PRIMITIVE_ALL_TYPE_SEQ
MarioLulab Mar 27, 2023
6b5d006
pass compile
Mar 27, 2023
2310075
refine format
Mar 27, 2023
9508011
fix spell bug: IsFloating16 -> IsComplex
Mar 27, 2023
557d118
save status
Mar 27, 2023
8121602
Finish fft_c2c, Support fft, ifft, fftn, ifftn Now.
MarioLulab Mar 28, 2023
6c005c0
update test file.
MarioLulab Mar 28, 2023
c813a7f
add option is_grad_fn to be used in backward pass
MarioLulab Mar 28, 2023
61cd37b
add complex data type for binary operation
MarioLulab Mar 28, 2023
074084d
fix bug of data type promoting priority
MarioLulab Mar 28, 2023
75d20ea
of_format
MarioLulab Mar 28, 2023
2b7c4f8
support add for complex on cpu
Mar 28, 2023
7b06095
finish work except test
Mar 28, 2023
0c3a016
of_format
Mar 28, 2023
64c1085
support complex for fill cuda
Mar 29, 2023
5911f9f
of_format
Mar 29, 2023
999769a
fix test
Mar 29, 2023
d1c48d8
Merge branch 'master' into lml/conj
levi131 Mar 29, 2023
61ffcc3
fix for ci
Mar 29, 2023
dc19e4e
prepare for merge
MarioLulab Mar 30, 2023
a45880d
Merge remote-tracking branch 'lml_complex_tensor_and_complex_attr/lml…
MarioLulab Mar 30, 2023
cf4cbd5
add fft_r2c forward and backward pass, but demanding test.
MarioLulab Mar 30, 2023
a7a7c00
rfft test pass
MarioLulab Mar 30, 2023
183b778
fix bug
Mar 30, 2023
3a0c984
merge master
Mar 30, 2023
2572353
fix bug
Mar 30, 2023
a890bce
readd IsFloat16 Trait
Mar 30, 2023
caa7f62
fix format
Mar 30, 2023
6b8b0d3
modify some docstring
Mar 31, 2023
4b9c36b
Merge branch 'master' into lml/cuComplex_checkdatatype
levi131 Mar 31, 2023
7299dea
add fft_c2r, but not testing yet
MarioLulab Mar 31, 2023
cd171dc
Merge branch 'master' into lml/cuComplex_checkdatatype
levi131 Mar 31, 2023
f442e1b
add fft_c2r grad, but not test yet. find bug in rfft when shape[dim] …
MarioLulab Mar 31, 2023
23de2a4
fix fft_r2c backward segment fault
MarioLulab Mar 31, 2023
cc56e3f
Merge remote-tracking branch 'upstream/lml/cuComplex_checkdatatype' i…
Apr 2, 2023
754bff2
merge master
Apr 2, 2023
9b3506e
support fftn, rfftn, hfftn
MarioLulab Apr 3, 2023
ae55408
use Real functor in casting complex tensor to real tensor. Finish fft…
MarioLulab Apr 3, 2023
ff1d095
finish rfftn
MarioLulab Apr 3, 2023
98c103d
finish irfftn
MarioLulab Apr 3, 2023
ec64c86
finish hfftn
MarioLulab Apr 3, 2023
e402163
finish ihfftn
MarioLulab Apr 3, 2023
da84dcd
code polish
MarioLulab Apr 3, 2023
7a5890d
code polish and modify test files
MarioLulab Apr 3, 2023
50776d0
remote debug file
MarioLulab Apr 4, 2023
ad87f04
of_format
MarioLulab Apr 4, 2023
f84ef3d
restore stft and test pass.
MarioLulab Apr 4, 2023
acea5f7
remove optional qualifier of *Fft2Functor, and use TensorProcessor to…
MarioLulab Apr 4, 2023
dae45c0
modify stft exception to pass test_stft_op.py
MarioLulab Apr 4, 2023
e44d796
add python function interface for FFT
MarioLulab Apr 4, 2023
c4b8991
remove debug code and redundant include headers
MarioLulab Apr 4, 2023
190b0c9
modify cast op backward
MarioLulab Apr 4, 2023
f43794a
restore optional qualifiers of s, remove optional qualifiers of dim
MarioLulab Apr 4, 2023
8990c75
remove std::complex<T> from fft_kernels
MarioLulab Apr 6, 2023
daa17dc
make code clean
MarioLulab Apr 6, 2023
a3e4d4b
support autotest for complex tensor testing
MarioLulab Apr 6, 2023
97513c2
of_format
MarioLulab Apr 6, 2023
280b854
merge master
Apr 6, 2023
0b56f20
fix add.cpp, ALL_DATATYPE_SEQ contains COMPLEX_DATATYPE_SEQ
Apr 6, 2023
0f1214c
enable cuda version
MarioLulab Apr 6, 2023
7e7981b
Merge branch 'master' into luqi/dev_fft_based_complex
MarioLulab Apr 6, 2023
b0f94a6
Merge remote-tracking branch 'lml_complex_tensor_and_complex_attr/lml…
MarioLulab Apr 6, 2023
1e24826
fix complie error, and add cufft utils.
MarioLulab Apr 7, 2023
f033ce2
add data layout of cufft
MarioLulab Apr 7, 2023
9d4a685
Merge remote-tracking branch 'upstream/master' into luqi/dev_fft_base…
MarioLulab Apr 7, 2023
dd90a9c
refactor cufft_plan_cache
MarioLulab Apr 7, 2023
1a72500
add infer tmp_buffer fn
MarioLulab Apr 10, 2023
4a6da71
add op in ccuda
MarioLulab Apr 14, 2023
976329d
add binary add,sub,mul, add unary cast, constant pad. Fix Complile Er…
MarioLulab Apr 17, 2023
e5dc825
operator overload compat cub::DeviceReduce::Reduce
MarioLulab Apr 17, 2023
5835871
prepare for conflict solving
MarioLulab Apr 17, 2023
e1a9933
Merge remote-tracking branch 'remotes/origin/master' into luqi/dev_ff…
MarioLulab Apr 17, 2023
f9d91af
merged
MarioLulab Apr 17, 2023
3fc4a3e
conflict solving
MarioLulab Apr 17, 2023
84e83a0
add equal and not_equal
MarioLulab Apr 17, 2023
62cb70f
of_format
MarioLulab Apr 17, 2023
f861729
Merge remote-tracking branch 'remotes/origin/luqi/add_complex_cuda_op…
MarioLulab Apr 17, 2023
7671641
support complex for autotest
MarioLulab Apr 20, 2023
34e450c
cuda fft_c2c done.
MarioLulab Apr 21, 2023
3198a5d
Merge remote-tracking branch 'remotes/origin/master' into luqi/dev_ff…
MarioLulab Apr 21, 2023
1dcb4aa
add cuda fft_r2c
MarioLulab Apr 23, 2023
d4d287d
finish cuda r2c op, but not test yet
MarioLulab Apr 23, 2023
0944878
replace `norm_str` with `norm_mode` simplify logic
MarioLulab Apr 23, 2023
23f15c7
add FFTC2CWrapper to decouple logic
MarioLulab Apr 23, 2023
c9f5c3d
test pass fft_c2c and fft_r2c
MarioLulab Apr 25, 2023
336684a
modify math_functor and fft gradient func
MarioLulab Apr 27, 2023
2b3b170
test pass fft_c2c, fft_r2c, fft_c2r on cuda
MarioLulab Apr 27, 2023
0338013
fix scale mul
MarioLulab Apr 27, 2023
e227d92
delete redundant debug info
MarioLulab Apr 27, 2023
848fe8d
remove debug info
MarioLulab Apr 27, 2023
4dfca18
remove redundant files
MarioLulab Apr 28, 2023
39e0a4b
compat stft inot new fft module
MarioLulab May 5, 2023
4ca63a1
limit index helper max_ndim and code polish
MarioLulab May 5, 2023
197f829
remove default attr value of fft ops
MarioLulab May 8, 2023
d549427
use macro of OF_CUFFT_CHECK
MarioLulab May 8, 2023
1df71f3
delete redundant code
MarioLulab May 8, 2023
4b4e5f2
refactor Functor of FFT
MarioLulab May 8, 2023
bd9fbb1
remove debug info of op call
MarioLulab May 8, 2023
fa80d9b
code polish of fft kernel and ops
MarioLulab May 8, 2023
6ae39ff
add doc of fft module
MarioLulab May 8, 2023
328fa6d
Delete code of duplicate throws exception
MarioLulab May 10, 2023
5d5d34c
Modify autotest to support the “include_complex”
MarioLulab May 10, 2023
2530677
delete blankspace
MarioLulab May 10, 2023
164e266
of_format
MarioLulab May 10, 2023
c1fdb12
update test_fft.py
MarioLulab May 10, 2023
3861eb1
Merge branch 'master' into luqi/dev_fft_based_complex
levi131 May 10, 2023
2fe88c0
fix for ci
MarioLulab May 10, 2023
52bf0b5
of_format
MarioLulab May 10, 2023
afb9a0a
of_format
MarioLulab May 10, 2023
0b28829
refator test_fft.py and per-tensor gen complex
MarioLulab May 11, 2023
3905e57
fix for ci
MarioLulab May 11, 2023
aba72fe
fix for ci
MarioLulab May 11, 2023
c18f9ed
skip multi node test for ci
MarioLulab May 15, 2023
35c9f5a
Merge branch 'master' into luqi/dev_fft_based_complex
MarioLulab May 15, 2023
83c008e
of_format
MarioLulab May 15, 2023
855c123
Merge remote-tracking branch 'origin/master' into luqi/dev_fft_based_…
MarioLulab May 16, 2023
b0c27a5
remove redudant ewise binary op
MarioLulab May 16, 2023
f02d98d
Merge branch 'master' into luqi/dev_fft_based_complex
MarioLulab May 16, 2023
1fef9f7
revert broadcast_elementwise_binary.cpp
MarioLulab May 16, 2023
35feec6
Merge branch 'master' into luqi/dev_fft_based_complex
mergify[bot] May 17, 2023
d368edc
Update oneflow/user/kernels/fft_kernels.cpp
MarioLulab May 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 201 additions & 0 deletions oneflow/core/autograd/gradient_funcs/fft.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <string>
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"

namespace oneflow {
namespace one {

struct FftR2CCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
bool onesided = false;
std::vector<int64_t> dims;
DimVector input_shape_vec;
int32_t norm_mode = 0;
};

class FftR2C : public OpExprGradFunction<FftR2CCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Maybe<void> Capture(FftR2CCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1) << "RuntimeError: assert `inputs.size() == 1`";
ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

ctx->onesided = JUST(attrs.GetAttr<bool>("onesided"));
ctx->dims = JUST(attrs.GetAttr<std::vector<int64_t>>("dims"));
ctx->norm_mode = JUST(attrs.GetAttr<int32_t>("norm_mode"));
ctx->input_shape_vec = JUST(oneflow::VectorAt(inputs, 0))->shape()->dim_vec();

return Maybe<void>::Ok();
}

Maybe<void> Apply(const FftR2CCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "RuntimeError: assert `out_grads.size() == 1`";
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

in_grads->resize(1);
if (!ctx->onesided) {
auto complex_grad = JUST(functional::FftC2C(JUST(oneflow::VectorAt(out_grads, 0)), NullOpt,
ctx->dims, ctx->norm_mode,
/*forward=*/false, /*normalized=*/false));
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::Real(complex_grad));
} else {
std::vector<int64_t> fft_dims = ctx->dims;
std::vector<int64_t> fft_shapes(fft_dims.size(), 0);
FOR_RANGE(size_t, i, 0, fft_dims.size()) {
fft_shapes[i] = ctx->input_shape_vec[fft_dims[i]];
}

// fill the last dim
bool must_copy = false;
auto x_sizes = JUST(oneflow::VectorAt(out_grads, 0))->shape()->dim_vec();
std::vector<int64_t> pad_amount(x_sizes.size() * 2, 0);
int64_t last_dim = ctx->dims.back();
if (x_sizes[last_dim] < ctx->input_shape_vec[last_dim]) {
must_copy = true;
auto pad_idx = pad_amount.size() - 2 * last_dim - 1;
pad_amount[pad_idx] = ctx->input_shape_vec[last_dim] - x_sizes[last_dim];
}
auto complex_full_grad =
must_copy
? JUST(functional::ConstantPad(JUST(oneflow::VectorAt(out_grads, 0)), pad_amount, 0))
: JUST(oneflow::VectorAt(out_grads, 0));
complex_full_grad =
JUST(functional::FftC2C(complex_full_grad, NullOpt, ctx->dims, ctx->norm_mode,
/*forward=*/false, /*normalized=*/false));

JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::Real(complex_full_grad));
}

return Maybe<void>::Ok();
}
};

struct FftC2CCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
bool forward = false;
std::vector<int64_t> dims;
int32_t norm_mode = 0;
};

class FftC2C : public OpExprGradFunction<FftC2CCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Maybe<void> Capture(FftC2CCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1) << "RuntimeError: assert `inputs.size() == 1`";

ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

ctx->forward = JUST(attrs.GetAttr<bool>("forward"));
ctx->dims = JUST(attrs.GetAttr<std::vector<int64_t>>("dims"));
ctx->norm_mode = JUST(attrs.GetAttr<int32_t>("norm_mode"));

return Maybe<void>::Ok();
}

Maybe<void> Apply(const FftC2CCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "RuntimeError: assert `out_grads.size() == 1`";
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

in_grads->resize(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::FftC2C(
JUST(oneflow::VectorAt(out_grads, 0)), NullOpt, ctx->dims, ctx->norm_mode,
/*forward=*/!(ctx->forward), /*normalized=*/false));
return Maybe<void>::Ok();
}
};

struct FftC2RCaptureState : public AutoGradCaptureState {
levi131 marked this conversation as resolved.
Show resolved Hide resolved
bool requires_grad = false;
std::vector<int64_t> dims;
int32_t norm_mode = 0;
int64_t last_dim_size = 1;
DimVector input_shape_vec;
};

class FftC2R : public OpExprGradFunction<FftC2RCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Maybe<void> Capture(FftC2RCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1) << "RuntimeError: assert `inputs.size() == 1`";
ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

ctx->dims = JUST(attrs.GetAttr<std::vector<int64_t>>("dims"));
ctx->norm_mode = JUST(attrs.GetAttr<int32_t>("norm_mode"));
ctx->last_dim_size = JUST(attrs.GetAttr<int64_t>("last_dim_size"));
ctx->input_shape_vec = JUST(oneflow::VectorAt(inputs, 0))->shape()->dim_vec();

return Maybe<void>::Ok();
}

Maybe<void> Apply(const FftC2RCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "RuntimeError: out_grads.size() == 1";
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

in_grads->resize(1);

// NOTE: set `forward` True to prevent conjugating result
auto complex_grad = JUST(functional::FftR2C(
JUST(oneflow::VectorAt(out_grads, 0)), NullOpt, ctx->dims, ctx->norm_mode,
/*onesided=*/true, /*forward=*/true, /*normalized=*/false)); // no need conj
Shape input_shape(ctx->input_shape_vec);
int64_t last_dim = ctx->dims.back();
auto double_length =
JUST(oneflow::VectorAt(out_grads, 0))->dim(last_dim) - complex_grad->dim(last_dim);
auto in_grad = complex_grad;

// Mul by 2, and slice
if (double_length > 0) {
in_grad = JUST(functional::Narrow(complex_grad, last_dim, 1,
double_length)); // will change shape of in_grad
in_grad = JUST(functional::ScalarMul(in_grad, 2, /*inplace=*/true));
}

std::vector<int64_t> slice_st(input_shape.size(), 0);
std::vector<int64_t> slice_end(input_shape.begin(), input_shape.end());
std::vector<int64_t> slice_step(input_shape.size(), 1);
auto sliced_tensor =
JUST(functional::Slice(complex_grad, slice_st, slice_end, slice_step, false));

JUST(oneflow::VectorAt(*in_grads, 0)) = sliced_tensor;
return Maybe<void>::Ok();
}
};

REGISTER_OP_EXPR_GRAD_FUNCTION("fft_r2c", FftR2C);
REGISTER_OP_EXPR_GRAD_FUNCTION("fft_c2c", FftC2C);
REGISTER_OP_EXPR_GRAD_FUNCTION("fft_c2r", FftC2R);

} // namespace one

} // namespace oneflow
22 changes: 22 additions & 0 deletions oneflow/core/device/cuda_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,28 @@ const char* CurandGetErrorString(curandStatus_t error) {
}
}

const char* CuFFTGetErrorString(cufftResult_t error) {
switch (error) {
case CUFFT_SUCCESS: return "CUFFT_SUCCESS";
case CUFFT_INVALID_PLAN: return "CUFFT_INVALID_PLAN";
case CUFFT_ALLOC_FAILED: return "CUFFT_ALLOC_FAILED";
case CUFFT_INVALID_TYPE: return "CUFFT_INVALID_TYPE";
case CUFFT_INVALID_VALUE: return "CUFFT_INVALID_VALUE";
case CUFFT_INTERNAL_ERROR: return "CUFFT_INTERNAL_ERROR";
case CUFFT_EXEC_FAILED: return "CUFFT_EXEC_FAILED";
case CUFFT_SETUP_FAILED: return "CUFFT_SETUP_FAILED";
case CUFFT_INVALID_SIZE: return "CUFFT_INVALID_SIZE";
case CUFFT_UNALIGNED_DATA: return "CUFFT_UNALIGNED_DATA";
case CUFFT_INCOMPLETE_PARAMETER_LIST: return "CUFFT_INCOMPLETE_PARAMETER_LIST";
case CUFFT_INVALID_DEVICE: return "CUFFT_INVALID_DEVICE";
case CUFFT_PARSE_ERROR: return "CUFFT_PARSE_ERROR";
case CUFFT_NO_WORKSPACE: return "CUFFT_NO_WORKSPACE";
case CUFFT_NOT_IMPLEMENTED: return "CUFFT_NOT_IMPLEMENTED";
case CUFFT_NOT_SUPPORTED: return "CUFFT_NOT_SUPPORTED";
default: return "Unknown cufft status";
}
}

#if CUDA_VERSION >= 11000
const char* CusovlerGetErrorString(cusolverStatus_t error) {
switch (error) {
Expand Down
9 changes: 9 additions & 0 deletions oneflow/core/device/cuda_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include <cuda_runtime.h>
#include <cudnn.h>
#include <curand.h>
#include <cufft.h>
#include <nccl.h>
#include <cuda_fp16.h>
#if CUDA_VERSION >= 11000
Expand All @@ -51,6 +52,8 @@ const char* CublasGetErrorString(cublasStatus_t error);

const char* CurandGetErrorString(curandStatus_t error);

const char* CuFFTGetErrorString(cufftResult_t error);

#if CUDA_VERSION >= 11000
const char* CusovlerGetErrorString(cusolverStatus_t error);
#endif
Expand Down Expand Up @@ -78,6 +81,12 @@ const char* NvjpegGetErrorString(nvjpegStatus_t error);
LOG(FATAL) << "Check failed: " #condition " : " << CublasGetErrorString(_of_cublas_check_status) \
<< " (" << _of_cublas_check_status << ") "

#define OF_CUFFT_CHECK(condition) \
for (cufftResult_t _of_cufft_check_status = (condition); \
_of_cufft_check_status != CUFFT_SUCCESS;) \
LOG(FATAL) << "Check failed: " #condition " : " << CuFFTGetErrorString(_of_cufft_check_status) \
<< " (" << _of_cufft_check_status << ") "

#if CUDA_VERSION >= 11000
#define OF_CUSOLVER_CHECK(condition) \
for (cusolverStatus_t _of_cusolver_check_status = (condition); \
Expand Down
105 changes: 105 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3275,6 +3275,111 @@
'Tensor (Tensor input, Int64 n_fft,Int64 hop_length=None, Int64 win_length=None, Tensor window=None,Bool center=True,String pad_mode="reflect",Bool normalized=False,Bool onesided=True,Bool return_complex=False) =>Stft'
bind_python: True

- name: "fft_c2c"
signature:
'Tensor (Tensor input, Int64List n=None, Int64List dims=None, Int32 norm_mode=0, Bool forward=True, Bool normalized=False) => FftC2C'
bind_python: False
levi131 marked this conversation as resolved.
Show resolved Hide resolved

- name: "fft_r2c"
signature:
'Tensor (Tensor input, Int64List n=None, Int64List dims=None, Int32 norm_mode=0, Bool onesided=False, Bool forward=True, Bool normalized=False) => FftR2C'
bind_python: False

- name: "fft_c2r"
signature:
'Tensor (Tensor input, Int64List n=None, Int64List dims=None, Int32 norm_mode=0, Bool forward=True, Bool normalized=False) =>FftC2R'
bind_python: False

- name: "fft"
signature:
'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => Fft'
bind_python: True

- name: "ifft"
signature:
'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => IFft'
bind_python: True

- name: "fft2"
signature:
'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => Fft2'
bind_python: True

- name: "ifft2"
signature:
'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => IFft2'
bind_python: True

- name: "fftn"
signature:
'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => FftN'
bind_python: True

- name: "ifftn"
signature:
'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => IFftN'
bind_python: True

- name: "rfft"
signature:
'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => RFft'
bind_python: True

- name: "irfft"
signature:
'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => IRFft'
bind_python: True

- name: "rfft2"
signature:
'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => RFft2'
bind_python: True

- name: "irfft2"
signature:
'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => IRFft2'
bind_python: True

- name: "rfftn"
signature:
'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => RFftN'
bind_python: True

- name: "irfftn"
signature:
'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => IRFftN'
bind_python: True

- name: "hfft"
signature:
'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => HFft'
bind_python: True

- name: "ihfft"
signature:
'Tensor (Tensor input, Int64 n=-1, Int64 dim=-1, String norm=None) => IHFft'
bind_python: True

- name: "hfft2"
signature:
'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => HFft2'
bind_python: True

- name: "ihfft2"
signature:
'Tensor (Tensor input, Int64List s=None, Int64List dim, String norm=None) => IHFft2'
bind_python: True

- name: "hfftn"
signature:
'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => HFftN'
bind_python: True

- name: "ihfftn"
signature:
'Tensor (Tensor input, Int64List s=None, Int64List dim=None, String norm=None) => IHFftN'
bind_python: True

- name: "isclose"
signature: "Tensor (Tensor input, Tensor other, Float atol=1e-08, Float rtol=1e-05, Bool equal_nan=False) => IsClose"
bind_python: True
Expand Down
Loading