Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support batch 1-d convolution in ht.signal.convolve #1515

Merged
merged 18 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 118 additions & 17 deletions heat/core/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
"""
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
Returns the discrete, linear convolution of two one-dimensional `DNDarray`s or scalars.
If the input ``DNDarray``s have more than one dimension, batch-convolution along the last dimension will be attempted. See below for details.

Parameters
----------
a : DNDarray or scalar
One-dimensional signal `DNDarray` of shape (N,) or scalar.
One-dimensional signal `DNDarray` of shape (N,), or scalar. If ``a`` is more than 1D, it will be treated as a batch of 1D signals.
Distribution along the batch dimension is required for distributed batch processing. See examples for details.
v : DNDarray or scalar
One-dimensional filter weight `DNDarray` of shape (M,) or scalar.
One-dimensional filter weight `DNDarray` of shape (M,), or scalar. If ``v`` is more than 1D, it will be treated as a batch of 1D filter weights.
The batch dimension(s) of ``v`` must match the batch dimension(s) of ``a``.
mode : str
Can be 'full', 'valid', or 'same'. Default is 'full'.
'full':
Expand Down Expand Up @@ -69,6 +72,34 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
[0/3] DNDarray([0., 1., 3., 3.])
[1/3] DNDarray([3., 3., 3., 3.])
[2/3] DNDarray([3., 3., 3., 2.])

>>> a = ht.arange(50, dtype = ht.float64, split=0)
>>> a = a.reshape(10, 5) # 10 signals of length 5
>>> v = ht.arange(3)
>>> ht.convolve(a, v) # batch processing: 10 signals convolved with filter v
DNDarray([[ 0., 0., 1., 4., 7., 10., 8.],
[ 0., 5., 16., 19., 22., 25., 18.],
[ 0., 10., 31., 34., 37., 40., 28.],
[ 0., 15., 46., 49., 52., 55., 38.],
[ 0., 20., 61., 64., 67., 70., 48.],
[ 0., 25., 76., 79., 82., 85., 58.],
[ 0., 30., 91., 94., 97., 100., 68.],
[ 0., 35., 106., 109., 112., 115., 78.],
[ 0., 40., 121., 124., 127., 130., 88.],
[ 0., 45., 136., 139., 142., 145., 98.]], dtype=ht.float64, device=cpu:0, split=0)

>>> v = ht.random.randint(0, 3, (10, 3), split=0) # 10 filters of length 3
>>> ht.convolve(a, v) # batch processing: 10 signals convolved with 10 filters
DNDarray([[ 0., 0., 2., 4., 6., 8., 0.],
[ 5., 6., 7., 8., 9., 0., 0.],
[ 20., 42., 56., 61., 66., 41., 14.],
[ 0., 15., 16., 17., 18., 19., 0.],
[ 20., 61., 64., 67., 70., 48., 0.],
[ 50., 52., 104., 108., 112., 56., 58.],
[ 0., 30., 61., 63., 65., 67., 34.],
[ 35., 106., 109., 112., 115., 78., 0.],
[ 0., 40., 81., 83., 85., 87., 44.],
[ 0., 0., 45., 46., 47., 48., 49.]], dtype=ht.float64, device=cpu:0, split=0)
"""
if np.isscalar(a):
a = array([a])
Expand All @@ -88,34 +119,104 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
a = a.astype(promoted_type)
v = v.astype(promoted_type)

if len(a.shape) != 1 or len(v.shape) != 1:
raise ValueError("Only 1-dimensional input DNDarrays are allowed")
if mode == "same" and v.shape[0] % 2 == 0:
# check if the filter is longer than the signal and swap them if necessary
if v.shape[-1] > a.shape[-1]:
a, v = v, a

batch_processing = False
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
if a.ndim > 1:
# batch processing requires 1D filter OR matching batch dimensions for signal and filter
batch_dims = a.shape[:-1]
# verify that the filter shape is consistent with the signal
if v.ndim > 1:
if v.shape[:-1] != batch_dims:
raise ValueError(
f"Batch dimensions of signal and filter must match. Signal: {a.shape}, Filter: {v.shape}"
)
if a.is_distributed():
if a.split == a.ndim - 1:
raise ValueError(
"Please distribute the signal along the batch dimension, not the signal dimension. For in-place redistribution use `a.resplit_(axis=0)`"
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
)
if v.is_distributed():
if v.ndim == 1:
# gather filter to all ranks
v.resplit_(axis=None)
else:
v.resplit_(axis=a.split)
batch_processing = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is batch_processing = True hard coded?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is batch_processing = True hard coded?

Hi @krajsek , do you mean, it would be better to let the user set it as a keyword argument? If so, I agree.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand things, it is set batch_processing=True and, depending on some conditions, it is set batch_processing=False lateron; in my opinion this is fine and I would not suggest to introduce a kwarg since actually the correct value of batch_processing is already uniquely determined by the other inputs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

o.k., I agree, I also do not advocate for kwarg. Setting variables to some values in code always irritates me, but after reading the code again it makes sense. Maybe we spend a comment why this is hard coded here.

if not batch_processing and v.ndim > 1:
raise ValueError(
f"1-D convolution only supported for 1-dimensional signal and kernel. Signal: {a.shape}, Filter: {v.shape}"
)

if mode == "same" and v.shape[-1] % 2 == 0:
raise ValueError("Mode 'same' cannot be used with even-sized kernel")
if not v.is_balanced():
raise ValueError("Only balanced kernel weights are allowed")

if v.shape[0] > a.shape[0]:
a, v = v, a

# compute halo size
halo_size = torch.max(v.lshape_map[:, 0]).item() // 2

# pad DNDarray with zeros according to mode
# calculate pad size according to mode
if mode == "full":
pad_size = v.shape[0] - 1
gshape = v.shape[0] + a.shape[0] - 1
pad_size = v.shape[-1] - 1
gshape = v.shape[-1] + a.shape[-1] - 1
elif mode == "same":
pad_size = v.shape[0] // 2
gshape = a.shape[0]
pad_size = v.shape[-1] // 2
gshape = a.shape[-1]
elif mode == "valid":
pad_size = 0
gshape = a.shape[0] - v.shape[0] + 1
gshape = a.shape[-1] - v.shape[-1] + 1
else:
raise ValueError(f"Supported modes are 'full', 'valid', 'same', got {mode}")

if batch_processing:
# all operations are local torch operations, only the last dimension is convolved
local_a = a.larray
local_v = v.larray
# flip filter for convolution as Pytorch conv1d computes correlations
local_v = torch.flip(local_v, [-1])
local_batch_dims = tuple(local_a.shape[:-1])

# reshape signal and filter to 3D for Pytorch conv1d function
# see https://pytorch.org/docs/stable/generated/torch.nn.functional.conv1d.html
local_a = local_a.reshape(
torch.prod(torch.tensor(local_batch_dims, device=local_a.device), dim=0).item(),
local_a.shape[-1],
)
channels = local_a.shape[0]
if v.ndim > 1:
local_v = local_v.reshape(
torch.prod(torch.tensor(local_batch_dims, device=local_v.device), dim=0).item(),
local_v.shape[-1],
)
local_v = local_v.unsqueeze(1)
else:
local_v = local_v.unsqueeze(0).unsqueeze(0).expand(local_a.shape[0], 1, -1)
# add batch dimension to signal
local_a = local_a.unsqueeze(0)

# cast to single-precision float if on GPU
if local_a.is_cuda:
float_type = torch.promote_types(local_a.dtype, torch.float32)
local_a = local_a.to(float_type)
local_v = local_v.to(float_type)

# apply torch convolution operator
local_convolved = fc.conv1d(local_a, local_v, padding=pad_size, groups=channels)

# unpack 3D result into original shape
local_convolved = local_convolved.squeeze(0)
local_convolved = local_convolved.reshape(local_batch_dims + (-1,))

# wrap result in DNDarray
convolved = array(local_convolved, is_split=a.split, device=a.device, comm=a.comm)
return convolved

# pad signal with zeros
a = pad(a, pad_size, "constant", 0)

# compute halo size
halo_size = torch.max(v.lshape_map[:, -1]).item() // 2

if a.is_distributed():
if (v.lshape_map[:, 0] > a.lshape_map[:, 0]).any():
raise ValueError(
Expand Down
33 changes: 30 additions & 3 deletions heat/core/tests/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ def test_convolve(self):
ht.convolve(dis_signal, filter_wrong_type, mode="full")
with self.assertRaises(ValueError):
ht.convolve(dis_signal, kernel_odd, mode="invalid")
with self.assertRaises(ValueError):
s = dis_signal.reshape((2, -1))
ht.convolve(s, kernel_odd)
if dis_signal.comm.size > 1:
with self.assertRaises(ValueError):
s = dis_signal.reshape((2, -1)).resplit(axis=1)
ht.convolve(s, kernel_odd)
with self.assertRaises(ValueError):
k = ht.eye(3)
ht.convolve(dis_signal, k)
Expand Down Expand Up @@ -119,3 +120,29 @@ def test_convolve(self):

conv = ht.convolve(1, 5)
self.assertTrue(ht.equal(ht.array([5]), conv))

# test batched convolutions, distributed along the first axis
signal = ht.random.randn(1000, dtype=ht.float64)
batch_signal = ht.empty((10, 1000), dtype=ht.float64, split=0)
batch_signal.larray[:] = signal.larray
kernel = ht.random.randn(19, dtype=ht.float64)
batch_convolved = ht.convolve(batch_signal, kernel, mode="same")
self.assertTrue(ht.equal(ht.convolve(signal, kernel, mode="same"), batch_convolved[0]))
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved

# distributed kernel
dis_kernel = ht.array(kernel, split=0)
batch_convolved = ht.convolve(batch_signal, dis_kernel)
self.assertTrue(ht.equal(ht.convolve(signal, kernel), batch_convolved[0]))
batch_kernel = ht.empty((10, 19), dtype=ht.float64, split=1)
batch_kernel.larray[:] = dis_kernel.larray
batch_convolved = ht.convolve(batch_signal, batch_kernel, mode="full")
self.assertTrue(ht.equal(ht.convolve(signal, kernel, mode="full"), batch_convolved[0]))

mrfh92 marked this conversation as resolved.
Show resolved Hide resolved
# test batch-convolve exceptions
batch_kernel_wrong_shape = ht.random.randn(3, 19, dtype=ht.float64)
with self.assertRaises(ValueError):
ht.convolve(batch_signal, batch_kernel_wrong_shape)
if kernel.comm.size > 1:
batch_signal_wrong_split = batch_signal.resplit(1)
with self.assertRaises(ValueError):
ht.convolve(batch_signal_wrong_split, kernel)