Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add op hardshrink #7887

Merged
merged 55 commits into from
Apr 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
a24aac5
add op hardshrink
marigoold Mar 24, 2022
3522363
format code
marigoold Mar 24, 2022
0041ba2
refine docs
marigoold Mar 24, 2022
a14580e
fix typo in unittest
marigoold Mar 24, 2022
fcd6cfc
Merge branch 'master' into add_op_hardshrink
marigoold Mar 24, 2022
f661302
Merge branch 'master' into add_op_hardshrink
marigoold Mar 24, 2022
a0f7155
add inplace kernel registration
marigoold Mar 27, 2022
478ca1e
merge master
marigoold Mar 28, 2022
ccb7ec6
add inplace kernel registration
marigoold Mar 28, 2022
5145630
format code
marigoold Mar 28, 2022
f635cde
Merge branch 'master' into add_op_hardshrink
marigoold Mar 30, 2022
21b16e2
Merge branch 'master' into add_op_hardshrink
marigoold Mar 31, 2022
b7a3ba6
Merge branch 'master' into add_op_hardshrink
marigoold Mar 31, 2022
1d609f8
replace unsave ->at with VectorAt
marigoold Mar 31, 2022
a0883ce
format code
marigoold Mar 31, 2022
f383ddd
Merge branch 'master' into add_op_hardshrink
marigoold Mar 31, 2022
ef10cef
Merge branch 'master' into add_op_hardshrink
marigoold Apr 1, 2022
6fa85f3
fix bug of wrong init of nn.Module.Hardshrink
marigoold Apr 1, 2022
3cf5403
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 1, 2022
86b9274
replace unsave inputs->at with VectorAt
marigoold Apr 2, 2022
cf35d8f
Merge branch 'master' into add_op_hardshrink
marigoold Apr 2, 2022
8e71e4e
add error message for CHECK macro
marigoold Apr 2, 2022
79a16e5
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 2, 2022
c3084bd
add error message for CHECK macro
marigoold Apr 2, 2022
2cffc4e
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 2, 2022
46fb06a
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 2, 2022
64b4b7a
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 2, 2022
b38dcde
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 2, 2022
d764a2e
Merge branch 'master' into add_op_hardshrink
marigoold Apr 3, 2022
f266956
Merge branch 'master' into add_op_hardshrink
marigoold Apr 5, 2022
9af4e57
Merge branch 'master' into add_op_hardshrink
marigoold Apr 7, 2022
5d79c94
Merge branch 'master' into add_op_hardshrink
marigoold Apr 7, 2022
c3775ca
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 7, 2022
a6273dc
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 7, 2022
4d2f06b
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 7, 2022
50244cf
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 7, 2022
3714451
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 8, 2022
63b9b15
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 8, 2022
e5b9c47
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 8, 2022
fef5494
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 8, 2022
3180370
fix bug of docstr
marigoold Apr 11, 2022
a835f68
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 11, 2022
b646ee2
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 11, 2022
5a9cec4
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 12, 2022
8951c3b
Merge branch 'master' into add_op_hardshrink
marigoold Apr 13, 2022
58e588f
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 13, 2022
96cec4c
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 13, 2022
61a4d92
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 14, 2022
869cfac
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 14, 2022
d5c4e96
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 14, 2022
a917794
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 14, 2022
94e453b
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 14, 2022
1def04d
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 14, 2022
5b9e585
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 14, 2022
c811f38
Merge branch 'master' into add_op_hardshrink
mergify[bot] Apr 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Functional operations for neural networks
.. autofunction:: adaptive_avg_pool3d
.. autofunction:: relu
.. autofunction:: hardsigmoid
.. autofunction:: hardshrink
.. autofunction:: hardswish
.. autofunction:: hardtanh
.. autofunction:: normalize
Expand Down
1 change: 1 addition & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Operators for neural networks
GLU,
GroupNorm,
Hardsigmoid,
Hardshrink,
Hardswish,
Hardtanh,
Identity,
Expand Down
43 changes: 43 additions & 0 deletions oneflow/core/autograd/gradient_funcs/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,48 @@ class HardSigmoid : public BaseActivation {
}
};

struct HardShrinkCaptureState : public AutoGradCaptureState {
bool requires_grad = true;
double lambd = 0.5;
};

class HardShrink : public OpExprGradFunction<HardShrinkCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "Forward op must be not null";
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}

Maybe<void> Capture(HardShrinkCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1) << "Input grad size must be equal 1";
ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->lambd = JUST(composed_attrs.GetAttr<double>("lambd"));
ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 0)));
return Maybe<void>::Ok();
}

Maybe<void> Apply(const HardShrinkCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1) << "Output grad size must be equal 1";
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& y = JUST(oneflow::VectorAt(ctx->SavedTensors(), 0));
*JUST(oneflow::VectorAt(in_grads, 0)) =
JUST(functional::HardShrinkGrad(y, JUST(oneflow::VectorAt(out_grads, 0)), ctx->lambd));
}
return Maybe<void>::Ok();
}

private:
AttrMap base_attrs_;
};

class HardSwish : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
Expand Down Expand Up @@ -508,6 +550,7 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("softsign", Softsign);
REGISTER_OP_EXPR_GRAD_FUNCTION("relu", ReLU);
REGISTER_OP_EXPR_GRAD_FUNCTION("gelu", GeLU);
REGISTER_OP_EXPR_GRAD_FUNCTION("hardsigmoid", HardSigmoid);
REGISTER_OP_EXPR_GRAD_FUNCTION("hardshrink", HardShrink);
REGISTER_OP_EXPR_GRAD_FUNCTION("hardswish", HardSwish);
REGISTER_OP_EXPR_GRAD_FUNCTION("leaky_relu", LeakyRelu);
REGISTER_OP_EXPR_GRAD_FUNCTION("hardtanh", HardTanh);
Expand Down
8 changes: 8 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,14 @@
signature: "Tensor (Tensor dy, Tensor x) => HardSigmoidGrad"
bind_python: False

- name: "hardshrink"
signature: "Tensor (Tensor x, *, Double lambd=0.5, Bool inplace=False) => HardShrink"
bind_python: True

- name: "hardshrink_grad"
signature: "Tensor (Tensor y, Tensor dy, Double lambd=0.5) => HardShrinkGrad"
bind_python: False

- name: "softmax"
signature: "Tensor (Tensor x, Int64 dim=None) => Softmax"
bind_python: True
Expand Down
47 changes: 47 additions & 0 deletions oneflow/core/functional/impl/activation_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,51 @@ class HardSigmoidGradFunctor : public BinaryFunctor {
CHECK_JUST(one::OpBuilder("hardsigmoid_grad").Input("dy").Input("x").Output("dx").Build());
}
};

class HardShrinkFunctor {
public:
HardShrinkFunctor() {
op_ = CHECK_JUST(one::OpBuilder("hardshrink").Input("in").Output("out").Build());
}

Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, const double& lambd,
bool inplace) const {
MutableAttrMap attrs;
CHECK_GT_OR_RETURN(lambd, 0) << "lambd must be greater than 0";
JUST(attrs.SetAttr<double>("lambd", lambd));
if (inplace) {
JUST(CheckInplaceValid(x));
std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);
// outputs->at(0) = x;
*JUST(oneflow::VectorAt(outputs.get(), 0)) = x;
JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), attrs));
// return outputs->at(0);
return *JUST(oneflow::VectorAt(outputs.get(), 0));
} else {
return OpInterpUtil::Dispatch<one::Tensor>(*op_, {x}, attrs);
}
}

private:
std::shared_ptr<OpExpr> op_;
};

class HardShrinkGradFunctor {
public:
HardShrinkGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("hardshrink_grad").Input("dy").Input("y").Output("dx").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& y, const std::shared_ptr<Tensor>& dy,
const double& lambd) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<double>("lambd", lambd));
return OpInterpUtil::Dispatch<one::Tensor>(*op_, {dy, y}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class SoftmaxFunctorBase {
public:
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
Expand Down Expand Up @@ -568,6 +613,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::GluFunctor>("Glu");
m.add_functor<impl::HardSigmoidFunctor>("HardSigmoid");
m.add_functor<impl::HardSigmoidGradFunctor>("HardSigmoidGrad");
m.add_functor<impl::HardShrinkFunctor>("HardShrink");
m.add_functor<impl::HardShrinkGradFunctor>("HardShrinkGrad");
m.add_functor<impl::SoftmaxFunctor>("Softmax");
m.add_functor<impl::SoftmaxGradFunctor>("SoftmaxGrad");
m.add_functor<impl::LogSoftmaxFunctor>("LogSoftmax");
Expand Down
39 changes: 36 additions & 3 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ class OneFlow_NormalizationAddReluBaseOp : OneFlow_BaseOp<"normalization_add_rel
#endif // GET_ONEFLOW_BASE_OP_DEFINITIONS

// Group: BINARY
// bias_add, cast_like, celu_grad, diag_grad, diagonal_grad, dot, dropout_grad, elementwise_maximum, elementwise_minimum, elu_grad, floordiv, gelu_grad, grid_sample, hardsigmoid_grad, hardswish_grad, l1_l2_regularize_gradient, leaky_relu_grad, masked_fill, mish_grad, multiply, narrow_grad, pow, prelu, relu_grad, selu_grad, sigmoid_grad, silu_grad, softshrink_grad, threshold_grad, tf_prelu, unfold_tensor_grad, xdivy, xlogy
// Total: 33
// bias_add, cast_like, celu_grad, diag_grad, diagonal_grad, dot, dropout_grad, elementwise_maximum, elementwise_minimum, elu_grad, floordiv, gelu_grad, grid_sample, hardsigmoid_grad, hardshrink_grad, hardswish_grad, l1_l2_regularize_gradient, leaky_relu_grad, masked_fill, mish_grad, multiply, narrow_grad, pow, prelu, relu_grad, selu_grad, sigmoid_grad, silu_grad, softshrink_grad, threshold_grad, tf_prelu, unfold_tensor_grad, xdivy, xlogy
// Total: 34

#ifdef GET_ONEFLOW_BINARY_OP_DEFINITIONS

Expand Down Expand Up @@ -438,6 +438,23 @@ def OneFlow_HardsigmoidGradOp : OneFlow_BaseOp<"hardsigmoid_grad", [NoSideEffect
let has_data_type_infer_fn = 1;
}

def OneFlow_HardShrinkGradOp : OneFlow_BaseOp<"hardshrink_grad", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$y,
OneFlow_Tensor:$dy
);
let output = (outs
OneFlow_Tensor:$dx
);
let attrs = (ins
DefaultValuedAttr<F64Attr, "0.">:$lambd
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_HardswishGradOp : OneFlow_BaseOp<"hardswish_grad", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x,
Expand Down Expand Up @@ -8108,7 +8125,7 @@ def OneFlow_TanhGradOp : OneFlow_BaseOp<"tanh_grad", [NoSideEffect, DeclareOpInt
#endif // GET_ONEFLOW_TRIGONOMETRIC_OP_DEFINITIONS

// Group: UNARY
// acc, affine_grid, affine_grid_grad, bernoulli, cast, cast_to_static_shape, cast_to_tick, celu, copy, count_not_finite, diag, diagonal, elu, expand, expand_dims, flatten, flip, flip_grad, fold, gelu, hardsigmoid, hardswish, leaky_relu, log2, logical_not, mish, narrow, one_hot, pack, random_mask_like, repeat, roll, selu, sigmoid, silu, softshrink, softsign, sort, square_sum, squeeze, threshold, transpose, tril, triu, unfold, unfold_tensor, unpack, zero_like, to_contiguous, isnan, isinf
// acc, affine_grid, affine_grid_grad, bernoulli, cast, cast_to_static_shape, cast_to_tick, celu, copy, count_not_finite, diag, diagonal, elu, expand, expand_dims, flatten, flip, flip_grad, fold, gelu, hardsigmoid, hardshrink, hardswish, leaky_relu, log2, logical_not, mish, narrow, one_hot, pack, random_mask_like, repeat, roll, selu, sigmoid, silu, softshrink, softsign, sort, square_sum, squeeze, threshold, transpose, tril, triu, unfold, unfold_tensor, unpack, zero_like, to_contiguous, isnan, isinf
// Total: 51

#ifdef GET_ONEFLOW_UNARY_OP_DEFINITIONS
Expand Down Expand Up @@ -8451,6 +8468,22 @@ def OneFlow_HardsigmoidOp : OneFlow_BaseOp<"hardsigmoid", [NoSideEffect, Declare
let has_data_type_infer_fn = 1;
}

def OneFlow_HardShrinkOp : OneFlow_BaseOp<"hardshrink", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$in
);
let output = (outs
OneFlow_Tensor:$out
);
let attrs = (ins
DefaultValuedAttr<F64Attr, "0.">:$lambd
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_HardswishOp : OneFlow_BaseOp<"hardswish", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$in
Expand Down
1 change: 1 addition & 0 deletions oneflow/user/kernels/activation_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace oneflow {
REGISTER_CELU_KERNEL(DeviceType::kCPU, dtype); \
REGISTER_HARDSWISH_KERNEL(DeviceType::kCPU, dtype); \
REGISTER_HARDSIGMOID_KERNEL(DeviceType::kCPU, dtype); \
REGISTER_HARDSHRINK_KERNEL(DeviceType::kCPU, dtype); \
REGISTER_HARDTANH_KERNEL(DeviceType::kCPU, dtype); \
REGISTER_MISH_KERNEL(DeviceType::kCPU, dtype); \
REGISTER_SILU_KERNEL(DeviceType::kCPU, dtype); \
Expand Down
24 changes: 24 additions & 0 deletions oneflow/user/kernels/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,29 @@ struct HardswishGradFunctor<half> {
}
};

template<>
struct HardShrinkFunctor<half> {
OF_DEVICE_FUNC explicit HardShrinkFunctor(float lambd)
: lambd(lambd), float_functor(HardShrinkFunctor<float>(lambd)) {}
OF_DEVICE_FUNC half operator()(half x) const {
return __float2half(float_functor(__half2float(x)));
}
const float lambd;
HardShrinkFunctor<float> float_functor;
};

template<>
struct HardShrinkGradFunctor<half> {
OF_DEVICE_FUNC explicit HardShrinkGradFunctor(float lambd)
: lambd(lambd), float_functor(HardShrinkGradFunctor<float>(lambd)) {}
OF_DEVICE_FUNC half operator()(half y, half dy) const {
return __float2half(float_functor(__half2float(y), __half2float(dy)));
}

const float lambd;
HardShrinkGradFunctor<float> float_functor;
};

template<>
struct MishFunctor<half> {
OF_DEVICE_FUNC explicit MishFunctor() : float_functor(MishFunctor<float>()) {}
Expand Down Expand Up @@ -261,6 +284,7 @@ struct SoftShrinkGradFunctor<half> {
REGISTER_CELU_KERNEL(DeviceType::kCUDA, dtype); \
REGISTER_HARDSWISH_KERNEL(DeviceType::kCUDA, dtype); \
REGISTER_HARDSIGMOID_KERNEL(DeviceType::kCUDA, dtype); \
REGISTER_HARDSHRINK_KERNEL(DeviceType::kCUDA, dtype); \
REGISTER_HARDTANH_KERNEL(DeviceType::kCUDA, dtype); \
REGISTER_MISH_KERNEL(DeviceType::kCUDA, dtype); \
REGISTER_SILU_KERNEL(DeviceType::kCUDA, dtype); \
Expand Down
54 changes: 54 additions & 0 deletions oneflow/user/kernels/activation_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,26 @@ struct HardsigmoidGradFunctor {
}
};

template<typename T>
struct HardShrinkFunctor {
OF_DEVICE_FUNC explicit HardShrinkFunctor(double lambd) : lambd(lambd) {}
OF_DEVICE_FUNC T operator()(T x) const {
return (x <= lambd && x >= -lambd) ? static_cast<T>(0) : x;
}

const T lambd;
};

template<typename T>
struct HardShrinkGradFunctor {
OF_DEVICE_FUNC explicit HardShrinkGradFunctor(double lambd) : lambd(lambd) {}
OF_DEVICE_FUNC T operator()(T y, T dy) const {
return y == static_cast<T>(0) ? static_cast<T>(0) : dy;
}

const T lambd;
};

template<typename T>
struct HardtanhFunctor {
OF_DEVICE_FUNC explicit HardtanhFunctor(float min_val, float max_val)
Expand Down Expand Up @@ -370,6 +390,40 @@ struct SoftShrinkGradFunctor {
[](user_op::KernelComputeContext* ctx) { return HardsigmoidGradFunctor<dtype>(); }, "dx", \
"x", "dy");

#define REGISTER_HARDSHRINK_KERNEL(device, dtype) \
REGISTER_USER_KERNEL("hardshrink") \
.SetCreateFn([]() { \
return user_op::NewOpKernel< \
UnaryElemwiseXpuKernel<device, HardShrinkFunctor<dtype>, dtype, dtype>>( \
[](user_op::KernelComputeContext* ctx) { \
return HardShrinkFunctor<dtype>(ctx->Attr<double>("lambd")); \
}, \
"out", "in"); \
}) \
.SetIsMatchedHob((user_op::HobDeviceType() == device) \
&& (user_op::HobDataType("in", 0) == GetDataType<dtype>::value)) \
.SetInplaceProposalFn([](const user_op::InferContext&, \
user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, true)); \
return Maybe<void>::Ok(); \
}); \
REGISTER_USER_KERNEL("hardshrink_grad") \
.SetCreateFn([]() { \
return user_op::NewOpKernel< \
BinaryElemwiseXpuKernel<device, HardShrinkGradFunctor<dtype>, dtype, dtype, dtype>>( \
[](user_op::KernelComputeContext* ctx) { \
return HardShrinkGradFunctor<dtype>(ctx->Attr<double>("lambd")); \
}, \
"dx", "y", "dy"); \
}) \
.SetIsMatchedHob((user_op::HobDeviceType() == device) \
&& (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value)) \
.SetInplaceProposalFn([](const user_op::InferContext&, \
user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> { \
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("dx", 0, "dy", 0, true)); \
return Maybe<void>::Ok(); \
});

#define REGISTER_HARDTANH_KERNEL(device, dtype) \
REGISTER_USER_KERNEL("hardtanh") \
.SetCreateFn([]() { \
Expand Down
Loading