Skip to content

Commit

Permalink
Support 0D for equal tensor with scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
zhwesky2010 committed Feb 24, 2023
1 parent 31e465e commit 4b40763
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
25 changes: 23 additions & 2 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,18 @@ def test_shape(self):
self.assertEqual(out.shape, [0])
np.testing.assert_array_equal(out.numpy(), np.array([]))

def test_pow_factor(self):
def test_equal_scalar(self):
x = paddle.rand([])
out = paddle.equal(x, 2.0)
self.assertEqual(out.shape, [])
self.assertEqual(out, False)

x1 = paddle.full([], 2.0)
out1 = paddle.equal(x1, 2.0)
self.assertEqual(out1.shape, [])
self.assertEqual(out1, True)

def test_pow_scalar(self):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.pow(x, 2.0)
Expand Down Expand Up @@ -1837,7 +1848,17 @@ def test_flip(self):
self.assertEqual(res[3].shape, ())

@prog_scope()
def test_pow_factor(self):
def test_equal_scalar(self):
x = paddle.rand([])
out = paddle.equal(x, 2.0)

prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[0], False)

@prog_scope()
def test_pow_scalar(self):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.pow(x, 2.0)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def equal(x, y, name=None):
)
)
if not isinstance(y, Variable):
y = full(shape=[1], dtype=x.dtype, fill_value=y)
y = full(shape=[], dtype=x.dtype, fill_value=y)

if in_dygraph_mode():
return _C_ops.equal(x, y)
Expand Down

0 comments on commit 4b40763

Please sign in to comment.