Skip to content

Commit

Permalink
Add op hardshrink (#7887)
Browse files Browse the repository at this point in the history
* add op hardshrink

* format code

* refine docs

* fix typo in unittest

* add inplace kernel registration

* add inplace kernel registration

* format code

* replace unsave ->at with VectorAt

* format code

* fix bug of wrong init of nn.Module.Hardshrink

* replace unsave inputs->at with VectorAt

* add error message for CHECK macro

* add error message for CHECK macro

* fix bug of docstr

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
marigoold and mergify[bot] authored Apr 15, 2022
1 parent ae13b04 commit e82f520
Show file tree
Hide file tree
Showing 15 changed files with 400 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Functional operations for neural networks
.. autofunction:: adaptive_avg_pool3d
.. autofunction:: relu
.. autofunction:: hardsigmoid
.. autofunction:: hardshrink
.. autofunction:: hardswish
.. autofunction:: hardtanh
.. autofunction:: normalize
Expand Down
1 change: 1 addition & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Operators for neural networks
GLU,
GroupNorm,
Hardsigmoid,
Hardshrink,
Hardswish,
Hardtanh,
Identity,
Expand Down
43 changes: 43 additions & 0 deletions oneflow/core/autograd/gradient_funcs/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,48 @@ class HardSigmoid : public BaseActivation {
}
};

struct HardShrinkCaptureState : public AutoGradCaptureState {
bool requires_grad = true;
double lambd = 0.5;
};

class HardShrink : public OpExprGradFunction<HardShrinkCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "Forward op must be not null";
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}

Maybe<void> Capture(HardShrinkCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1) << "Input grad size must be equal 1";
ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->lambd = JUST(composed_attrs.GetAttr<double>("lambd"));
ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 0)));
return Maybe<void>::Ok();
}

Maybe<void> Apply(const HardShrinkCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "Output grad size must be equal 1";
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& y = JUST(oneflow::VectorAt(ctx->SavedTensors(), 0));
*JUST(oneflow::VectorAt(in_grads, 0)) =
JUST(functional::HardShrinkGrad(y, JUST(oneflow::VectorAt(out_grads, 0)), ctx->lambd));
}
return Maybe<void>::Ok();
}

private:
AttrMap base_attrs_;
};

class HardSwish : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
Expand Down Expand Up @@ -508,6 +550,7 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("softsign", Softsign);
REGISTER_OP_EXPR_GRAD_FUNCTION("relu", ReLU);
REGISTER_OP_EXPR_GRAD_FUNCTION("gelu", GeLU);
REGISTER_OP_EXPR_GRAD_FUNCTION("hardsigmoid", HardSigmoid);
REGISTER_OP_EXPR_GRAD_FUNCTION("hardshrink", HardShrink);
REGISTER_OP_EXPR_GRAD_FUNCTION("hardswish", HardSwish);
REGISTER_OP_EXPR_GRAD_FUNCTION("leaky_relu", LeakyRelu);
REGISTER_OP_EXPR_GRAD_FUNCTION("hardtanh", HardTanh);
Expand Down
8 changes: 8 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,14 @@
signature: "Tensor (Tensor dy, Tensor x) => HardSigmoidGrad"
bind_python: False

- name: "hardshrink"
signature: "Tensor (Tensor x, *, Double lambd=0.5, Bool inplace=False) => HardShrink"
bind_python: True

- name: "hardshrink_grad"
signature: "Tensor (Tensor y, Tensor dy, Double lambd=0.5) => HardShrinkGrad"
bind_python: False

- name: "softmax"
signature: "Tensor (Tensor x, Int64 dim=None) => Softmax"
bind_python: True
Expand Down
47 changes: 47 additions & 0 deletions oneflow/core/functional/impl/activation_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,51 @@ class HardSigmoidGradFunctor : public BinaryFunctor {
CHECK_JUST(one::OpBuilder("hardsigmoid_grad").Input("dy").Input("x").Output("dx").Build());
}
};

class HardShrinkFunctor {
public:
HardShrinkFunctor() {
op_ = CHECK_JUST(one::OpBuilder("hardshrink").Input("in").Output("out").Build());
}

Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, const double& lambd,
bool inplace) const {
MutableAttrMap attrs;
CHECK_GT_OR_RETURN(lambd, 0) << "lambd must be greater than 0";
JUST(attrs.SetAttr<double>("lambd", lambd));
if (inplace) {
JUST(CheckInplaceValid(x));
std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);
// outputs->at(0) = x;
*JUST(oneflow::VectorAt(outputs.get(), 0)) = x;
JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), attrs));
// return outputs->at(0);
return *JUST(oneflow::VectorAt(outputs.get(), 0));
} else {
return OpInterpUtil::Dispatch<one::Tensor>(*op_, {x}, attrs);
}
}

private:
std::shared_ptr<OpExpr> op_;
};

class HardShrinkGradFunctor {
public:
HardShrinkGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("hardshrink_grad").Input("dy").Input("y").Output("dx").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& y, const std::shared_ptr<Tensor>& dy,
const double& lambd) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<double>("lambd", lambd));
return OpInterpUtil::Dispatch<one::Tensor>(*op_, {dy, y}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class SoftmaxFunctorBase {
public:
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
Expand Down Expand Up @@ -568,6 +613,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::GluFunctor>("Glu");
m.add_functor<impl::HardSigmoidFunctor>("HardSigmoid");
m.add_functor<impl::HardSigmoidGradFunctor>("HardSigmoidGrad");
m.add_functor<impl::HardShrinkFunctor>("HardShrink");
m.add_functor<impl::HardShrinkGradFunctor>("HardShrinkGrad");
m.add_functor<impl::SoftmaxFunctor>("Softmax");
m.add_functor<impl::SoftmaxGradFunctor>("SoftmaxGrad");
m.add_functor<impl::LogSoftmaxFunctor>("LogSoftmax");
Expand Down
39 changes: 36 additions & 3 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ class OneFlow_NormalizationAddReluBaseOp : OneFlow_BaseOp<"normalization_add_rel
#endif // GET_ONEFLOW_BASE_OP_DEFINITIONS

// Group: BINARY
// bias_add, cast_like, celu_grad, diag_grad, diagonal_grad, dot, dropout_grad, elementwise_maximum, elementwise_minimum, elu_grad, floordiv, gelu_grad, grid_sample, hardsigmoid_grad, hardswish_grad, l1_l2_regularize_gradient, leaky_relu_grad, masked_fill, mish_grad, multiply, narrow_grad, pow, prelu, relu_grad, selu_grad, sigmoid_grad, silu_grad, softshrink_grad, threshold_grad, tf_prelu, unfold_tensor_grad, xdivy, xlogy
// Total: 33
// bias_add, cast_like, celu_grad, diag_grad, diagonal_grad, dot, dropout_grad, elementwise_maximum, elementwise_minimum, elu_grad, floordiv, gelu_grad, grid_sample, hardsigmoid_grad, hardshrink_grad, hardswish_grad, l1_l2_regularize_gradient, leaky_relu_grad, masked_fill, mish_grad, multiply, narrow_grad, pow, prelu, relu_grad, selu_grad, sigmoid_grad, silu_grad, softshrink_grad, threshold_grad, tf_prelu, unfold_tensor_grad, xdivy, xlogy
// Total: 34

#ifdef GET_ONEFLOW_BINARY_OP_DEFINITIONS

Expand Down Expand Up @@ -438,6 +438,23 @@ def OneFlow_HardsigmoidGradOp : OneFlow_BaseOp<"hardsigmoid_grad", [NoSideEffect
let has_data_type_infer_fn = 1;
}

def OneFlow_HardShrinkGradOp : OneFlow_BaseOp<"hardshrink_grad", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$y,
OneFlow_Tensor:$dy
);
let output = (outs
OneFlow_Tensor:$dx
);
let attrs = (ins
DefaultValuedAttr<F64Attr, "0.">:$lambd
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_HardswishGradOp : OneFlow_BaseOp<"hardswish_grad", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x,
Expand Down Expand Up @@ -8108,7 +8125,7 @@ def OneFlow_TanhGradOp : OneFlow_BaseOp<"tanh_grad", [NoSideEffect, DeclareOpInt
#endif // GET_ONEFLOW_TRIGONOMETRIC_OP_DEFINITIONS

// Group: UNARY
// acc, affine_grid, affine_grid_grad, bernoulli, cast, cast_to_static_shape, cast_to_tick, celu, copy, count_not_finite, diag, diagonal, elu, expand, expand_dims, flatten, flip, flip_grad, fold, gelu, hardsigmoid, hardswish, leaky_relu, log2, logical_not, mish, narrow, one_hot, pack, random_mask_like, repeat, roll, selu, sigmoid, silu, softshrink, softsign, sort, square_sum, squeeze, threshold, transpose, tril, triu, unfold, unfold_tensor, unpack, zero_like, to_contiguous, isnan, isinf
// acc, affine_grid, affine_grid_grad, bernoulli, cast, cast_to_static_shape, cast_to_tick, celu, copy, count_not_finite, diag, diagonal, elu, expand, expand_dims, flatten, flip, flip_grad, fold, gelu, hardsigmoid, hardshrink, hardswish, leaky_relu, log2, logical_not, mish, narrow, one_hot, pack, random_mask_like, repeat, roll, selu, sigmoid, silu, softshrink, softsign, sort, square_sum, squeeze, threshold, transpose, tril, triu, unfold, unfold_tensor, unpack, zero_like, to_contiguous, isnan, isinf
// Total: 51

#ifdef GET_ONEFLOW_UNARY_OP_DEFINITIONS
Expand Down Expand Up @@ -8451,6 +8468,22 @@ def OneFlow_HardsigmoidOp : OneFlow_BaseOp<"hardsigmoid", [NoSideEffect, Declare
let has_data_type_infer_fn = 1;
}

def OneFlow_HardShrinkOp : OneFlow_BaseOp<"hardshrink", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$in
);
let output = (outs
OneFlow_Tensor:$out
);
let attrs = (ins
DefaultValuedAttr<F64Attr, "0.">:$lambd
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_HardswishOp : OneFlow_BaseOp<"hardswish", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$in
Expand Down
1 change: 1 addition & 0 deletions oneflow/user/kernels/activation_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace oneflow {
REGISTER_CELU_KERNEL(DeviceType::kCPU, dtype); \
REGISTER_HARDSWISH_KERNEL(DeviceType::kCPU, dtype); \
REGISTER_HARDSIGMOID_KERNEL(DeviceType::kCPU, dtype); \
REGISTER_HARDSHRINK_KERNEL(DeviceType::kCPU, dtype); \
REGISTER_HARDTANH_KERNEL(DeviceType::kCPU, dtype); \
REGISTER_MISH_KERNEL(DeviceType::kCPU, dtype); \
REGISTER_SILU_KERNEL(DeviceType::kCPU, dtype); \
Expand Down
24 changes: 24 additions & 0 deletions oneflow/user/kernels/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,29 @@ struct HardswishGradFunctor<half> {
}
};

template<>
struct HardShrinkFunctor<half> {
OF_DEVICE_FUNC explicit HardShrinkFunctor(float lambd)
: lambd(lambd), float_functor(HardShrinkFunctor<float>(lambd)) {}
OF_DEVICE_FUNC half operator()(half x) const {
return __float2half(float_functor(__half2float(x)));
}
const float lambd;
HardShrinkFunctor<float> float_functor;
};

template<>
struct HardShrinkGradFunctor<half> {
OF_DEVICE_FUNC explicit HardShrinkGradFunctor(float lambd)
: lambd(lambd), float_functor(HardShrinkGradFunctor<float>(lambd)) {}
OF_DEVICE_FUNC half operator()(half y, half dy) const {
return __float2half(float_functor(__half2float(y), __half2float(dy)));
}

const float lambd;
HardShrinkGradFunctor<float> float_functor;
};

template<>
struct MishFunctor<half> {
OF_DEVICE_FUNC explicit MishFunctor() : float_functor(MishFunctor<float>()) {}
Expand Down Expand Up @@ -261,6 +284,7 @@ struct SoftShrinkGradFunctor<half> {
REGISTER_CELU_KERNEL(DeviceType::kCUDA, dtype); \
REGISTER_HARDSWISH_KERNEL(DeviceType::kCUDA, dtype); \
REGISTER_HARDSIGMOID_KERNEL(DeviceType::kCUDA, dtype); \
REGISTER_HARDSHRINK_KERNEL(DeviceType::kCUDA, dtype); \
REGISTER_HARDTANH_KERNEL(DeviceType::kCUDA, dtype); \
REGISTER_MISH_KERNEL(DeviceType::kCUDA, dtype); \
REGISTER_SILU_KERNEL(DeviceType::kCUDA, dtype); \
Expand Down
54 changes: 54 additions & 0 deletions oneflow/user/kernels/activation_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,26 @@ struct HardsigmoidGradFunctor {
}
};

template<typename T>
struct HardShrinkFunctor {
OF_DEVICE_FUNC explicit HardShrinkFunctor(double lambd) : lambd(lambd) {}
OF_DEVICE_FUNC T operator()(T x) const {
return (x <= lambd && x >= -lambd) ? static_cast<T>(0) : x;
}

const T lambd;
};

template<typename T>
struct HardShrinkGradFunctor {
OF_DEVICE_FUNC explicit HardShrinkGradFunctor(double lambd) : lambd(lambd) {}
OF_DEVICE_FUNC T operator()(T y, T dy) const {
return y == static_cast<T>(0) ? static_cast<T>(0) : dy;
}

const T lambd;
};

template<typename T>
struct HardtanhFunctor {
OF_DEVICE_FUNC explicit HardtanhFunctor(float min_val, float max_val)
Expand Down Expand Up @@ -370,6 +390,40 @@ struct SoftShrinkGradFunctor {
[](user_op::KernelComputeContext* ctx) { return HardsigmoidGradFunctor<dtype>(); }, "dx", \
"x", "dy");

#define REGISTER_HARDSHRINK_KERNEL(device, dtype) \
REGISTER_USER_KERNEL("hardshrink") \
.SetCreateFn([]() { \
return user_op::NewOpKernel< \
UnaryElemwiseXpuKernel<device, HardShrinkFunctor<dtype>, dtype, dtype>>( \
[](user_op::KernelComputeContext* ctx) { \
return HardShrinkFunctor<dtype>(ctx->Attr<double>("lambd")); \
}, \
"out", "in"); \
}) \
.SetIsMatchedHob((user_op::HobDeviceType() == device) \
&& (user_op::HobDataType("in", 0) == GetDataType<dtype>::value)) \
.SetInplaceProposalFn([](const user_op::InferContext&, \
user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, true)); \
return Maybe<void>::Ok(); \
}); \
REGISTER_USER_KERNEL("hardshrink_grad") \
.SetCreateFn([]() { \
return user_op::NewOpKernel< \
BinaryElemwiseXpuKernel<device, HardShrinkGradFunctor<dtype>, dtype, dtype, dtype>>( \
[](user_op::KernelComputeContext* ctx) { \
return HardShrinkGradFunctor<dtype>(ctx->Attr<double>("lambd")); \
}, \
"dx", "y", "dy"); \
}) \
.SetIsMatchedHob((user_op::HobDeviceType() == device) \
&& (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value)) \
.SetInplaceProposalFn([](const user_op::InferContext&, \
user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("dx", 0, "dy", 0, true)); \
return Maybe<void>::Ok(); \
});

#define REGISTER_HARDTANH_KERNEL(device, dtype) \
REGISTER_USER_KERNEL("hardtanh") \
.SetCreateFn([]() { \
Expand Down
Loading

0 comments on commit e82f520

Please sign in to comment.