Skip to content

Commit

Permalink
Fix norm (#8629)
Browse files Browse the repository at this point in the history
* fix norm

* add doc

* add bool &

* update math_functor.cpp

* add note
  • Loading branch information
zhongshsh authored Jul 15, 2022
1 parent 8f01ed9 commit 0f3ebdc
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 7 deletions.
2 changes: 1 addition & 1 deletion oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1718,7 +1718,7 @@
- name: "norm"
signature:
[
"Tensor (Tensor input, Scalar ord=None, Int32List dim=None, Bool keepdim=False, *, DataType dtype=None) => Norm",
"Tensor (Tensor input, Scalar ord=None, Int32List dim=None, Bool keepdim=False, *, DataType dtype=None, Bool for_norm=False) => Norm",
"Tensor (Tensor input, String ord, Int32List dim=None, Bool keepdim=False, *, DataType dtype=None) => Norm",
"Tensor (Tensor input, Scalar ord=None, Scalar dim, Bool keepdim=False, *, DataType dtype=None) => ScalarNorm",
"Tensor (Tensor input, String ord, Scalar dim, Bool keepdim=False, *, DataType dtype=None) => ScalarNorm",
Expand Down
19 changes: 16 additions & 3 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,8 @@ class NormFunctor {
NormFunctor() {}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Optional<Scalar>& ord,
const Optional<std::vector<int32_t>>& input_dim, const bool& keepdim,
const Optional<Symbol<DType>>& dtype) const {
const Optional<Symbol<DType>>& dtype, const bool& for_norm) const {
// If for_norm, the functor will be used to oneflow.norm.
std::shared_ptr<one::Tensor> res;
if (dtype) {
Symbol<DType> dtype_val = JUST(dtype);
Expand All @@ -1444,8 +1445,9 @@ class NormFunctor {
}
}
Scalar ord_sca;
bool ord_type = false;
if (ord.has_value()) {
auto ord_type = (*JUST(ord)).IsIntegral();
ord_type = (*JUST(ord)).IsIntegral();
if (ord_type) {
ord_sca = Scalar((*JUST(ord)).As<double>());
} else {
Expand Down Expand Up @@ -1475,6 +1477,17 @@ class NormFunctor {
if (ord.has_value()) {
CHECK_OR_RETURN(x->ndim() <= 2)
<< "linalg.norm(): input must be 1-D or 2-D when dim is None and ord is not None";
if (ord_type) {
const double ord_double = (*JUST(ord)).As<double>();
if (for_norm && (ord_double >= 2 || ord_double <= -2)) {
const int32_t num_axes = x->shape()->NumAxes();
std::vector<int32_t> axes_vec(num_axes);
std::iota(axes_vec.begin(), axes_vec.end(), 0);
return ScalarPow(JUST(ReduceSum(JUST(ScalarPow(JUST(Abs(x)), ord_sca, false)), axes_vec,
/*keepdims=*/false)),
1 / ord_double, false);
}
}
if (x->ndim() == 1) {
res = JUST(VectorNorm(x, ord_sca, input_dim, keepdim, dtype));
} else {
Expand Down Expand Up @@ -1545,7 +1558,7 @@ class ScalarNormFunctor {
}
if (input_dim.IsIntegral()) {
std::vector<int32_t> dim(1, input_dim.As<int>());
return functional::Norm(x, ord, dim, keepdim, dtype);
return functional::Norm(x, ord, dim, keepdim, dtype, /*for_norm=*/false);
} else {
UNIMPLEMENTED_THEN_RETURN() << "linalg_norm(): only supports int dim.";
}
Expand Down
9 changes: 9 additions & 0 deletions python/oneflow/framework/docstr/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,15 @@
""",
)

add_docstr(
oneflow.Tensor.norm,
"""
norm(p="fro", dim=None, keepdim=False, dtype=None) -> Tensor
See :func:`oneflow.norm`.
""",
)

add_docstr(
oneflow.Tensor.numpy,
"""
Expand Down
4 changes: 3 additions & 1 deletion python/oneflow/framework/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def _cuda(self, device: Union[int, str, flow.device] = None):


def _norm(self, p=None, dim=None, keepdim=False, dtype=None):
return flow._C.norm(self, p, dim, keepdim, dtype=dtype)
if type(p) == str or dim != None:
return flow._C.norm(self, p, dim, keepdim, dtype=dtype)
return flow._C.norm(self, p, dim, keepdim, dtype=dtype, for_norm=True)


def is_nonzero(input):
Expand Down
6 changes: 5 additions & 1 deletion python/oneflow/nn/modules/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,8 @@ def norm(input, p="fro", dim=None, keepdim=False, dtype=None):
>>> flow.norm(d[0, :, :]), flow.norm(d[1, :, :])
(tensor(3.7417, dtype=oneflow.float32), tensor(11.2250, dtype=oneflow.float32))
"""
return flow._C.norm(input=input, ord=p, dim=dim, keepdim=keepdim, dtype=dtype)
if type(p) == str or dim != None:
return flow._C.norm(input=input, ord=p, dim=dim, keepdim=keepdim, dtype=dtype)
return flow._C.norm(
input=input, ord=p, dim=dim, keepdim=keepdim, dtype=dtype, for_norm=True
)
12 changes: 11 additions & 1 deletion python/oneflow/test/modules/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,6 @@ def test_no_dim_two_shape_norm_with_random_data(test_case):
def test_tuple_dim_norm_with_random_data(test_case):
device = random_device()
input = random_tensor(ndim=2).to(device)
k = random(low=-2, high=1).to(int)
dim = oneof((-2, -1), (0, 1), (-1, 0))
ord = oneof(float("inf"), float("-inf"), "fro", 1, -1, None)
keepdim = random().to(bool)
Expand All @@ -324,6 +323,17 @@ def test_vector_norm_only_zero_with_random_data(test_case):
m = torch.linalg.vector_norm(input, ord=0, dim=dim, keepdim=keepdim)
return m

@autotest(n=5)
def test_ord_random_data(test_case):
device = random_device()
ndim = random(1, 3).to(int)
input = random_tensor(ndim).to(device)
p1 = random(-5, -1).to(int).value()
p2 = random(2, 6).to(int).value()
m = input.norm(p1)
n = input.norm(p2)
return m, n


if __name__ == "__main__":
unittest.main()

0 comments on commit 0f3ebdc

Please sign in to comment.