Skip to content

Commit

Permalink
Add addcdiv (#8581)
Browse files Browse the repository at this point in the history
* add addcdiv

* fix tensor_functions

* fix inplace

* add test number

* rename consistent to global
  • Loading branch information
zhongshsh authored Jul 23, 2022
1 parent cf27cde commit b542e15
Show file tree
Hide file tree
Showing 13 changed files with 234 additions and 12 deletions.
1 change: 1 addition & 0 deletions docs/source/oneflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ Pointwise Ops
arccos
arccosh
add
addcdiv
addcmul
asin
asinh
Expand Down
2 changes: 2 additions & 0 deletions docs/source/tensor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ Tensor class reference
Tensor.acosh
Tensor.add
Tensor.add_
Tensor.addcdiv
Tensor.addcdiv_
Tensor.addcmul
Tensor.addcmul_
Tensor.addmm
Expand Down
4 changes: 4 additions & 0 deletions oneflow/api/python/framework/tensor_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ DIRECT_PASS_FUNC(PyTensorObject_amin, functional::amin)
DIRECT_PASS_FUNC(PyTensorObject_amax, functional::amax)
DIRECT_PASS_FUNC(PyTensorObject_addcmul, functional::addcmul)
DIRECT_PASS_FUNC(PyTensorObject_addcmul_, functional::addcmul_)
DIRECT_PASS_FUNC(PyTensorObject_addcdiv, functional::addcdiv)
DIRECT_PASS_FUNC(PyTensorObject_addcdiv_, functional::addcdiv_)
DIRECT_PASS_FUNC(PyTensorObject_clip, functional::clip)
DIRECT_PASS_FUNC(PyTensorObject_clip_, functional::clip_)
DIRECT_PASS_FUNC(PyTensorObject_clamp, functional::clamp)
Expand Down Expand Up @@ -812,6 +814,8 @@ PyMethodDef PyTensorObject_extra_methods[] = {
{"diagonal", (PyCFunction)PyTensorObject_diagonal, METH_VARARGS | METH_KEYWORDS, NULL},
{"addcmul", (PyCFunction)PyTensorObject_addcmul, METH_VARARGS | METH_KEYWORDS, NULL},
{"addcmul_", (PyCFunction)PyTensorObject_addcmul_, METH_VARARGS | METH_KEYWORDS, NULL},
{"addcdiv", (PyCFunction)PyTensorObject_addcdiv, METH_VARARGS | METH_KEYWORDS, NULL},
{"addcdiv_", (PyCFunction)PyTensorObject_addcdiv_, METH_VARARGS | METH_KEYWORDS, NULL},
{"matmul", (PyCFunction)PyTensorObject_matmul, METH_VARARGS | METH_KEYWORDS, NULL},
{"int", PyTensorObject_int, METH_NOARGS, NULL},
{"long", PyTensorObject_long, METH_NOARGS, NULL},
Expand Down
8 changes: 8 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@
signature: "Tensor (Tensor input, Tensor tensor1, Tensor tensor2, *, Scalar value=1) => InplaceAddcmul"
bind_python: true

- name: "addcdiv"
signature: "Tensor (Tensor input, Tensor tensor1, Tensor tensor2, *, Scalar value=1) => AddCDiv"
bind_python: true

- name: "addcdiv_"
signature: "Tensor (Tensor input, Tensor tensor1, Tensor tensor2, *, Scalar value=1) => InplaceAddCDiv"
bind_python: true

- name: "div"
signature:
[
Expand Down
26 changes: 26 additions & 0 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2989,6 +2989,30 @@ class EinSumFunctor {
}
};

class AddCDivFunctor {
public:
AddCDivFunctor() {}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& tensor1,
const std::shared_ptr<one::Tensor>& tensor2, const Scalar& value) const {
return JUST(Add(input, JUST(ScalarMul(JUST(Div(tensor1, tensor2)), value, false)), 1, false));
}
};

class InplaceAddCDivFunctor {
public:
InplaceAddCDivFunctor() {}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& tensor1,
const std::shared_ptr<one::Tensor>& tensor2, const Scalar& value) const {
JUST(CheckInplaceValid(input));
std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);
JUST(VectorAt(*outputs, 0)) = input;
JUST(Add(input, JUST(ScalarMul(JUST(Div(tensor1, tensor2)), value, false)), 1, true));
return JUST(VectorAt(*outputs, 0));
}
};

} // namespace impl

using namespace impl;
Expand All @@ -2999,6 +3023,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<ScalarSubFunctor, ScalarSub2Functor>("ScalarSub");
m.add_functor<ScalarMulFunctor, ScalarMul2Functor>("ScalarMul");
m.add_functor<InplaceScalarMulFunctor>("InplaceScalarMul");
m.add_functor<AddCDivFunctor>("AddCDiv");
m.add_functor<InplaceAddCDivFunctor>("InplaceAddCDiv");
m.add_functor<ScalarDivFunctor, ScalarDiv2Functor>("ScalarDiv");
m.add_functor<InplaceScalarDivFunctor>("InplaceScalarDiv");
m.add_functor<ScalarPowFunctor>("ScalarPow");
Expand Down
1 change: 1 addition & 0 deletions python/oneflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def is_deprecated(func_or_class):
from oneflow._C import diag
from oneflow._C import log1p
from oneflow._C import add
from oneflow._C import addcdiv
from oneflow._C import div, div_
from oneflow._C import addcmul
from oneflow._C import floor, floor_
Expand Down
1 change: 1 addition & 0 deletions python/oneflow/framework/docstr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,6 @@
from .amin import *
from .deconv import *
from .logical_ops import *
from .addcdiv import *
from .hann_window import *
from .convolution import *
58 changes: 58 additions & 0 deletions python/oneflow/framework/docstr/addcdiv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow
from oneflow.framework.docstr.utils import add_docstr

add_docstr(
oneflow.addcdiv,
r"""
addcdiv(input, tensor1, tensor2, *, value=1) -> Tensor
This function is equivalent to PyTorch’s addcdiv function.
The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.addcdiv.html.
Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`,
multiply the result by the scalar :attr:`value` and add it to :attr:`input`.
.. math::
\text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i}
The shapes of :attr:`input`, :attr:`tensor1`, and :attr:`tensor2` must be
`broadcastable`.
For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be
a real number, otherwise an integer.
Args:
input (Tensor): the tensor to be added
tensor1 (Tensor): the numerator tensor
tensor2 (Tensor): the denominator tensor
Keyword args:
value (Number, optional): multiplier for :math:`\text{{tensor1}} / \text{{tensor2}}`
Example::
>>> import oneflow as flow
>>> input = flow.tensor([ 0.3810, 1.2774, -0.2972, -0.3719])
>>> tensor1 = flow.tensor([0.8032, 0.2930, -0.8113, -0.2308])
>>> tensor2 = flow.tensor([[0.5], [1]])
>>> output = flow.addcdiv(input, tensor1, tensor2)
>>> output.shape
oneflow.Size([2, 4])
""",
)
14 changes: 14 additions & 0 deletions python/oneflow/framework/docstr/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,20 @@
""",
)

add_docstr(
oneflow.Tensor.addcdiv,
"""
See :func:`oneflow.addcdiv`
""",
)

add_docstr(
oneflow.Tensor.addcdiv_,
"""
In-place version of :func:`oneflow.Tensor.addcdiv`
""",
)

add_docstr(
oneflow.Tensor.dim,
"""
Expand Down
63 changes: 63 additions & 0 deletions python/oneflow/test/modules/test_addcdiv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from oneflow.test_utils.automated_test_util import *
import oneflow as flow
import oneflow.unittest


@flow.unittest.skip_unless_1n1d()
class TestAddcdiv(flow.unittest.TestCase):
@autotest(n=5)
def test_addcdiv(test_case):
device = random_device()
ndim = random(2, 4).to(int).value()
shape = [random(2, 4) for i in range(ndim)]
input = random_tensor(ndim, *shape).to(device)
tensor1 = random_tensor(ndim, *shape).to(device)
tensor2 = random_tensor(ndim, *shape).to(device)
value = random(2, 4).to(int)
output = torch.addcdiv(input, tensor1, tensor2, value=value)
return output

@autotest(n=5)
def test_tensor_addcdiv(test_case):
device = random_device()
ndim = random(2, 4).to(int).value()
shape = [random(2, 4) for i in range(ndim)]
input = random_tensor(ndim, *shape).to(device)
tensor1 = random_tensor(ndim, *shape).to(device)
tensor2 = random_tensor(ndim, *shape).to(device)
value = random(2, 4).to(int)
output = input.addcdiv(tensor1, tensor2, value=value)
return output

@autotest(n=5)
def test_tensor_addcdiv_inplace(test_case):
device = random_device()
ndim = random(2, 4).to(int).value()
shape = [random(2, 4) for i in range(ndim)]
input = random_tensor(ndim, *shape).to(device)
input = input + 1.0
tensor1 = random_tensor(ndim, *shape).to(device)
tensor2 = random_tensor(ndim, *shape).to(device)
value = random(2, 4).to(int)
input.addcdiv_(tensor1, tensor2, value=value)
return input


if __name__ == "__main__":
unittest.main()
44 changes: 44 additions & 0 deletions python/oneflow/test/modules/test_global_addcdiv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest

import oneflow as flow
import oneflow.unittest
from oneflow.test_utils.automated_test_util import *


@autotest(n=1, check_graph=False)
def _test_addcdiv(test_case, ndim, placement, sbp):
shape = [random(2, 4) * 8 for i in range(ndim)]
input = random_tensor(ndim, *shape).to_global(placement=placement, sbp=sbp)
tensor1 = random_tensor(ndim, *shape).to_global(placement=placement, sbp=sbp)
tensor2 = random_tensor(ndim, *shape).to_global(placement=placement, sbp=sbp)
value = random(2, 4).to(int)
output = torch.addcdiv(input, tensor1, tensor2, value=value)
return output


class TestModule(flow.unittest.TestCase):
@globaltest
def test_addcdiv(test_case):
ndim = random(2, 4).to(int).value()
for placement in all_placement():
for sbp in all_sbp(placement, max_dim=ndim):
_test_addcdiv(test_case, ndim, placement, sbp)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from oneflow.test_utils.test_util import GenArgDict


def _test_consistent_full(test_case, shape, placement, sbp):
def _test_global_full(test_case, shape, placement, sbp):
x = flow.full(shape, 1.0, placement=placement, sbp=sbp)

test_case.assertEqual(x.shape, flow.Size(shape))
Expand All @@ -33,32 +33,32 @@ def _test_consistent_full(test_case, shape, placement, sbp):


def _test_graph_full(test_case, shape, placement, sbp):
class ConsistentFullGraph(flow.nn.Graph):
class GlobalFullGraph(flow.nn.Graph):
def __init__(self,):
super().__init__()

def build(self):
x = flow.full(shape, 1.0, placement=placement, sbp=sbp)
return x

model = ConsistentFullGraph()
model = GlobalFullGraph()
x = model()

test_case.assertEqual(x.shape, flow.Size(shape))
test_case.assertEqual(x.sbp, sbp)
test_case.assertEqual(x.placement, placement)


class TestFullConsistent(flow.unittest.TestCase):
class TestFullGlobal(flow.unittest.TestCase):
@globaltest
def test_full_consistent(test_case):
def test_full_global(test_case):
shapes = [(8,), (8, 8,), (8, 8, 8)]
for shape in shapes:
for placement in all_placement():
for sbp in all_sbp(
placement, max_dim=len(shape), except_partial_sum=True
):
_test_consistent_full(test_case, shape, placement, sbp)
_test_global_full(test_case, shape, placement, sbp)

@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n2d()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from oneflow.test_utils.test_util import GenArgDict


def _test_consistent_full_like(test_case, shape, placement, sbp):
def _test_global_full_like(test_case, shape, placement, sbp):
x_ = flow.randn(shape)
x = flow.full_like(x_, 1.0, placement=placement, sbp=sbp)

Expand All @@ -34,7 +34,7 @@ def _test_consistent_full_like(test_case, shape, placement, sbp):


def _test_graph_full_like(test_case, shape, placement, sbp):
class ConsistentFullLikeGraph(flow.nn.Graph):
class GlobalFullLikeGraph(flow.nn.Graph):
def __init__(self,):
super().__init__()

Expand All @@ -43,24 +43,24 @@ def build(self):
x = flow.full_like(x_, 1.0, placement=placement, sbp=sbp)
return x

model = ConsistentFullLikeGraph()
model = GlobalFullLikeGraph()
x = model()

test_case.assertEqual(x.shape, flow.Size(shape))
test_case.assertEqual(x.sbp, sbp)
test_case.assertEqual(x.placement, placement)


class TestFillLikeConsistent(flow.unittest.TestCase):
class TestFillLikeGlobal(flow.unittest.TestCase):
@globaltest
def test_full_like_consistent(test_case):
def test_full_like_global(test_case):
shapes = [(8,), (8, 8,), (8, 8, 8)]
for shape in shapes:
for placement in all_placement():
for sbp in all_sbp(
placement, max_dim=len(shape), except_partial_sum=True
):
_test_consistent_full_like(test_case, shape, placement, sbp)
_test_global_full_like(test_case, shape, placement, sbp)

@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n2d()
Expand Down

0 comments on commit b542e15

Please sign in to comment.