Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix linalg vector norm backward bug #8015

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,14 @@
signature: "Tensor (Tensor input) => TransposeAllDimFunction"
bind_python: True

- name: "not_equal_zero"
signature: "Tensor (Tensor x) => NotEqualZero"
bind_python: False

- name: "not_equal_zero_grad"
signature: "Tensor (Tensor x, Tensor dy) => NotEqualZeroGrad"
bind_python: False

- name: "reciprocal"
signature: "Tensor (Tensor x) => Reciprocal"
bind_python: True
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,7 @@ class VectorNormFunctor {
if (ord.IsIntegral() || ord.IsFloatingPoint()) {
double ord_val = JUST(ord.As<double>());
if (ord_val == 0) {
res = JUST(ReduceSum(JUST(ScalarLogicalNotEqual(x, 0)), dim, keepdim));
res = JUST(ReduceSum(JUST(functional::NotEqualZero(x)), dim, keepdim));
} else if (ord_val == INFINITY) {
res = JUST(ReduceMax(JUST(Abs(x)), dim, keepdim));
} else if (ord_val == -INFINITY) {
Expand Down
4 changes: 3 additions & 1 deletion oneflow/core/functional/impl/unary_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ namespace impl {
OF_PP_MAKE_TUPLE_SEQ("sqrt", Sqrt) \
OF_PP_MAKE_TUPLE_SEQ("square", Square) \
OF_PP_MAKE_TUPLE_SEQ("tan", Tan) \
OF_PP_MAKE_TUPLE_SEQ("tanh", Tanh)
OF_PP_MAKE_TUPLE_SEQ("tanh", Tanh) \
OF_PP_MAKE_TUPLE_SEQ("not_equal_zero", NotEqualZero)

#define LOGICAL_FLOAT_UNARY_FUNC_SEQ OF_PP_MAKE_TUPLE_SEQ("logical_not", LogicalNot)

Expand Down Expand Up @@ -151,6 +152,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
ADD_UNARY_FUNCTOR(Square, "Square");
ADD_UNARY_FUNCTOR(Tan, "Tan");
ADD_UNARY_FUNCTOR(Tanh, "Tanh");
ADD_UNARY_FUNCTOR(NotEqualZero, "NotEqualZero")
m.add_functor<LogicalNotFunctor>("LogicalNot");
m.add_functor<InplaceSinFunctor>("Sin_");
m.add_functor<InplaceFloorFunctor>("Floor_");
Expand Down
31 changes: 29 additions & 2 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7632,8 +7632,8 @@ def OneFlow_TestUserOpAttrAutoTypeOp : OneFlow_BaseOp<"test_user_op_attr_auto_ty
#endif // GET_ONEFLOW_TEST_OP_DEFINITIONS

// Group: TRIGONOMETRIC
// acos, acos_grad, acosh, acosh_grad, asin, asin_grad, asinh, asinh_grad, atan, atan2, atan2_x_grad, atan2_y_grad, atan_grad, atanh, atanh_grad, cos, cos_grad, cosh, cosh_grad, hardtanh, hardtanh_grad, sin, sin_grad, sinh, sinh_grad, tan, tan_grad, tanh, tanh_grad
// Total: 29
// acos, acos_grad, acosh, acosh_grad, asin, asin_grad, asinh, asinh_grad, atan, atan2, atan2_x_grad, atan2_y_grad, atan_grad, atanh, atanh_grad, cos, cos_grad, cosh, cosh_grad, hardtanh, hardtanh_grad, sin, sin_grad, sinh, sinh_grad, tan, tan_grad, tanh, tanh_grad, not_equal_zero, not_equal_zero_grad
// Total: 31

#ifdef GET_ONEFLOW_TRIGONOMETRIC_OP_DEFINITIONS

Expand Down Expand Up @@ -8040,6 +8040,33 @@ def OneFlow_TanhGradOp : OneFlow_BaseOp<"tanh_grad", [NoSideEffect, DeclareOpInt
let has_data_type_infer_fn = 1;
}

def OneFlow_NotEqualZeroOp : OneFlow_BaseOp<"not_equal_zero", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x
);
let output = (outs
OneFlow_Tensor:$y
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_NotEqualZeroGradOp : OneFlow_BaseOp<"not_equal_zero_grad", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x,
OneFlow_Tensor:$dy
);
let output = (outs
OneFlow_Tensor:$dx
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

#endif // GET_ONEFLOW_TRIGONOMETRIC_OP_DEFINITIONS

// Group: UNARY
Expand Down
23 changes: 23 additions & 0 deletions oneflow/user/kernels/math_unary_elementwise_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ struct AtanhFunctor<float> {
}
};

template<>
struct NotEqualZeroFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return x != 0; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该用 static_cast(0.0), 你这样是拿着float和int比较吧


static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * (x != 0); }
};

template<>
struct CeilFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(ceil, x); }
Expand Down Expand Up @@ -422,6 +429,13 @@ struct AtanhFunctor<double> {
}
};

template<>
struct NotEqualZeroFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return x != 0; }

static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * (x != 0); }
};

template<>
struct CeilFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(ceil, x); }
Expand Down Expand Up @@ -717,6 +731,15 @@ struct CeilFunctor<half> {
static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }
};

template<>
struct NotEqualZeroFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return x != static_cast<half>(0.0); }

static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, x != static_cast<half>(0.0));
}
};

template<>
struct CosFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hcos(x); }
Expand Down
6 changes: 4 additions & 2 deletions oneflow/user/ops/math_unary_elementwise_seq.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ namespace oneflow {
OF_PP_MAKE_TUPLE_SEQ("sinh", Sinh) \
OF_PP_MAKE_TUPLE_SEQ("sqrt", Sqrt) \
OF_PP_MAKE_TUPLE_SEQ("square", Square) \
OF_PP_MAKE_TUPLE_SEQ("tan", Tan)
OF_PP_MAKE_TUPLE_SEQ("tan", Tan) \
OF_PP_MAKE_TUPLE_SEQ("not_equal_zero", NotEqualZero)

#define MATH_UNARY_ELEMENTWISE_FUNC_SEQ_ODS \
OF_PP_MAKE_TUPLE_SEQ("abs", Abs) \
Expand Down Expand Up @@ -88,7 +89,8 @@ namespace oneflow {
OF_PP_MAKE_TUPLE_SEQ("sinh", Sinh) \
OF_PP_MAKE_TUPLE_SEQ("sqrt", Sqrt) \
OF_PP_MAKE_TUPLE_SEQ("square", Square) \
OF_PP_MAKE_TUPLE_SEQ("tan", Tan)
OF_PP_MAKE_TUPLE_SEQ("tan", Tan) \
OF_PP_MAKE_TUPLE_SEQ("not_equal_zero", NotEqualZero)

} // namespace oneflow

Expand Down
8 changes: 6 additions & 2 deletions python/oneflow/nn/utils/clip_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def clip_grad_norm_(
),
norm_type,
)
if error_if_nonfinite and flow.logical_or(total_norm.isnan(), total_norm.isinf()):
if error_if_nonfinite and flow.logical_or(
total_norm.isnan(), total_norm.isinf()
):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. To disable "
Expand Down Expand Up @@ -149,7 +151,9 @@ def clip_grad_norm_(
),
norm_type,
)
if error_if_nonfinite and flow.logical_or(total_norm.isnan(), total_norm.isinf()):
if error_if_nonfinite and flow.logical_or(
total_norm.isnan(), total_norm.isinf()
):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. To disable "
Expand Down
5 changes: 3 additions & 2 deletions python/oneflow/test/modules/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,14 +304,15 @@ def test_tuple_dim_norm_with_random_data(test_case):
m = torch.linalg.norm(input, ord=ord, dim=dim, keepdim=keepdim)
return m

@autotest(n=5, auto_backward=False, check_graph=True)
def test_ord_zero_with_random_data(test_case):
@autotest(n=5, auto_backward=False)
def test_vector_norm_only_zero_with_random_data(test_case):
device = random_device()
input = random_tensor(ndim=2).to(device)
dim = oneof((-2, -1), (0, 1), (-1, 0))
keepdim = random().to(bool)
m = torch.linalg.vector_norm(input, ord=0, dim=dim, keepdim=keepdim)
return m


if __name__ == "__main__":
unittest.main()