Skip to content

Commit

Permalink
adding property methods to ivy Shape Class (#15091)
Browse files Browse the repository at this point in the history
Co-authored-by: ivy-branch <ivy.branch@lets-unify.ai>
  • Loading branch information
soma2000-lang and ivy-branch authored Jun 29, 2023
1 parent 741fc7f commit 4e8f544
Show file tree
Hide file tree
Showing 5 changed files with 977 additions and 3 deletions.
108 changes: 108 additions & 0 deletions ivy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,114 @@ def __dir__(self):
def shape(self):
return self._shape

@property
def value(self):
return self._value

def concatenate(self, other):
if self._shape is None or other.dims is None:
raise ValueError("Unknown Shape")
else:
return Shape(self.dims + other.dims)

def index(self, index):
assert isinstance(self._shape, Shape)
if self._shape.rank is None:
return Shape(None)
else:
return self._shape[index]

@property
def shape(self):
return self._shap

def as_dimension(self):
if isinstance(self._shape, Shape):
return self._shape
else:
return Shape(self._shape)

def is_compatible_with(self, other):
return self._shape is None or other.value is None or self._shape == other.value

@property
def rank(self):
"""Returns the rank of this shape, or None if it is unspecified."""
if self._shape is not None:
return len(self._shape)
return None

def assert_same_rank(self, other):
other = Shape(other)
if self.rank != other.rank:
raise ValueError("Shapes %s and %s must have the same rank" % (self, other))

def assert_has_rank(self, rank):
if self.rank not in (None, rank):
raise ValueError("Shape %s must have rank %d" % (self, rank))

def unknown_shape(rank=None, **kwargs):
if rank is None and "ndims" in kwargs:
rank = kwargs.pop("ndims")
if kwargs:
raise TypeError("Unknown argument: %s" % kwargs)
if rank is None:
return Shape(None)
else:
return Shape([Shape(None)] * rank)

def with_rank(self, rank):
try:
return self.merge_with(unknown_shape(rank=rank))
except ValueError:
raise ValueError("Shape %s must have rank %d" % (self, rank))

def with_rank_at_least(self, rank):
if self.rank is not None and self.rank < rank:
raise ValueError("Shape %s must have rank at least %d" % (self, rank))
else:
return self

def with_rank_at_most(self, rank):
if self.rank is not None and self.rank > rank:
raise ValueError("Shape %s must have rank at most %d" % (self, rank))
else:
return self

def as_shape(shape):
if isinstance(shape, Shape):
return shape
else:
return Shape(shape)

@property
def dims(self):
if self._shape is None:
return None
# return [as_dimension(d) for d in self._shape]

@property
def ndims(self):
"""Deprecated accessor for `rank`."""
return self.rank

@property
def is_fully_defined(self):
return self._shape is not None and all(
shape is not None for shape in self._shape
)

property

def num_elements(self):
if not self.is_fully_defined():
return None

@property
def assert_is_fully_defined(self):
if not self.is_fully_defined():
raise ValueError("Shape %s is not fully defined" % self)

def as_list(self):
if self._shape is None:
raise ivy.utils.exceptions.IvyException(
Expand Down
3 changes: 2 additions & 1 deletion ivy/functional/frontends/paddle/tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
def poisson(x, name=None):
return ivy.poisson(x, shape=None, device=None, dtype=None, seed=None, out=None)


def randn(shape, dtype=None, name=None):
if dtype not in ["float32", "float64"]:
raise ivy.exceptions.IvyError(
Expand All @@ -47,3 +47,4 @@ def uniform_(x, min=-1.0, max=1.0, seed=0, name=None):
return ivy.random_uniform(
low=min, high=max, shape=x.shape, dtype=x.dtype, seed=seed
)

6 changes: 4 additions & 2 deletions ivy/functional/frontends/paddle/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,10 @@ def multiply(self, y, name=None):
def isfinite(self, name=None):
return ivy.isfinite(self._ivy_array)

def all(self, axis=None, keepdim=False, name=None):
return ivy.all(self.ivy_array, axis=axis, keepdims=keepdim)
@with_supported_dtypes({"2.4.2 and below": ("float16", "bfloat16")}, "paddle")
def all(self, axis=None, keepdim=False, dtype=None, name=None):
return ivy.all(self.ivy_array, axis=axis, keepdims=keepdim, dtype=dtype)


@with_supported_dtypes({"2.5.0 and below": ("float16", "bfloat16")}, "paddle")
def allclose(self, other, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def test_paddle_randn(
test_values=False,
shape=shape,
dtype=dtype,

)


Expand Down Expand Up @@ -162,3 +163,4 @@ def test_paddle_uniform_(
max=max,
seed=seed,
)

Loading

0 comments on commit 4e8f544

Please sign in to comment.