Skip to content

Commit

Permalink
Fix bug + Add Tests + Enhance docstrings (shape_equal) (#751)
Browse files Browse the repository at this point in the history
* Enhance shape_equal Functionality and Tests

* python -m Black

* add KerasTensor.shape

* fix IndexError

* remove print

* remove raise ValueError

* back to initial code

* fix docstrings

* E501 line too long

* axis: A list or tuple of integers

* fix examples

* fix bug axis = [axis]

* fix foramt

* add test Using axis=1

* fix .shape bug

* add Example shape_equal with axis=[1,2]

* fix add() Example

* fix absolute Examples output

* Fix bug+Add Tests+Enhance docstrings (shape_equal)

* Fix bug+Add Tests+Enhance docstrings (shape_equal)

* Fix bug+Add Tests+Enhance docstrings shape_equal
  • Loading branch information
Faisal-Alsrheed authored Aug 20, 2023
1 parent 7adbaf7 commit db4f8ae
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 8 deletions.
46 changes: 38 additions & 8 deletions keras_core/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,21 +200,50 @@ def shape_equal(shape1, shape2, axis=None, allow_none=True):
"""Check if two shapes are equal.
Args:
shape1: A tuple or list of integers.
shape2: A tuple or list of integers.
axis: int or list/tuple of ints, defaults to `None`. If specified, the
shape check will ignore the axes specified by `axis`.
allow_none: bool, defaults to `True`. If `True`, `None` in the shape
will match any value.
shape1: A list or tuple of integers for first shape to be compared.
shape2: A list or tuple of integers for second shape to be compared.
axis: An integer, list, or tuple of integers (optional):
Axes to ignore during comparison. Default is `None`.
allow_none (bool, optional): If `True`, allows `None` in a shape
to match any value in the corresponding position of the other shape.
Default is `True`.
Returns:
bool: `True` if shapes are considered equal based on the criteria,
`False` otherwise.
Examples:
>>> shape_equal((32, 64, 128), (32, 64, 128))
True
>>> shape_equal((32, 64, 128), (32, 64, 127))
False
>>> shape_equal((32, 64, None), (32, 64, 128), allow_none=True)
True
>>> shape_equal((32, 64, None), (32, 64, 128), allow_none=False)
False
>>> shape_equal((32, 64, 128), (32, 63, 128), axis=1)
True
>>> shape_equal((32, 64, 128), (32, 63, 127), axis=(1, 2))
True
>>> shape_equal((32, 64, 128), (32, 63, 127), axis=[1,2])
True
>>> shape_equal((32, 64), (32, 64, 128))
False
"""
if len(shape1) != len(shape2):
return False

shape1 = list(shape1)
shape2 = list(shape2)

if axis is not None:
if isinstance(axis, int):
axis = [axis]
for ax in axis:
shape1[ax] = -1
shape2[ax] = -1

if allow_none:
for i in range(len(shape1)):
if shape1[i] is None:
Expand Down Expand Up @@ -246,9 +275,10 @@ def absolute(x):
An array containing the absolute value of each element in `x`.
Example:
>>> x = keras_core.ops.convert_to_tensor([-1.2, 1.2])
>>> keras_core.ops.absolute(x)
array([1.2 1.2], shape=(2,), dtype=float32)
array([1.2, 1.2], dtype=float32)
"""
if any_symbolic_tensors((x,)):
return Absolute().symbolic_call(x)
Expand Down Expand Up @@ -291,7 +321,7 @@ def add(x1, x2):
>>> x1 = keras_core.ops.convert_to_tensor([1, 4])
>>> x2 = keras_core.ops.convert_to_tensor([5, 6])
>>> keras_core.ops.add(x1, x2)
array([ 6 10], shape=(2,), dtype=int32)
array([6, 10], dtype=int32)
`keras_core.ops.add` also broadcasts shapes:
>>> x1 = keras_core.ops.convert_to_tensor(
Expand Down
26 changes: 26 additions & 0 deletions keras_core/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,32 @@ def test_xor(self):
y = KerasTensor((2, None))
self.assertEqual(knp.logical_xor(x, y).shape, (2, 3))

def test_shape_equal_basic_equality(self):
x = KerasTensor([3, 4]).shape
y = KerasTensor([3, 4]).shape
self.assertTrue(knp.shape_equal(x, y))
y = KerasTensor([3, 5]).shape
self.assertFalse(knp.shape_equal(x, y))

def test_shape_equal_allow_none(self):
x = KerasTensor([3, 4, None]).shape
y = KerasTensor([3, 4, 5]).shape
self.assertTrue(knp.shape_equal(x, y, allow_none=True))
self.assertFalse(knp.shape_equal(x, y, allow_none=False))

def test_shape_equal_different_shape_lengths(self):
x = KerasTensor([3, 4]).shape
y = KerasTensor([3, 4, 5]).shape
self.assertFalse(knp.shape_equal(x, y))

def test_shape_equal_ignore_axes(self):
x = KerasTensor([3, 4, 5]).shape
y = KerasTensor([3, 6, 5]).shape
self.assertTrue(knp.shape_equal(x, y, axis=1))
y = KerasTensor([3, 6, 7]).shape
self.assertTrue(knp.shape_equal(x, y, axis=(1, 2)))
self.assertFalse(knp.shape_equal(x, y, axis=1))


class NumpyTwoInputOpsStaticShapeTest(testing.TestCase):
def test_add(self):
Expand Down

0 comments on commit db4f8ae

Please sign in to comment.