Skip to content

Commit

Permalink
Add binary functor exception (#8161)
Browse files Browse the repository at this point in the history
* fix eager free tensor bug when in job

* add binary functor exception

* add mul and mul_ exception

* add div exception

* refine

* format

* fix ci error

* fix ci error

* fix ci error

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
BBuf and mergify[bot] authored May 15, 2022
1 parent 4787170 commit 5e6451c
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 5 deletions.
2 changes: 2 additions & 0 deletions cmake/caches/cn/fast/cuda-75-clang.cmake
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
set(CMAKE_C_COMPILER "clang" CACHE STRING "")
set(WITH_MLIR YES CACHE BOOL "")
set(WITH_MLIR_CUDA_CODEGEN YES CACHE BOOL "")
set(CMAKE_CXX_COMPILER "clang++" CACHE STRING "")
set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "")
set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "")
Expand Down
12 changes: 9 additions & 3 deletions oneflow/core/functional/impl/binary_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,17 @@ class AddFunctor {
return Error::RuntimeError()
<< "For integral input tensors, argument alpha must not be a floating point number.";
}

bool input_static_zeros = IsStaticZerosTensor(input);
if (input_static_zeros || IsStaticZerosTensor(other)) {
CHECK_OR_RETURN(JUST(input->device()) == JUST(other->device()));
CHECK_OR_RETURN(*input->shape() == *other->shape());
CHECK_OR_RETURN(input->dtype() == other->dtype());
CHECK_OR_RETURN(JUST(input->device()) == JUST(other->device()))
<< Error::RuntimeError()
<< "Expected all tensors to be on the same device, but found at least two devices, "
<< JUST(input->device())->ToString() << " and " << JUST(other->device())->ToString()
<< "!";
CHECK_OR_RETURN(*input->shape() == *other->shape())
<< Error::RuntimeError() << "The size of tensor a " << input->shape()->ToString()
<< " must match the size of tensor b " << other->shape();
if (input_static_zeros) {
if ((alpha.IsIntegral() && alpha.Value<int64_t>() == 1)
|| (alpha.IsFloatingPoint()
Expand Down
5 changes: 3 additions & 2 deletions oneflow/core/functional/impl/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Maybe<void> CheckInplaceValid(const std::shared_ptr<Tensor>& x) {
Maybe<void> CheckInplaceCastValid(const std::shared_ptr<Tensor>& x,
const std::shared_ptr<Tensor>& x_cast) {
CHECK_OR_RETURN(*x->dtype() == *x_cast->dtype())
<< "RuntimeError: result type " << x_cast->dtype()->name()
<< Error::RuntimeError() << "result type " << x_cast->dtype()->name()
<< " can't be cast to the desired output type " << x->dtype()->name();
return Maybe<void>::Ok();
}
Expand All @@ -90,7 +90,8 @@ bool IsShapeCanExpandTo(const Shape& shape, const Shape& expand_shape) {

Maybe<void> CheckShapeCanExpandTo(const Shape& shape, const Shape& expand_shape) {
CHECK_OR_RETURN(IsShapeCanExpandTo(shape, expand_shape))
<< "Can not expand shape " << shape.ToString() << " to " << expand_shape.ToString();
<< Error::RuntimeError() << "Can not expand shape " << shape.ToString() << " to "
<< expand_shape.ToString();
return Maybe<void>::Ok();
}

Expand Down
83 changes: 83 additions & 0 deletions python/oneflow/test/exceptions/test_binary_functor_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict

import os
import numpy as np
import time
import oneflow as flow
import oneflow.unittest


class TestBinaryFunctorError(flow.unittest.TestCase):
def test_add_inplace_runtime_error(test_case):
with test_case.assertRaises(RuntimeError) as context:
x = flow.ones((4, 4), dtype=flow.float32, requires_grad=True)
y = flow.ones((4, 4), dtype=flow.float32, requires_grad=True)
x.add_(y)
test_case.assertTrue(
"a leaf Tensor that requires grad is being used in an in-place operation"
in str(context.exception)
)

def test_add_broad_cast_runtime_error(test_case):
with test_case.assertRaises(RuntimeError) as context:
x = flow.ones((2, 3))
y = flow.ones((2, 4))
x.add_(y)
test_case.assertTrue(
"Can not expand shape (2,4) to (2,3)" in str(context.exception)
)

with test_case.assertRaises(RuntimeError) as context:
x = flow.ones((4, 4), dtype=flow.float32, requires_grad=True)
y = flow.ones((4, 4), dtype=flow.float32, requires_grad=True)
x.mul_(y)
test_case.assertTrue(
"a leaf Tensor that requires grad is being used in an in-place operation"
in str(context.exception)
)

with test_case.assertRaises(RuntimeError) as context:
x = flow.ones((2, 3))
y = flow.ones((2, 4))
x.mul_(y)
test_case.assertTrue(
"Can not expand shape (2,4) to (2,3)" in str(context.exception)
)

def test_div_inplace_runtime_error(test_case):
with test_case.assertRaises(RuntimeError) as context:
x = flow.ones((4, 4), dtype=flow.float32, requires_grad=True)
y = flow.ones((4, 4), dtype=flow.float32, requires_grad=True)
x.div_(y)
test_case.assertTrue(
"a leaf Tensor that requires grad is being used in an in-place operation"
in str(context.exception)
)

with test_case.assertRaises(RuntimeError) as context:
x = flow.ones((2, 3))
y = flow.ones((2, 4))
x.div_(y)
test_case.assertTrue(
"Can not expand shape (2,4) to (2,3)" in str(context.exception)
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5e6451c

Please sign in to comment.