Skip to content

Commit

Permalink
Merge pull request #11 from icecube/1d_layers
Browse files Browse the repository at this point in the history
adding 1d layers
  • Loading branch information
mhuen authored Apr 10, 2024
2 parents c1994d1 + e650c2a commit bcf3304
Show file tree
Hide file tree
Showing 3 changed files with 325 additions and 11 deletions.
205 changes: 204 additions & 1 deletion tfscripts/conv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
'''
Conv functions for tfscripts:
convolution helper functions,
locally connected 2d and 3d convolutions [tf.Modules],
locally connected 1d, 2d, and 3d convolutions [tf.Modules],
dynamic 2d and 3d convolution,
local trafo 2d and 3d,
wrapper: trafo on patch 2d and 3d
Expand Down Expand Up @@ -225,6 +225,209 @@ def get_conv_slice(position, input_length, filter_size, stride, dilation=1):
return conv_slice, (padding_left, padding_right)


class LocallyConnected1d(tf.Module):
"""Like conv1d, but doesn't share weights.
"""

def __init__(self,
input_shape,
num_outputs,
filter_size,
kernel=None,
strides= [1],
padding='SAME',
dilation_rate=None,
float_precision=FLOAT_PRECISION,
name=None):
"""Initialize object
Parameters
----------
input_shape : TensorShape, or list of int
The shape of the inputs.
num_outputs : int
Number of output channels
filter_size : list of int of size 1
[filter x size]
kernel : tf.Tensor, optional
Optionally, the weights to be used as the kernel can be provided.
If a kernel is provided, a list of variables 'var_list' must also
be provided.
If None, new kernel weights are created.
strides : list of int
A list of ints that has length = 1. 1-D tensor of length 1.
The stride of the sliding window for each dimension of input.
padding : str
A string from: "SAME", "VALID".
The type of padding algorithm to use.
dilation_rate : None or list of int of length 1
[dilattion in x]
defines dilattion rate to be used
float_precision : tf.dtype, optional
The tensorflow dtype describing the float precision to use.
name : None, optional
The name of the tensorflow module.
Deleted Parameters
------------------
input_data : tf.Tensor
Input data.
"""
super(LocallyConnected1d, self).__init__(name=name)

if dilation_rate is None:
dilation_rate = [1]

# ------------------
# get shapes
# ------------------
if isinstance(input_shape, tf.TensorShape):
input_shape = input_shape.as_list()

# sanity checks
msg = 'Filter size must be of shape [x], but is {!r}'
assert len(filter_size) == 1, msg.format(filter_size)

msg = 'Filter sizes must be greater than 0, but are: {!r}'
assert np.prod(filter_size) > 0, msg.format(filter_size)

msg = 'Shape is expected to be of length 3, but is {!r}'
assert len(input_shape) == 3, msg.format(input_shape)

# calculate output shape
output_shape = np.empty(3, dtype=int)
for i in range(1):
output_shape[i+1] = conv_output_length(
input_length=input_shape[i + 1],
filter_size=filter_size[i],
padding=padding,
stride=strides[i],
dilation=dilation_rate[i])
output_shape[0] = -1
output_shape[2] = num_outputs

num_inputs = input_shape[2]

kernel_shape = (np.prod(output_shape[1:-1]),
np.prod(filter_size) * num_inputs,
num_outputs)

# ------------------
# Create Kernel
# ------------------
# fast shortcut
if kernel is None:
if list(filter_size) == [1]:
kernel = new_locally_connected_weights(
shape=input_shape[1:] + [num_outputs],
shared_axes=[0],
float_precision=float_precision)

else:
kernel = new_locally_connected_weights(
shape=kernel_shape,
shared_axes=[0],
float_precision=float_precision)

self.output_shape = output_shape
self.num_outputs = num_outputs
self.num_inputs = num_inputs
self.filter_size = filter_size
self.strides = strides
self.padding = padding
self.dilation_rate = dilation_rate
self.float_precision = float_precision
self.kernel = kernel

def __call__(self, inputs):
"""Apply 1d Locally Connected Module.
Parameters
----------
inputs : tf.Tensor
Input tensor.
Returns
-------
tf.Tensor
The output tensor.
"""

input_shape = inputs.get_shape().as_list()

# ------------------
# 1x1 convolution
# ------------------
# fast shortcut
if list(self.filter_size) == [1]:
output = tf.reduce_sum(
input_tensor=tf.expand_dims(inputs, axis=3) * self.kernel, axis=2)
return output

# ------------------
# get slices
# ------------------
start_indices = [get_start_index(input_length=input_shape[i + 1],
filter_size=self.filter_size[i],
padding=self.padding,
stride=self.strides[i],
dilation=self.dilation_rate[i])
for i in range(1)]

input_patches = []
# ---------------------------
# loop over all x positions
# ---------------------------
for x in range(start_indices[0], input_shape[1], self.strides[0]):

# get slice for patch along x-axis
slice_x, padding_x = get_conv_slice(
x,
input_length=input_shape[1],
filter_size=self.filter_size[0],
stride=self.strides[0],
dilation=self.dilation_rate[0])

if self.padding == 'VALID' and padding_x != (0):
# skip this x position, since it does not provide
# a valid patch for padding 'VALID'
continue

# ------------------------------------------
# Get input patch at filter position x
# ------------------------------------------
input_patch = inputs[:, slice_x, :]

if self.padding == 'SAME':
# pad with zeros
paddings = [(0, 0), padding_x, (0, 0)]
if paddings != [(0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]:
input_patch = tf.pad(tensor=input_patch,
paddings=paddings,
mode='CONSTANT',
)

# reshape
input_patch = tf.reshape(
input_patch,
[-1, 1, np.prod(self.filter_size)*self.num_inputs, 1])

# append to list
input_patches.append(input_patch)
# ------------------------------------------

# concat input patches
input_patches = tf.concat(input_patches, axis=1)

# ------------------
# perform convolution
# ------------------
output = input_patches * self.kernel
output = tf.reduce_sum(input_tensor=output, axis=2)
output = tf.reshape(output, self.output_shape)
return output


class LocallyConnected2d(tf.Module):
"""Like conv2d, but doesn't share weights.
"""
Expand Down
59 changes: 49 additions & 10 deletions tfscripts/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def flatten_hex_layer(hex_layer):
class ConvNdLayer(tf.Module):
"""TF Module for creating a new nD Convolutional Layer
2 <= n <=4 are supported.
1 <= n <=4 are supported.
For n == 3 (3 spatial dimensions x, y, and z):
input: (n+2)-dim tensor of shape [batch, x, y, z, num_input_channels]
output: (n+2)-dim tensor of shape [batch, x_p, y_p, z_p, num_filters]
Expand Down Expand Up @@ -321,9 +321,18 @@ def __init__(self,
pooling_ksize = [1, 2, 2, 1]
if strides is None:
strides = [1, 1, 1, 1]

elif num_dims == 1:
# 1D convolution
if pooling_strides is None:
pooling_strides = [1, 2, 1]
if pooling_ksize is None:
pooling_ksize = [1, 2, 1]
if strides is None:
strides = [1, 1, 1]

else:
msg = 'Currently only 2D, 3D or 4D supported {!r}'
msg = 'Currently only 1D, 2D, 3D or 4D supported {!r}'
raise ValueError(msg.format(input_shape))

# make sure inferred dimension matches filter_size
Expand All @@ -348,7 +357,7 @@ def __init__(self,
biases = new_biases(length=num_filters,
float_precision=float_precision)

if num_dims == 2 or num_dims == 3:
if num_dims == 1 or num_dims == 2 or num_dims == 3:
# create a temp function with all parameters set
def temp_func(inputs):
return tf.nn.convolution(input=inputs,
Expand All @@ -373,7 +382,10 @@ def temp_func(inputs):
# Hexagonal convolution
# ---------------------
elif method.lower() == 'hex_convolution':
if num_dims == 2 or num_dims == 3:
if num_dims == 1:
raise NotImplementedError(
'1D hex_convolution not implemented')
elif num_dims == 2 or num_dims == 3:
self.conv_layer = hx.ConvHex(
input_shape=input_shape,
filter_size=filter_size,
Expand Down Expand Up @@ -414,7 +426,17 @@ def temp_func(inputs):
# -------------------
elif method.lower() == 'locally_connected':

if num_dims == 2:
if num_dims == 1:
self.conv_layer = conv.LocallyConnected1d(
input_shape=input_shape,
num_outputs=num_filters,
filter_size=filter_size,
kernel=weights,
strides=strides[1:-1],
padding=padding,
dilation_rate=dilation_rate,
float_precision=float_precision)
elif num_dims == 2:
self.conv_layer = conv.LocallyConnected2d(
input_shape=input_shape,
num_outputs=num_filters,
Expand Down Expand Up @@ -479,7 +501,7 @@ def temp_func(inputs):

assert weights is not None

if num_dims == 2 or num_dims == 3:
if num_dims == 1 or num_dims == 2 or num_dims == 3:
# create a temp function with all parameters set
def temp_func(inputs):
return conv.dynamic_conv(inputs=inputs,
Expand Down Expand Up @@ -647,7 +669,14 @@ def _apply_pooling(self, layer):
layer : tf.Tensor
The layer on which to apply pooling
"""
if self.num_dims == 2:
if self.num_dims == 1:
layer = pooling.pool1d(layer=layer,
ksize=self.pooling_ksize,
strides=self.pooling_strides,
padding=self.pooling_padding,
pooling_type=self.pooling_type,
)
elif self.num_dims == 2:
layer = pooling.pool2d(layer=layer,
ksize=self.pooling_ksize,
strides=self.pooling_strides,
Expand Down Expand Up @@ -678,7 +707,7 @@ def _apply_pooling(self, layer):
raise NotImplementedError("Pooling type not supported: "
"{!r}".format(self.pooling_type))
else:
raise NotImplementedError('Only supported 2d, 3d, 4d!')
raise NotImplementedError('Only supported 1d, 2d, 3d, 4d!')

return layer

Expand Down Expand Up @@ -1258,7 +1287,7 @@ def __call__(self, inputs, is_training, keep_prob=None):


class ConvNdLayers(tf.Module):
"""TF Module for creating new conv2d, conv3d, and conv4d layers.
"""TF Module for creating new conv1d, conv2d, conv3d, and conv4d layers.
"""

def __init__(self,
Expand Down Expand Up @@ -1455,6 +1484,7 @@ def __init__(self,
pooling_ksize_list = [1, 2, 2, 2, 2, 1]
if strides_list is None:
strides_list = [1, 1, 1, 1, 1, 1]

elif num_dims == 5:
# 3D convolution
if pooling_strides_list is None:
Expand All @@ -1472,9 +1502,18 @@ def __init__(self,
pooling_ksize_list = [1, 2, 2, 1]
if strides_list is None:
strides_list = [1, 1, 1, 1]

elif num_dims == 3:
# 1D convolution
if pooling_strides_list is None:
pooling_strides_list = [1, 2, 1]
if pooling_ksize_list is None:
pooling_ksize_list = [1, 2, 1]
if strides_list is None:
strides_list = [1, 1, 1]

else:
msg = 'Currently only 2D, 3D, or 4D supported {!r}'
msg = 'Currently only 1D, 2D, 3D, or 4D supported {!r}'
raise ValueError(msg.format(input_shape))

num_layers = len(num_filters_list)
Expand Down
Loading

0 comments on commit bcf3304

Please sign in to comment.