diff --git a/convlstm/model.py b/convlstm/model.py index 9571b84..43714ad 100644 --- a/convlstm/model.py +++ b/convlstm/model.py @@ -14,7 +14,7 @@ class ConvLSTMParams(TypedDict): padding: Union[int, Tuple, str] activation: str frame_size: Tuple[int, int] - weights_initializer: NotRequired[str] + weights_initializer: NotRequired[WeightsInitializer] class ConvLSTM(nn.Module): diff --git a/convlstm/seq2seq.py b/convlstm/seq2seq.py index 256afb6..7c53f65 100644 --- a/convlstm/seq2seq.py +++ b/convlstm/seq2seq.py @@ -1,15 +1,15 @@ -from typing import NotRequired, Optional, Tuple, TypedDict, Union +from typing import NotRequired, TypedDict import torch from torch import nn from convlstm.model import ConvLSTM, ConvLSTMParams -from core.constants import WeightsInitializer class Seq2SeqParams(TypedDict): - num_layers: int input_seq_length: int + num_layers: int + num_kernels: int return_sequences: NotRequired[bool] convlstm_params: ConvLSTMParams @@ -19,41 +19,32 @@ class Seq2Seq(nn.Module): def __init__( self, - num_channels: int, - kernel_size: Union[int, Tuple], - num_kernels: int, - padding: Union[int, Tuple, str], - activation: str, - frame_size: Tuple, - num_layers: int, input_seq_length: int, - out_channels: Optional[int] = None, - weights_initializer: WeightsInitializer = WeightsInitializer.Zeros, + num_layers: int, + num_kernels: int, + convlstm_params: ConvLSTMParams, return_sequences: bool = False, ) -> None: """ Args: - num_channels (int): Number of input channels. - kernel_size (int): kernel size. + input_seq_length (int): Number of input frames. + num_layers (int): Number of ConvLSTM layers. num_kernels (int): Number of kernels. - padding (Union[str, Tuple]): 'same', 'valid' or (int, int) - activation (str): the name of activation function. - frame_size (Tuple): height and width. - num_layers (int): the number of layers. + return_sequences (int): If True, the model predict the next frames that is the same length of inputs. If False, the model predicts only one next frame. """ super().__init__() - self.num_channels = num_channels - self.kernel_size = kernel_size - self.num_kernels = num_kernels - self.padding = padding - self.activation = activation - self.frame_size = frame_size - self.num_layers = num_layers self.input_seq_length = input_seq_length - self.out_channels = out_channels if out_channels is not None else num_channels - self.weights_initializer = weights_initializer + self.num_layers = num_layers + self.num_kernels = num_kernels self.return_sequences = return_sequences + self.in_channels = convlstm_params["in_channels"] + self.kernel_size = convlstm_params["kernel_size"] + self.padding = convlstm_params["padding"] + self.activation = convlstm_params["activation"] + self.frame_size = convlstm_params["frame_size"] + self.out_channels = convlstm_params["out_channels"] + self.weights_initializer = convlstm_params["weights_initializer"] self.sequential = nn.Sequential() @@ -61,19 +52,19 @@ def __init__( self.sequential.add_module( "convlstm1", ConvLSTM( - in_channels=num_channels, - out_channels=num_kernels, - kernel_size=kernel_size, - padding=padding, - activation=activation, - frame_size=frame_size, - weights_initializer=weights_initializer, + in_channels=self.in_channels, + out_channels=self.num_kernels, + kernel_size=self.kernel_size, + padding=self.padding, + activation=self.activation, + frame_size=self.frame_size, + weights_initializer=self.weights_initializer, ), ) self.sequential.add_module( "layernorm1", - nn.LayerNorm([num_kernels, self.input_seq_length, *self.frame_size]), + nn.LayerNorm([self.num_kernels, self.input_seq_length, *self.frame_size]), ) # Add the rest of the layers @@ -81,19 +72,21 @@ def __init__( self.sequential.add_module( f"convlstm{layer_idx}", ConvLSTM( - in_channels=num_kernels, - out_channels=num_kernels, - kernel_size=kernel_size, - padding=padding, - activation=activation, - frame_size=frame_size, - weights_initializer=weights_initializer, + in_channels=self.num_kernels, + out_channels=self.num_kernels, + kernel_size=self.kernel_size, + padding=self.padding, + activation=self.activation, + frame_size=self.frame_size, + weights_initializer=self.weights_initializer, ), ) self.sequential.add_module( f"layernorm{layer_idx}", - nn.LayerNorm([num_kernels, self.input_seq_length, *self.frame_size]), + nn.LayerNorm( + [self.num_kernels, self.input_seq_length, *self.frame_size] + ), ) self.sequential.add_module( diff --git a/tests/convlstm/test_model.py b/tests/convlstm/test_model.py index eaee23e..279dcdd 100644 --- a/tests/convlstm/test_model.py +++ b/tests/convlstm/test_model.py @@ -1,7 +1,7 @@ import torch from convlstm.model import ConvLSTM, ConvLSTMParams -from core.constants import DEVICE +from core.constants import DEVICE, WeightsInitializer def test_ConvLSTM(): @@ -12,6 +12,7 @@ def test_ConvLSTM(): "padding": 1, "activation": "relu", "frame_size": (8, 8), + "weights_initializer": WeightsInitializer.He, } model = ConvLSTM(**model_params) output = model(torch.rand((2, 1, 3, 8, 8), dtype=torch.float, device=DEVICE)) diff --git a/tests/convlstm/test_seq2seq.py b/tests/convlstm/test_seq2seq.py index a651a58..e41f8aa 100644 --- a/tests/convlstm/test_seq2seq.py +++ b/tests/convlstm/test_seq2seq.py @@ -3,8 +3,8 @@ import pytest import torch -from convlstm.seq2seq import Seq2Seq -from core.constants import DEVICE +from convlstm.seq2seq import Seq2Seq, Seq2SeqParams +from core.constants import DEVICE, WeightsInitializer @pytest.mark.parametrize( @@ -12,20 +12,21 @@ [(True, (2, 1, 2, 8, 8)), (False, (2, 1, 1, 8, 8))], ) def test_seq2seq(return_sequences: bool, expected_output_size: Tuple): - model = ( - Seq2Seq( - num_channels=1, - kernel_size=3, - num_kernels=4, - padding="same", - activation="relu", - frame_size=(8, 8), - num_layers=2, - input_seq_length=2, - return_sequences=return_sequences, - ) - .to(DEVICE) - .to(torch.float) - ) + model_params: Seq2SeqParams = { + "input_seq_length": 2, + "num_layers": 2, + "num_kernels": 4, + "return_sequences": return_sequences, + "convlstm_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_size": 3, + "padding": 1, + "activation": "relu", + "frame_size": (8, 8), + "weights_initializer": WeightsInitializer.He, + }, + } + model = Seq2Seq(**model_params).to(DEVICE).to(torch.float) output = model(torch.rand((2, 1, 2, 8, 8), dtype=torch.float, device=DEVICE)) assert output.size() == expected_output_size