Skip to content

Commit

Permalink
Signal processing: fully distributed 1D convolution (#983)
Browse files Browse the repository at this point in the history
* first commit

* started distributed kernel support

* fixed communication between processes

* storing values from all calculated signals through communication

* pads weights when kerenl is uneven

* flipped input kernel dndarray

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* calculating correct convolution across all processes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added example and minor changes

* minor change

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixing pre-commit hooks

* swap a and v when v is larger

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* used bcast

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added tests for distributed kernels

* Accumulate filtered_signal in 1D within first loop

* Fix split axis of  when signal is distributed

* avoid empty local chunk condition

* pre-commit auto fixes

* added test for large random signal and removed earlier implementation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* supported and addded test for all modes

* added example and refactored

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added support for scalars and corrected halo_size

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* reformatted

* coorected halo_size

* cast t_v to float on cuda

* error message on unbalanced weights

* error message on unbalanced weight

* resolved device issue

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Claudia Comito <c.comito@fz-juelich.de>
Co-authored-by: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com>
  • Loading branch information
4 people authored Jan 24, 2023
1 parent 799abef commit 62410e7
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 63 deletions.
151 changes: 104 additions & 47 deletions heat/core/signal.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
"""Provides a collection of signal-processing operations"""

import torch
from typing import Union, Tuple, Sequence
import numpy as np

from .communication import MPI
from .dndarray import DNDarray
from .types import promote_types
from .manipulations import pad
from .factories import array
from .manipulations import pad, flip
from .factories import array, zeros
import torch.nn.functional as fc

__all__ = ["convolve"]


def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
"""
Returns the discrete, linear convolution of two one-dimensional `DNDarray`s.
Returns the discrete, linear convolution of two one-dimensional `DNDarray`s or scalars.
Parameters
----------
a : DNDarray
One-dimensional signal `DNDarray` of shape (N,)
v : DNDarray
One-dimensional filter weight `DNDarray` of shape (M,).
a : DNDarray or scalar
One-dimensional signal `DNDarray` of shape (N,) or scalar.
v : DNDarray or scalar
One-dimensional filter weight `DNDarray` of shape (M,) or scalar.
mode : str
Can be 'full', 'valid', or 'same'. Default is 'full'.
'full':
Expand All @@ -40,15 +40,6 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
overlap completely. Values outside the signal boundary have no
effect.
Notes
-----
Contrary to the original `numpy.convolve`, this function does not
swap the input arrays if the second one is larger than the first one.
This is because `a`, the signal, might be memory-distributed,
whereas the filter `v` is assumed to be non-distributed,
i.e. a copy of `v` will reside on each process.
Examples
--------
Note how the convolution operator flips the second array
Expand All @@ -62,7 +53,27 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
DNDarray([1., 3., 3., 3., 3.])
>>> ht.convolve(a, v, mode='valid')
DNDarray([3., 3., 3.])
>>> a = ht.ones(10, split = 0)
>>> v = ht.arange(3, split = 0).astype(ht.float)
>>> ht.convolve(a, v, mode='valid')
DNDarray([3., 3., 3., 3., 3., 3., 3., 3.])
[0/3] DNDarray([3., 3., 3.])
[1/3] DNDarray([3., 3., 3.])
[2/3] DNDarray([3., 3.])
>>> a = ht.ones(10, split = 0)
>>> v = ht.arange(3, split = 0)
>>> ht.convolve(a, v)
DNDarray([0., 1., 3., 3., 3., 3., 3., 3., 3., 3., 3., 2.], dtype=ht.float32, device=cpu:0, split=0)
[0/3] DNDarray([0., 1., 3., 3.])
[1/3] DNDarray([3., 3., 3., 3.])
[2/3] DNDarray([3., 3., 3., 2.])
"""
if np.isscalar(a):
a = array([a])
if np.isscalar(v):
v = array([v])
if not isinstance(a, DNDarray):
try:
a = array(a)
Expand All @@ -77,24 +88,25 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
a = a.astype(promoted_type)
v = v.astype(promoted_type)

if v.is_distributed():
raise TypeError("Distributed filter weights are not supported")
if len(a.shape) != 1 or len(v.shape) != 1:
raise ValueError("Only 1-dimensional input DNDarrays are allowed")
if a.shape[0] <= v.shape[0]:
raise ValueError("Filter size must not be greater than or equal to signal size")
if mode == "same" and v.shape[0] % 2 == 0:
raise ValueError("Mode 'same' cannot be used with even-sized kernel")
if not v.is_balanced():
raise ValueError("Only balanced kernel weights are allowed")

if v.shape[0] > a.shape[0]:
a, v = v, a

# compute halo size
halo_size = v.shape[0] // 2
halo_size = torch.max(v.lshape_map[:, 0]).item() // 2

# pad DNDarray with zeros according to mode
if mode == "full":
pad_size = v.shape[0] - 1
gshape = v.shape[0] + a.shape[0] - 1
elif mode == "same":
pad_size = halo_size
pad_size = v.shape[0] // 2
gshape = a.shape[0]
elif mode == "valid":
pad_size = 0
Expand All @@ -105,44 +117,89 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
a = pad(a, pad_size, "constant", 0)

if a.is_distributed():
if (v.shape[0] > a.lshape_map[:, 0]).any():
raise ValueError("Filter weight is larger than the local chunks of signal")
if (v.lshape_map[:, 0] > a.lshape_map[:, 0]).any():
raise ValueError(
"Local chunk of filter weight is larger than the local chunks of signal"
)
# fetch halos and store them in a.halo_next/a.halo_prev
a.get_halo(halo_size)
# apply halos to local array
signal = a.array_with_halos
else:
signal = a.larray

# flip filter for convolution as Pytorch conv1d computes correlations
v = flip(v, [0])
if v.larray.shape != v.lshape_map[0]:
# pads weights if input kernel is uneven
target = torch.zeros(v.lshape_map[0][0], dtype=v.larray.dtype, device=v.larray.device)
pad_size = v.lshape_map[0][0] - v.larray.shape[0]
target[pad_size:] = v.larray
weight = target
else:
weight = v.larray

t_v = weight # stores temporary weight

# make signal and filter weight 3D for Pytorch conv1d function
signal = signal.reshape(1, 1, signal.shape[0])

# flip filter for convolution as Pytorch conv1d computes correlations
weight = v.larray.flip(dims=(0,))
weight = weight.reshape(1, 1, weight.shape[0])

# cast to float if on GPU
if signal.is_cuda:
float_type = promote_types(signal.dtype, torch.float32).torch_type()
signal = signal.to(float_type)
weight = weight.to(float_type)
t_v = t_v.to(float_type)

# apply torch convolution operator
signal_filtered = fc.conv1d(signal, weight)

# unpack 3D result into 1D
signal_filtered = signal_filtered[0, 0, :]

# if kernel shape along split axis is even we need to get rid of duplicated values
if a.comm.rank != 0 and v.shape[0] % 2 == 0:
signal_filtered = signal_filtered[1:]

return DNDarray(
signal_filtered.contiguous(),
(gshape,),
signal_filtered.dtype,
a.split,
a.device,
a.comm,
balanced=False,
).astype(a.dtype.torch_type())
if v.is_distributed():
size = v.comm.size

for r in range(size):
rec_v = v.comm.bcast(t_v, root=r)
t_v1 = rec_v.reshape(1, 1, rec_v.shape[0])
local_signal_filtered = fc.conv1d(signal, t_v1)
# unpack 3D result into 1D
local_signal_filtered = local_signal_filtered[0, 0, :]

if a.comm.rank != 0 and v.lshape_map[0][0] % 2 == 0:
local_signal_filtered = local_signal_filtered[1:]

# accumulate filtered signal on the fly
global_signal_filtered = array(
local_signal_filtered, is_split=0, device=a.device, comm=a.comm
)
if r == 0:
# initialize signal_filtered, starting point of slice
signal_filtered = zeros(
gshape, dtype=a.dtype, split=a.split, device=a.device, comm=a.comm
)
start_idx = 0

# accumulate relevant slice of filtered signal
# note, this is a binary operation between unevenly distributed dndarrays and will require communication, check out _operations.__binary_op()
signal_filtered += global_signal_filtered[start_idx : start_idx + gshape]
if r != size - 1:
start_idx += v.lshape_map[r + 1][0].item()
return signal_filtered

else:
# apply torch convolution operator
signal_filtered = fc.conv1d(signal, weight)

# unpack 3D result into 1D
signal_filtered = signal_filtered[0, 0, :]

# if kernel shape along split axis is even we need to get rid of duplicated values
if a.comm.rank != 0 and v.shape[0] % 2 == 0:
signal_filtered = signal_filtered[1:]

return DNDarray(
signal_filtered.contiguous(),
(gshape,),
signal_filtered.dtype,
a.split,
a.device,
a.comm,
balanced=False,
).astype(a.dtype.torch_type())
70 changes: 54 additions & 16 deletions heat/core/tests/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,59 +20,94 @@ def test_convolve(self):
[0, 1, 3, 6, 10, 14, 18, 22, 26, 30, 34, 38, 42, 46, 50, 54, 42, 29, 15]
).astype(ht.int)

signal = ht.arange(0, 16, split=0).astype(ht.int)
dis_signal = ht.arange(0, 16, split=0).astype(ht.int)
signal = ht.arange(0, 16).astype(ht.int)
full_ones = ht.ones(7, split=0).astype(ht.int)
kernel_odd = ht.ones(3).astype(ht.int)
kernel_even = [1, 1, 1, 1]
dis_kernel_odd = ht.ones(3, split=0).astype(ht.int)
dis_kernel_even = ht.ones(4, split=0).astype(ht.int)

with self.assertRaises(TypeError):
signal_wrong_type = [0, 1, 2, "tre", 4, "five", 6, "ʻehiku", 8, 9, 10]
ht.convolve(signal_wrong_type, kernel_odd, mode="full")
with self.assertRaises(TypeError):
filter_wrong_type = [1, 1, "pizza", "pineapple"]
ht.convolve(signal, filter_wrong_type, mode="full")
ht.convolve(dis_signal, filter_wrong_type, mode="full")
with self.assertRaises(ValueError):
ht.convolve(signal, kernel_odd, mode="invalid")
ht.convolve(dis_signal, kernel_odd, mode="invalid")
with self.assertRaises(ValueError):
s = signal.reshape((2, -1))
s = dis_signal.reshape((2, -1))
ht.convolve(s, kernel_odd)
with self.assertRaises(ValueError):
k = ht.eye(3)
ht.convolve(signal, k)
with self.assertRaises(ValueError):
ht.convolve(kernel_even, full_even)
ht.convolve(dis_signal, k)
with self.assertRaises(ValueError):
ht.convolve(signal, kernel_even, mode="same")
ht.convolve(dis_signal, kernel_even, mode="same")
if self.comm.size > 1:
with self.assertRaises(TypeError):
k = ht.ones(4, split=0).astype(ht.int)
ht.convolve(signal, k)
if self.comm.size >= 5:
with self.assertRaises(ValueError):
ht.convolve(signal, kernel_even, mode="valid")
ht.convolve(full_ones, kernel_even, mode="valid")
with self.assertRaises(ValueError):
ht.convolve(kernel_even, full_ones, mode="valid")
if self.comm.size > 5:
with self.assertRaises(ValueError):
ht.convolve(dis_signal, kernel_even)

# test modes, avoid kernel larger than signal chunk
if self.comm.size <= 3:
modes = ["full", "same", "valid"]
for i, mode in enumerate(modes):
# odd kernel size
conv = ht.convolve(signal, kernel_odd, mode=mode)
conv = ht.convolve(dis_signal, kernel_odd, mode=mode)
gathered = manipulations.resplit(conv, axis=None)
self.assertTrue(ht.equal(full_odd[i : len(full_odd) - i], gathered))

conv = ht.convolve(dis_signal, dis_kernel_odd, mode=mode)
gathered = manipulations.resplit(conv, axis=None)
self.assertTrue(ht.equal(full_odd[i : len(full_odd) - i], gathered))

conv = ht.convolve(signal, dis_kernel_odd, mode=mode)
gathered = manipulations.resplit(conv, axis=None)
self.assertTrue(ht.equal(full_odd[i : len(full_odd) - i], gathered))

# different data types
conv = ht.convolve(signal.astype(ht.float), kernel_odd)
conv = ht.convolve(dis_signal.astype(ht.float), kernel_odd)
gathered = manipulations.resplit(conv, axis=None)
self.assertTrue(ht.equal(full_odd.astype(ht.float), gathered))

conv = ht.convolve(dis_signal.astype(ht.float), dis_kernel_odd)
gathered = manipulations.resplit(conv, axis=None)
self.assertTrue(ht.equal(full_odd.astype(ht.float), gathered))

conv = ht.convolve(signal.astype(ht.float), dis_kernel_odd)
gathered = manipulations.resplit(conv, axis=None)
self.assertTrue(ht.equal(full_odd.astype(ht.float), gathered))

# even kernel size
# skip mode 'same' for even kernels
if mode != "same":
conv = ht.convolve(signal, kernel_even, mode=mode)
conv = ht.convolve(dis_signal, kernel_even, mode=mode)
dis_conv = ht.convolve(dis_signal, dis_kernel_even, mode=mode)
gathered = manipulations.resplit(conv, axis=None)
dis_gathered = manipulations.resplit(dis_conv, axis=None)

if mode == "full":
self.assertTrue(ht.equal(full_even, gathered))
self.assertTrue(ht.equal(full_even, dis_gathered))
else:
self.assertTrue(ht.equal(full_even[3:-3], gathered))
self.assertTrue(ht.equal(full_even[3:-3], dis_gathered))

# distributed large signal and kernel
np.random.seed(12)
np_a = np.random.randint(1000, size=4418)
np_b = np.random.randint(1000, size=1543)
np_conv = np.convolve(np_a, np_b, mode=mode)

a = ht.array(np_a, split=0, dtype=ht.int32)
b = ht.array(np_b, split=0, dtype=ht.int32)
conv = ht.convolve(a, b, mode=mode)
self.assert_array_equal(conv, np_conv)

# test edge cases
# non-distributed signal, size-1 kernel
Expand All @@ -81,3 +116,6 @@ def test_convolve(self):
kernel = ht.ones(1).astype(ht.int)
conv = ht.convolve(alt_signal, kernel)
self.assertTrue(ht.equal(signal, conv))

conv = ht.convolve(1, 5)
self.assertTrue(ht.equal(ht.array([5]), conv))

0 comments on commit 62410e7

Please sign in to comment.