Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates for pytorch 1.6.0 #653

Merged
merged 22 commits into from
Aug 14, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
1477b63
update statics unittests argmax argmin
mtar Aug 7, 2020
0498773
increase torch version requirement
mtar Aug 7, 2020
81079de
sort indices in mpi_argmax, mpi_argmin
mtar Aug 12, 2020
c62f357
black formatting
mtar Aug 12, 2020
f2a6ffc
skip test part for argmax, argmin on gpu
mtar Aug 13, 2020
abdd11a
Merge branch 'master' into bug/argmin_argmax_unittests
coquelin77 Aug 13, 2020
4ff1db8
set argmin/argmax split to None if not distributed
coquelin77 Aug 13, 2020
a99108d
Add device to newly created tensors in get_halo
mtar Aug 13, 2020
ba5dc42
above threshold printing sends data, NOT buffer
coquelin77 Aug 13, 2020
3e8bcdd
removed old prints
coquelin77 Aug 13, 2020
02d75b1
Merge branch 'bug/argmin_argmax_unittests' of https://github.com/helm…
coquelin77 Aug 13, 2020
a4484e1
changelog update + black
coquelin77 Aug 13, 2020
76dcd4c
reruning tests
coquelin77 Aug 13, 2020
270de05
removed spectral tests for >4 processes, NaNs + infs causing failures
coquelin77 Aug 13, 2020
61b593c
version check ensures < 0.5.0
coquelin77 Aug 13, 2020
2aa5822
removed future annotations from dndarray
coquelin77 Aug 13, 2020
611cc44
added '.to()' call in tiling to handle the transfer to GPU
coquelin77 Aug 14, 2020
e2a9a43
gpu fixes
coquelin77 Aug 14, 2020
2dfdbc6
removed tempfile and temp folder, h5 and nc tests use 'os.getcwd()' i…
coquelin77 Aug 14, 2020
d83e8a3
more missing devices in dndarray tests (halo)
coquelin77 Aug 14, 2020
9fdfe8f
more more device call missed
coquelin77 Aug 14, 2020
cfb588b
device corrections...
coquelin77 Aug 14, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
- [#635](https://github.com/helmholtz-analytics/heat/pull/635) `DNDarray.__getitem__` balances and resplits the given key to None if the key is a DNDarray
- [#638](https://github.com/helmholtz-analytics/heat/pull/638) Fix: arange returns float32 with single input of type float & update skipped device tests
- [#639](https://github.com/helmholtz-analytics/heat/pull/639) Bufix: balanced array in demo_knn, changed behaviour of knn
- [#653](https://github.com/helmholtz-analytics/heat/pull/653) Update unittests argmax & argmin + force index order in mpi_argmax & mpi_argmin.

# v0.4.0

Expand Down
20 changes: 16 additions & 4 deletions heat/core/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,8 +1184,14 @@ def mpi_argmax(a, b, _):
rhs = torch.from_numpy(np.frombuffer(b, dtype=np.float64))

# extract the values and minimal indices from the buffers (first half are values, second are indices)
values = torch.stack((lhs.chunk(2)[0], rhs.chunk(2)[0]), dim=1)
indices = torch.stack((lhs.chunk(2)[1], rhs.chunk(2)[1]), dim=1)
idx_l, idx_r = lhs.chunk(2)[1], rhs.chunk(2)[1]

if idx_l[0] < idx_r[0]:
values = torch.stack((lhs.chunk(2)[0], rhs.chunk(2)[0]), dim=1)
indices = torch.stack((idx_l, idx_r), dim=1)
else:
values = torch.stack((rhs.chunk(2)[0], lhs.chunk(2)[0]), dim=1)
indices = torch.stack((idx_r, idx_l), dim=1)

# determine the minimum value and select the indices accordingly
max, max_indices = torch.max(values, dim=1)
Expand All @@ -1201,8 +1207,14 @@ def mpi_argmin(a, b, _):
lhs = torch.from_numpy(np.frombuffer(a, dtype=np.float64))
rhs = torch.from_numpy(np.frombuffer(b, dtype=np.float64))
# extract the values and minimal indices from the buffers (first half are values, second are indices)
values = torch.stack((lhs.chunk(2)[0], rhs.chunk(2)[0]), dim=1)
indices = torch.stack((lhs.chunk(2)[1], rhs.chunk(2)[1]), dim=1)
idx_l, idx_r = lhs.chunk(2)[1], rhs.chunk(2)[1]

if idx_l[0] < idx_r[0]:
values = torch.stack((lhs.chunk(2)[0], rhs.chunk(2)[0]), dim=1)
indices = torch.stack((idx_l, idx_r), dim=1)
else:
values = torch.stack((rhs.chunk(2)[0], lhs.chunk(2)[0]), dim=1)
indices = torch.stack((idx_r, idx_l), dim=1)

# determine the minimum value and select the indices accordingly
min, min_indices = torch.min(values, dim=1)
Expand Down
18 changes: 10 additions & 8 deletions heat/core/tests/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def test_argmax(self):
data = ht.tril(ht.ones((size, size), split=0), k=-1)

result = ht.argmax(data, axis=0)
expected = torch.tensor(np.argmax(data.numpy(), axis=0))
self.assertIsInstance(result, ht.DNDarray)
self.assertEqual(result.dtype, ht.int64)
self.assertEqual(result._DNDarray__array.dtype, torch.int64)
Expand All @@ -73,24 +74,24 @@ def test_argmax(self):
self.assertEqual(result.split, None)
# skip test on gpu; argmax works different
if not (torch.cuda.is_available() and result.device == ht.gpu):
self.assertTrue((result._DNDarray__array != 0).all())
self.assertTrue((result._DNDarray__array == expected).all())

# 2D split tensor, across the axis, output tensor
size = ht.MPI_WORLD.size * 2
data = ht.tril(ht.ones((size, size), split=0), k=-1)

output = ht.empty((size,))
result = ht.argmax(data, axis=0, out=output)

expected = torch.tensor(np.argmax(data.numpy(), axis=0))
self.assertIsInstance(result, ht.DNDarray)
self.assertEqual(output.dtype, ht.int64)
self.assertEqual(output._DNDarray__array.dtype, torch.int64)
self.assertEqual(output.shape, (size,))
self.assertEqual(output.lshape, (size,))
self.assertEqual(output.split, None)
# skip test on gpu; argmax works different
if not (torch.cuda.is_available() and output.device == ht.gpu):
self.assertTrue((output._DNDarray__array != 0).all())
if not (torch.cuda.is_available() and result.device == ht.gpu):
self.assertTrue((output._DNDarray__array == expected).all())

# check exceptions
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -155,6 +156,7 @@ def test_argmin(self):
data = ht.triu(ht.ones((size, size), split=0), k=1)

result = ht.argmin(data, axis=0)
expected = torch.tensor(np.argmin(data.numpy(), axis=0))
self.assertIsInstance(result, ht.DNDarray)
self.assertEqual(result.dtype, ht.int64)
self.assertEqual(result._DNDarray__array.dtype, torch.int64)
Expand All @@ -163,24 +165,24 @@ def test_argmin(self):
self.assertEqual(result.split, None)
# skip test on gpu; argmin works different
if not (torch.cuda.is_available() and result.device == ht.gpu):
self.assertTrue((result._DNDarray__array != 0).all())
self.assertTrue((result._DNDarray__array == expected).all())

# 2D split tensor, across the axis, output tensor
size = ht.MPI_WORLD.size * 2
data = ht.triu(ht.ones((size, size), split=0), k=1)

output = ht.empty((size,))
result = ht.argmin(data, axis=0, out=output)

expected = torch.tensor(np.argmin(data.numpy(), axis=0))
self.assertIsInstance(result, ht.DNDarray)
self.assertEqual(output.dtype, ht.int64)
self.assertEqual(output._DNDarray__array.dtype, torch.int64)
self.assertEqual(output.shape, (size,))
self.assertEqual(output.lshape, (size,))
self.assertEqual(output.split, None)
# skip test on gpu; argmin works different
if not (torch.cuda.is_available() and output.device == ht.gpu):
self.assertTrue((output._DNDarray__array != 0).all())
if not (torch.cuda.is_available() and result.device == ht.gpu):
self.assertTrue((output._DNDarray__array == expected).all())

# check exceptions
with self.assertRaises(TypeError):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering",
],
install_requires=["mpi4py>=3.0.0", "numpy>=1.13.0", "torch>=1.5.0", "scipy>=0.14.0"],
install_requires=["mpi4py>=3.0.0", "numpy>=1.13.0", "torch>=1.6.0", "scipy>=0.14.0"],
extras_require={
"hdf5": ["h5py>=2.8.0"],
"netcdf": ["netCDF4>=1.4.0,<=1.5.2"],
Expand Down