Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
rossant committed Mar 8, 2020
1 parent 6edf536 commit 85873b5
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions pykilosort/cptools.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,27 +138,27 @@ def pad(fcn_convolve):
@wraps(fcn_convolve)
def function_wrapper(x, b, axis=0, **kwargs):
# add the padding to the array
assert axis == 0 # to simplify things for now
xsize = x.shape[axis]
if 'pad' in kwargs and kwargs['pad']:
npad = b.shape[axis] // 2
padd = cp.take(x, cp.arange(npad), axis=axis) * 0
padd = cp.zeros_like(x[:npad])
if kwargs['pad'] == 'zeros':
x = cp.concatenate((padd, x, padd), axis=axis)
if kwargs['pad'] == 'constant':
x = cp.concatenate((padd * 0 + cp.mean(x[:npad]), x, padd + cp.mean(x[-npad:])),
axis=axis)
x = cp.concatenate(
(padd + cp.mean(x[:npad]), x, padd + cp.mean(x[-npad:])), axis=axis)
if kwargs['pad'] == 'flip':
pad_in = cp.flip(cp.take(x, cp.arange(1, npad + 1), axis=axis), axis=axis)
pad_out = cp.flip(cp.take(x, cp.arange(xsize - npad - 1, xsize - 1),
axis=axis), axis=axis)
pad_in = cp.flip(x[1:npad + 1], axis=axis)
pad_out = cp.flip(x[xsize - npad - 1:xsize - 1], axis=axis)
x = cp.concatenate((pad_in, x, pad_out), axis=axis)
# run the convolution
y = fcn_convolve(x, b, **kwargs)
# remove padding from both arrays (necessary for x ?)
if 'pad' in kwargs and kwargs['pad']:
# remove the padding
y = cp.take(y, cp.arange(npad, x.shape[axis] - npad), axis=axis)
x = cp.take(x, cp.arange(npad, x.shape[axis] - npad), axis=axis)
y = y[npad:x.shape[axis] - npad]
x = x[npad:x.shape[axis] - npad]
assert xsize == x.shape[axis]
assert xsize == y.shape[axis]
return y
Expand Down

0 comments on commit 85873b5

Please sign in to comment.