Skip to content

Commit

Permalink
Merge branch 'main' into Communicator-not-properly-initialized-when-c…
Browse files Browse the repository at this point in the history
…reating-new-DNDarrays-in-some-routines/1074-my-bug-fix
  • Loading branch information
AsRaNi1 authored Jan 30, 2023
2 parents 3acb3e3 + 8597417 commit 6ff1d6f
Show file tree
Hide file tree
Showing 25 changed files with 401 additions and 122 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/changelog-updater.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ jobs:
- name: Commit updated CHANGELOG
uses: stefanzweifel/git-auto-commit-action@v4
with:
branch: main
branch: release/1.2.x
commit_message: Update CHANGELOG
file_pattern: CHANGELOG.md
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
jobs:
approved:
if: github.event.review.state == 'approved'
runs-on: ubuntu-20.04
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
Expand Down
12 changes: 7 additions & 5 deletions .github/workflows/latest-pytorch-support.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ env:
new_major: $(<.github/pytorch-release-versions/pytorch-latest.txt | cut -d'.' -f1)
new_minor: $(<.github/pytorch-release-versions/pytorch-latest.txt | cut -d'.' -f2)
new_patch: $(<.github/pytorch-release-versions/pytorch-latest.txt | cut -d'.' -f3)
new_setup_patch: $(($env.new_patch+1))
new_setup_pytorch: ("$env.new_major"."$env.new_minor"."$env.new_setup_patch")
permissions:
contents: write
issues: write
Expand Down Expand Up @@ -39,12 +37,16 @@ jobs:
with:
token: ${{ secrets.GITHUB_TOKEN }}
ref: 'release/1.2.x'
- name: Increment patch
run: |
echo "new_setup_patch=$((${{ env.new_patch }}+1))" >> $GITHUB_ENV
- name: Define version string
run: |
echo "new_setup_pytorch=$("${{ env.new_major }}"."${{ env.new_minor }}"."${{ env.new_setup_patch }}")" >> $GITHUB_ENV
- name: Update setup.py
run: |
echo ${{ env.previous_pytorch }}
echo ${{ env.new_pytorch }}
echo ${{ env.new_setup_pytorch }}
sed -i '/torch>=/ s/'"${{ env.previous_pytorch }}"'/'"${{ env.new_setup_pytorch }}"'/g' setup.py
sed -i 's/'"${{ env.previous_pytorch }}"'/'"${{ env.new_pytorch }}"'/g' .github/pytorch-release-versions/pytorch-latest.txt
- name: Define env variable
run: |
echo "new=$(<.github/pytorch-release-versions/pytorch-latest.txt)" >> $GITHUB_ENV
Expand Down
4 changes: 2 additions & 2 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
test:
image: nvidia/cuda:11.6.2-runtime-ubuntu20.04
image: nvidia/cuda:11.7.1-runtime-ubuntu22.04
tags:
- cuda
- x86_64
Expand All @@ -9,7 +9,7 @@ test:
- DEBIAN_FRONTEND=noninteractive apt -y install libopenmpi-dev openmpi-bin openmpi-doc
- apt -y install libhdf5-openmpi-dev libpnetcdf-dev
- pip install pytest coverage
- pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
- pip3 install torch torchvision torchaudio
- pip install .[hdf5,netcdf]
- COVERAGE_FILE=report/cov/coverage1 HEAT_TEST_USE_DEVICE=cpu mpirun --allow-run-as-root -n 1 coverage run --source=heat --parallel-mode -m pytest --junitxml=report/test/report1.xml heat/
- COVERAGE_FILE=report/cov/coverage2 HEAT_TEST_USE_DEVICE=gpu mpirun --allow-run-as-root -n 3 coverage run --source=heat --parallel-mode -m pytest --junitxml=report/test/report3.xml heat/
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ repos:
hooks:
- id: black
- repo: https://github.com/pycqa/pydocstyle
rev: 6.2.3 # pick a git hash / tag to point to
rev: 6.3.0 # pick a git hash / tag to point to
hooks:
- id: pydocstyle
exclude: 'tests|benchmarks|examples|scripts|setup.py' #|heat/utils/data/mnist.py|heat/utils/data/_utils.py ?
19 changes: 19 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
# v1.2.1

## Changes

- #1048 Support PyTorch 1.13.0 on branch release/1.2.x (by @github-actions)

## 🐛 Bug Fixes

- #1038 Lanczos decomposition `linalg.solver.lanczos`: Support double precision, complex data types (by @ClaudiaComito)
- #1034 `ht.array`, closed loophole allowing `DNDarray` construction with incompatible shapes of local arrays (by @Mystic-Slice)

## Linear Algebra

- #1038 Lanczos decomposition `linalg.solver.lanczos`: Support double precision, complex data types (by @ClaudiaComito)

## 🧪 Testing

- #1025 mirror repository on gitlab + ci (by @mtar)
- #1014 fix: set cuda rng state on gpu tests for test_random.py (by @JuanPedroGHM)

# v1.2.0

Expand Down
42 changes: 28 additions & 14 deletions heat/core/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,11 @@ def counts_displs_shape(

@classmethod
def mpi_type_and_elements_of(
cls, obj: Union[DNDarray, torch.Tensor], counts: Tuple[int], displs: Tuple[int]
cls,
obj: Union[DNDarray, torch.Tensor],
counts: Tuple[int],
displs: Tuple[int],
is_contiguous: Optional[bool],
) -> Tuple[MPI.Datatype, Tuple[int, ...]]:
"""
Determines the MPI data type and number of respective elements for the given tensor (:class:`~heat.core.dndarray.DNDarray`
Expand All @@ -264,12 +268,18 @@ def mpi_type_and_elements_of(
Optional counts arguments for variable MPI-calls (e.g. Alltoallv)
displs : Tuple[ints,...], optional
Optional displacements arguments for variable MPI-calls (e.g. Alltoallv)
is_contiguous: bool
Information on global contiguity of the memory-distributed object. If `None`, it will be set to local contiguity via ``torch.Tensor.is_contiguous()``.
# ToDo: The option to explicitely specify the counts and displacements to be send still needs propper implementation
"""
mpi_type, elements = cls.__mpi_type_mappings[obj.dtype], torch.numel(obj)

# simple case, continuous memory can be transmitted as is
if obj.is_contiguous():
# simple case, contiguous memory can be transmitted as is
if is_contiguous is None:
# determine local contiguity
is_contiguous = obj.is_contiguous()

if is_contiguous:
if counts is None:
return mpi_type, elements
else:
Expand All @@ -282,7 +292,7 @@ def mpi_type_and_elements_of(
),
)

# non-continuous memory, e.g. after a transpose, has to be packed in derived MPI types
# non-contiguous memory, e.g. after a transpose, has to be packed in derived MPI types
elements = obj.shape[0]
shape = obj.shape[1:]
strides = [1] * len(shape)
Expand Down Expand Up @@ -314,7 +324,11 @@ def as_mpi_memory(cls, obj) -> MPI.memory:

@classmethod
def as_buffer(
cls, obj: torch.Tensor, counts: Tuple[int] = None, displs: Tuple[int] = None
cls,
obj: torch.Tensor,
counts: Tuple[int] = None,
displs: Tuple[int] = None,
is_contiguous: Optional[bool] = None,
) -> List[Union[MPI.memory, Tuple[int, int], MPI.Datatype]]:
"""
Converts a passed ``torch.Tensor`` into a memory buffer object with associated number of elements and MPI data type.
Expand All @@ -327,14 +341,16 @@ def as_buffer(
Optional counts arguments for variable MPI-calls (e.g. Alltoallv)
displs : Tuple[int,...], optional
Optional displacements arguments for variable MPI-calls (e.g. Alltoallv)
is_contiguous: bool, optional
Optional information on global contiguity of the memory-distributed object.
"""
squ = False
if not obj.is_contiguous() and obj.ndim == 1:
# this makes the math work below this function.
obj.unsqueeze_(-1)
squ = True
mpi_type, elements = cls.mpi_type_and_elements_of(obj, counts, displs)

mpi_type, elements = cls.mpi_type_and_elements_of(obj, counts, displs, is_contiguous)
mpi_mem = cls.as_mpi_memory(obj)
if squ:
# the squeeze happens in the mpi_type_and_elements_of function in the case of a
Expand Down Expand Up @@ -1046,7 +1062,6 @@ def __allgather_like(
type(sendbuf)
)
)

# unpack the receive buffer
if isinstance(recvbuf, tuple):
recvbuf, recv_counts, recv_displs = recvbuf
Expand All @@ -1062,17 +1077,18 @@ def __allgather_like(

# keep a reference to the original buffer object
original_recvbuf = recvbuf

sbuf_is_contiguous, rbuf_is_contiguous = None, None
# permute the send_axis order so that the split send_axis is the first to be transmitted
if axis != 0:
send_axis_permutation = list(range(sendbuf.ndimension()))
send_axis_permutation[0], send_axis_permutation[axis] = axis, 0
sendbuf = sendbuf.permute(*send_axis_permutation)
sbuf_is_contiguous = False

if axis != 0:
recv_axis_permutation = list(range(recvbuf.ndimension()))
recv_axis_permutation[0], recv_axis_permutation[axis] = axis, 0
recvbuf = recvbuf.permute(*recv_axis_permutation)
rbuf_is_contiguous = False
else:
recv_axis_permutation = None

Expand All @@ -1083,20 +1099,18 @@ def __allgather_like(
if sendbuf is MPI.IN_PLACE or not isinstance(sendbuf, torch.Tensor):
mpi_sendbuf = sbuf
else:
mpi_sendbuf = self.as_buffer(sbuf, send_counts, send_displs)
mpi_sendbuf = self.as_buffer(sbuf, send_counts, send_displs, sbuf_is_contiguous)
if send_counts is not None:
mpi_sendbuf[1] = mpi_sendbuf[1][0][self.rank]

if recvbuf is MPI.IN_PLACE or not isinstance(recvbuf, torch.Tensor):
mpi_recvbuf = rbuf
else:
mpi_recvbuf = self.as_buffer(rbuf, recv_counts, recv_displs)
mpi_recvbuf = self.as_buffer(rbuf, recv_counts, recv_displs, rbuf_is_contiguous)
if recv_counts is None:
mpi_recvbuf[1] //= self.size

# perform the scatter operation
exit_code = func(mpi_sendbuf, mpi_recvbuf, **kwargs)

return exit_code, sbuf, rbuf, original_recvbuf, recv_axis_permutation

def Allgather(
Expand Down Expand Up @@ -1269,7 +1283,7 @@ def __alltoall_like(
# keep a reference to the original buffer object
original_recvbuf = recvbuf

# Simple case, continuous buffers can be transmitted as is
# Simple case, contiguous buffers can be transmitted as is
if send_axis < 2 and recv_axis < 2:
send_axis_permutation = list(range(recvbuf.ndimension()))
recv_axis_permutation = list(range(recvbuf.ndimension()))
Expand Down
18 changes: 17 additions & 1 deletion heat/core/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch

from typing import Optional, Union
from typing import Any, Optional, Union

from . import communication

Expand Down Expand Up @@ -74,6 +74,22 @@ def __str__(self) -> str:
"""
return "{}:{}".format(self.device_type, self.device_id)

def __eq__(self, other: Any) -> bool:
"""
Overloads the `==` operator for local equal check.
Parameters
----------
other : Any
The object to compare with
"""
if isinstance(other, Device):
return self.device_type == other.device_type and self.device_id == other.device_id
elif isinstance(other, torch.device):
return self.device_type == other.type and self.device_id == other.index
else:
return NotImplemented


# create a CPU device singleton
cpu = Device("cpu", 0, "cpu")
Expand Down
3 changes: 3 additions & 0 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,8 +1268,11 @@ def resplit_(self, axis: int = None):
axis = sanitize_axis(self.shape, axis)

# early out for unchanged content
if self.comm.size == 1:
self.__split = axis
if axis == self.split:
return self

if axis is None:
gathered = torch.empty(
self.shape, dtype=self.dtype.torch_type(), device=self.device.torch_device
Expand Down
30 changes: 18 additions & 12 deletions heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,26 +510,32 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
if b.dtype != c_type:
b = c_type(b, device=b.device)

# early out for single-process setup, torch matmul
if a.comm.size == 1:
ret = factories.array(torch.matmul(a.larray, b.larray), device=a.device)
if gpu_int_flag:
ret = og_type(ret, device=a.device)
return ret

if a.split is None and b.split is None: # matmul from torch
if len(a.gshape) < 2 or len(b.gshape) < 2 or not allow_resplit:
# if either of A or B is a vector
ret = factories.array(torch.matmul(a.larray, b.larray), device=a.device, comm=a.comm)
if gpu_int_flag:
ret = og_type(ret, device=a.device)
return ret
else:
a.resplit_(0)
slice_0 = a.comm.chunk(a.shape, a.split)[2][0]
hold = a.larray @ b.larray

c = factories.zeros(
(a.gshape[-2], b.gshape[1]), dtype=c_type, device=a.device, comm=a.comm
)
c.larray[slice_0.start : slice_0.stop, :] += hold
c.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
if gpu_int_flag:
c = og_type(c, device=a.device)
return c
a.resplit_(0)
slice_0 = a.comm.chunk(a.shape, a.split)[2][0]
hold = a.larray @ b.larray

c = factories.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type, device=a.device, comm=a.comm)
c.larray[slice_0.start : slice_0.stop, :] += hold
c.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
if gpu_int_flag:
c = og_type(c, device=a.device)
return c


# if they are vectors they need to be expanded to be the proper dimensions
vector_flag = False # flag to run squeeze at the end of the function
Expand Down
26 changes: 23 additions & 3 deletions heat/core/linalg/tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,14 @@ def test_inv(self):
self.assertTupleEqual(ainv.shape, a.shape)
self.assertTrue(ht.allclose(ainv, ares, atol=1e-6))

# distributed
a = ht.array([[5.0, -3, 2], [-3, 2, -1], [-3, 2, -2]], split=0)
ainv = ht.linalg.inv(a)
self.assertEqual(ainv.split, a.split)
self.assertEqual(ainv.device, a.device)
self.assertTupleEqual(ainv.shape, a.shape)
self.assertTrue(ht.allclose(ainv, ares, atol=1e-6))

ares = ht.array([[2.0, 2, 1], [3, 4, 1], [0, 1, -1]], split=1)
a = ht.array([[5.0, -3, 2], [-3, 2, -1], [-3, 2, -2]], split=1)
ainv = ht.linalg.inv(a)
self.assertEqual(ainv.split, a.split)
Expand Down Expand Up @@ -281,14 +281,15 @@ def test_inv(self):
self.assertTrue(ht.allclose(ainv, ares, atol=1e-6))

# pivoting row change
ares = ht.array([[-1, 0, 2], [2, 0, -1], [-6, 3, 0]], dtype=ht.double) / 3.0
ares = ht.array([[-1, 0, 2], [2, 0, -1], [-6, 3, 0]], dtype=ht.double, split=0) / 3.0
a = ht.array([[1, 2, 0], [2, 4, 1], [2, 1, 0]], dtype=ht.double, split=0)
ainv = ht.linalg.inv(a)
self.assertEqual(ainv.split, a.split)
self.assertEqual(ainv.device, a.device)
self.assertTupleEqual(ainv.shape, a.shape)
self.assertTrue(ht.allclose(ainv, ares, atol=1e-6))

ares = ht.array([[-1, 0, 2], [2, 0, -1], [-6, 3, 0]], dtype=ht.double, split=1) / 3.0
a = ht.array([[1, 2, 0], [2, 4, 1], [2, 1, 0]], dtype=ht.double, split=1)
ainv = ht.linalg.inv(a)
self.assertEqual(ainv.split, a.split)
Expand Down Expand Up @@ -365,9 +366,28 @@ def test_matmul(self):
self.assertEqual(ret00.shape, (n, k))
self.assertEqual(ret00.dtype, ht.float)
self.assertEqual(ret00.split, None)
self.assertEqual(a.split, 0)
if a.comm.size > 1:
self.assertEqual(a.split, 0)
self.assertEqual(b.split, None)

# splits 0 None on 1 process
if a.comm.size == 1:
a = ht.ones((n, m), split=0)
b = ht.ones((j, k), split=None)
a[0] = ht.arange(1, m + 1)
a[:, -1] = ht.arange(1, n + 1)
b[0] = ht.arange(1, k + 1)
b[:, 0] = ht.arange(1, j + 1)
ret00 = ht.matmul(a, b, allow_resplit=True)

self.assertEqual(ht.all(ret00 == ht.array(a_torch @ b_torch)), 1)
self.assertIsInstance(ret00, ht.DNDarray)
self.assertEqual(ret00.shape, (n, k))
self.assertEqual(ret00.dtype, ht.float)
self.assertEqual(ret00.split, None)
self.assertEqual(a.split, 0)
self.assertEqual(b.split, None)

if a.comm.size > 1:
# splits 00
a = ht.ones((n, m), split=0, dtype=ht.float64)
Expand Down
Loading

0 comments on commit 6ff1d6f

Please sign in to comment.