Skip to content

Commit

Permalink
Fix eq none bug (#7938)
Browse files Browse the repository at this point in the history
* fix reduce_sum scalar check bug

* fix equal none bug

* refine

* auto format by CI

Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 3, 2022
1 parent b8547c6 commit ded23a5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
7 changes: 6 additions & 1 deletion python/oneflow/framework/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,12 @@ def _meta_repr(self):


def _eq(self, other):
return flow._C.equal(self, other)
if self is None and other is None:
return True
elif self is None or other is None:
return False
else:
return flow._C.equal(self, other)


def _ne(self, other):
Expand Down
9 changes: 9 additions & 0 deletions python/oneflow/test/tensor/test_tensor_part_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,15 @@ def test_tensor_int(test_case):
y = int(x)
test_case.assertTrue(np.array_equal(y, 2))

def test_none_equal(test_case):
xt = flow.randn(10)
yt = flow.randn(10)
z = None in [xt, yt]
test_case.assertTrue(np.array_equal(z, False))
zt = None
z = None in [xt, yt, zt]
test_case.assertTrue(np.array_equal(z, True))


if __name__ == "__main__":
unittest.main()

0 comments on commit ded23a5

Please sign in to comment.