From decce5635a519117db3f9266e07cc7bfdf4b6bb7 Mon Sep 17 00:00:00 2001 From: Yusha Arif <101613943+YushaArif99@users.noreply.github.com> Date: Mon, 5 Feb 2024 16:19:44 +0000 Subject: [PATCH] (feat)(torch frontends): added the frontend functions for `torch.Tensor.bernoulli_` and `torch.Tensor.numel` --- ivy/functional/frontends/torch/random_sampling.py | 4 ++-- ivy/functional/frontends/torch/tensor.py | 15 +++++++++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/ivy/functional/frontends/torch/random_sampling.py b/ivy/functional/frontends/torch/random_sampling.py index 77f1c96f49230..88dacdb40b0ee 100644 --- a/ivy/functional/frontends/torch/random_sampling.py +++ b/ivy/functional/frontends/torch/random_sampling.py @@ -13,9 +13,9 @@ "torch", ) @to_ivy_arrays_and_back -def bernoulli(input, *, generator=None, out=None): +def bernoulli(input, p, *, generator=None, out=None): seed = generator.initial_seed() if generator is not None else None - return ivy.bernoulli(input, seed=seed, out=out) + return ivy.bernoulli(p, logits=input, seed=seed, out=out) @to_ivy_arrays_and_back diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 22195e6f3008d..fbaa3dc37d9c4 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -1201,8 +1201,19 @@ def dot(self, tensor): return torch_frontend.dot(self, tensor) @with_supported_dtypes({"2.1.2 and below": ("float32", "float64")}, "torch") - def bernoulli(self, *, generator=None, out=None): - return torch_frontend.bernoulli(self._ivy_array, generator=generator, out=out) + def bernoulli(self, p, *, generator=None, out=None): + return torch_frontend.bernoulli( + self._ivy_array, p, generator=generator, out=out + ) + + @with_supported_dtypes({"2.1.2 and below": ("float32", "float64")}, "torch") + def bernoulli_(self, p, *, generator=None, out=None): + self.ivy_array = self.bernoulli(p, generator=generator, out=out).ivy_array + return self + + def numel(self): + shape = self.shape + return int(ivy.astype(ivy.prod(shape), ivy.int64)) # Special Methods # # -------------------#