Skip to content

Commit

Permalink
[ENH] Added possibility for pooling strides in TimeCNN (#2485)
Browse files Browse the repository at this point in the history
* Added strides_pooling in CNNNetwork

* Added strides_pooling in CNNClassifier

* Added strides_pooling in docstrings

* Automatic `pre-commit` fixes

* Line Limit

---------

Co-authored-by: kavya-r30 <kavya-r30@users.noreply.github.com>
  • Loading branch information
kavya-r30 and kavya-r30 authored Jan 15, 2025
1 parent 850e3cd commit 4d436d6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
6 changes: 6 additions & 0 deletions aeon/classification/deep_learning/_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class TimeCNNClassifier(BaseDeepClassifier):
strides : int or list of int, default = 1
The strides of kernels in the convolution and max pooling layers, if not a
list, the same strides are used for all layers.
strides_pooling : int or list of int, default = None
Strides for the pooling layers. If None, defaults to pool_size.
If not a list, the same strides are used for all pooling layers.
dilation_rate : int or list of int, default = 1
The dilation rate of the convolution layers, if not a list, the same dilation
rate is used all over the network.
Expand Down Expand Up @@ -125,6 +128,7 @@ def __init__(
activation="sigmoid",
padding="valid",
strides=1,
strides_pooling=None,
dilation_rate=1,
n_epochs=2000,
batch_size=16,
Expand All @@ -148,6 +152,7 @@ def __init__(
self.n_filters = n_filters
self.padding = padding
self.strides = strides
self.strides_pooling = strides_pooling
self.dilation_rate = dilation_rate
self.avg_pool_size = avg_pool_size
self.activation = activation
Expand Down Expand Up @@ -182,6 +187,7 @@ def __init__(
activation=self.activation,
padding=self.padding,
strides=self.strides,
strides_pooling=self.strides_pooling,
dilation_rate=self.dilation_rate,
use_bias=self.use_bias,
)
Expand Down
25 changes: 22 additions & 3 deletions aeon/networks/_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class TimeCNNNetwork(BaseDeepLearningNetwork):
strides : int or list of int, default = 1
The strides of kernels in the convolution and max pooling layers, if not a list,
the same strides are used for all layers.
strides_pooling : int or list of int, default = None
Strides for the pooling layers. If None, defaults to pool_size.
If not a list, the same strides are used for all pooling layers.
dilation_rate : int or list of int, default = 1
The dilation rate of the convolution layers, if not a list, the same dilation
rate is used all over the network.
Expand Down Expand Up @@ -65,6 +68,7 @@ def __init__(
activation="sigmoid",
padding="valid",
strides=1,
strides_pooling=None,
dilation_rate=1,
use_bias=True,
):
Expand All @@ -75,6 +79,7 @@ def __init__(
self.activation = activation
self.padding = padding
self.strides = strides
self.strides_pooling = strides_pooling
self.dilation_rate = dilation_rate
self.use_bias = use_bias

Expand Down Expand Up @@ -131,6 +136,19 @@ def build_network(self, input_shape, **kwargs):
else:
self._avg_pool_size = [self.avg_pool_size] * self.n_layers

if self.strides_pooling is None:
self._strides_pooling = self._avg_pool_size
elif isinstance(self.strides_pooling, list):
if len(self.strides_pooling) != self.n_layers:
raise ValueError(
f"Number of strides for pooling {len(self.strides_pooling)}"
f" should be the same as number of layers but is"
f" not: {self.n_layers}"
)
self._strides_pooling = self.strides_pooling
else:
self._strides_pooling = [self.strides_pooling] * self.n_layers

if isinstance(self.activation, list):
if len(self.activation) != self.n_layers:
raise ValueError(
Expand Down Expand Up @@ -204,9 +222,10 @@ def build_network(self, input_shape, **kwargs):
use_bias=self._use_bias[i],
)(x)

conv = tf.keras.layers.AveragePooling1D(pool_size=self._avg_pool_size[i])(
conv
)
conv = tf.keras.layers.AveragePooling1D(
pool_size=self._avg_pool_size[i],
strides=self._strides_pooling[i],
)(conv)

x = conv

Expand Down

0 comments on commit 4d436d6

Please sign in to comment.