Skip to content

Commit

Permalink
Merge pull request #851 from helmholtz-analytics/enhancement/798-logi…
Browse files Browse the repository at this point in the history
…cal-dndarrray

Enhancement/798 logical dndarrray
  • Loading branch information
ClaudiaComito authored Aug 18, 2021
2 parents 50a743a + a89c2ca commit 86413ff
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions heat/core/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,29 @@ def isclose(
output_gshape = stride_tricks.broadcast_shape(t1.gshape, t2.gshape)
res = torch.empty(output_gshape, device=t1.device.torch_device).bool()
t1.comm.Allgather(_local_isclose, res)
result = factories.array(res, dtype=types.bool, device=t1.device, split=t1.split)
result = DNDarray(
res,
gshape=output_gshape,
dtype=types.bool,
split=t1.split,
device=t1.device,
comm=t1.comm,
balanced=t1.is_balanced,
)
else:
if _local_isclose.dim() == 0:
# both x and y are scalars, return a single boolean value
result = bool(factories.array(_local_isclose).item())
result = bool(_local_isclose.item())
else:
result = factories.array(_local_isclose, dtype=types.bool, device=t1.device)
result = DNDarray(
_local_isclose,
gshape=tuple(_local_isclose.shape),
dtype=types.bool,
split=None,
device=t1.device,
comm=t1.comm,
balanced=t1.is_balanced,
)

return result

Expand Down

0 comments on commit 86413ff

Please sign in to comment.