Skip to content

Commit

Permalink
Use typed params for seq2seq
Browse files Browse the repository at this point in the history
  • Loading branch information
tsugumi-sys committed Jan 9, 2024
1 parent 1f7c457 commit 1841ca9
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 62 deletions.
2 changes: 1 addition & 1 deletion convlstm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ConvLSTMParams(TypedDict):
padding: Union[int, Tuple, str]
activation: str
frame_size: Tuple[int, int]
weights_initializer: NotRequired[str]
weights_initializer: NotRequired[WeightsInitializer]


class ConvLSTM(nn.Module):
Expand Down
79 changes: 36 additions & 43 deletions convlstm/seq2seq.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import NotRequired, Optional, Tuple, TypedDict, Union
from typing import NotRequired, TypedDict

import torch
from torch import nn

from convlstm.model import ConvLSTM, ConvLSTMParams
from core.constants import WeightsInitializer


class Seq2SeqParams(TypedDict):
num_layers: int
input_seq_length: int
num_layers: int
num_kernels: int
return_sequences: NotRequired[bool]
convlstm_params: ConvLSTMParams

Expand All @@ -19,81 +19,74 @@ class Seq2Seq(nn.Module):

def __init__(
self,
num_channels: int,
kernel_size: Union[int, Tuple],
num_kernels: int,
padding: Union[int, Tuple, str],
activation: str,
frame_size: Tuple,
num_layers: int,
input_seq_length: int,
out_channels: Optional[int] = None,
weights_initializer: WeightsInitializer = WeightsInitializer.Zeros,
num_layers: int,
num_kernels: int,
convlstm_params: ConvLSTMParams,
return_sequences: bool = False,
) -> None:
"""
Args:
num_channels (int): Number of input channels.
kernel_size (int): kernel size.
input_seq_length (int): Number of input frames.
num_layers (int): Number of ConvLSTM layers.
num_kernels (int): Number of kernels.
padding (Union[str, Tuple]): 'same', 'valid' or (int, int)
activation (str): the name of activation function.
frame_size (Tuple): height and width.
num_layers (int): the number of layers.
return_sequences (int): If True, the model predict the next frames that is the same length of inputs. If False, the model predicts only one next frame.
"""
super().__init__()
self.num_channels = num_channels
self.kernel_size = kernel_size
self.num_kernels = num_kernels
self.padding = padding
self.activation = activation
self.frame_size = frame_size
self.num_layers = num_layers
self.input_seq_length = input_seq_length
self.out_channels = out_channels if out_channels is not None else num_channels
self.weights_initializer = weights_initializer
self.num_layers = num_layers
self.num_kernels = num_kernels
self.return_sequences = return_sequences
self.in_channels = convlstm_params["in_channels"]
self.kernel_size = convlstm_params["kernel_size"]
self.padding = convlstm_params["padding"]
self.activation = convlstm_params["activation"]
self.frame_size = convlstm_params["frame_size"]
self.out_channels = convlstm_params["out_channels"]
self.weights_initializer = convlstm_params["weights_initializer"]

self.sequential = nn.Sequential()

# Add first layer (Different in_channels than the rest)
self.sequential.add_module(
"convlstm1",
ConvLSTM(
in_channels=num_channels,
out_channels=num_kernels,
kernel_size=kernel_size,
padding=padding,
activation=activation,
frame_size=frame_size,
weights_initializer=weights_initializer,
in_channels=self.in_channels,
out_channels=self.num_kernels,
kernel_size=self.kernel_size,
padding=self.padding,
activation=self.activation,
frame_size=self.frame_size,
weights_initializer=self.weights_initializer,
),
)

self.sequential.add_module(
"layernorm1",
nn.LayerNorm([num_kernels, self.input_seq_length, *self.frame_size]),
nn.LayerNorm([self.num_kernels, self.input_seq_length, *self.frame_size]),
)

# Add the rest of the layers
for layer_idx in range(2, num_layers + 1):
self.sequential.add_module(
f"convlstm{layer_idx}",
ConvLSTM(
in_channels=num_kernels,
out_channels=num_kernels,
kernel_size=kernel_size,
padding=padding,
activation=activation,
frame_size=frame_size,
weights_initializer=weights_initializer,
in_channels=self.num_kernels,
out_channels=self.num_kernels,
kernel_size=self.kernel_size,
padding=self.padding,
activation=self.activation,
frame_size=self.frame_size,
weights_initializer=self.weights_initializer,
),
)

self.sequential.add_module(
f"layernorm{layer_idx}",
nn.LayerNorm([num_kernels, self.input_seq_length, *self.frame_size]),
nn.LayerNorm(
[self.num_kernels, self.input_seq_length, *self.frame_size]
),
)

self.sequential.add_module(
Expand Down
3 changes: 2 additions & 1 deletion tests/convlstm/test_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from convlstm.model import ConvLSTM, ConvLSTMParams
from core.constants import DEVICE
from core.constants import DEVICE, WeightsInitializer


def test_ConvLSTM():
Expand All @@ -12,6 +12,7 @@ def test_ConvLSTM():
"padding": 1,
"activation": "relu",
"frame_size": (8, 8),
"weights_initializer": WeightsInitializer.He,
}
model = ConvLSTM(**model_params)
output = model(torch.rand((2, 1, 3, 8, 8), dtype=torch.float, device=DEVICE))
Expand Down
35 changes: 18 additions & 17 deletions tests/convlstm/test_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,30 @@
import pytest
import torch

from convlstm.seq2seq import Seq2Seq
from core.constants import DEVICE
from convlstm.seq2seq import Seq2Seq, Seq2SeqParams
from core.constants import DEVICE, WeightsInitializer


@pytest.mark.parametrize(
"return_sequences, expected_output_size",
[(True, (2, 1, 2, 8, 8)), (False, (2, 1, 1, 8, 8))],
)
def test_seq2seq(return_sequences: bool, expected_output_size: Tuple):
model = (
Seq2Seq(
num_channels=1,
kernel_size=3,
num_kernels=4,
padding="same",
activation="relu",
frame_size=(8, 8),
num_layers=2,
input_seq_length=2,
return_sequences=return_sequences,
)
.to(DEVICE)
.to(torch.float)
)
model_params: Seq2SeqParams = {
"input_seq_length": 2,
"num_layers": 2,
"num_kernels": 4,
"return_sequences": return_sequences,
"convlstm_params": {
"in_channels": 1,
"out_channels": 1,
"kernel_size": 3,
"padding": 1,
"activation": "relu",
"frame_size": (8, 8),
"weights_initializer": WeightsInitializer.He,
},
}
model = Seq2Seq(**model_params).to(DEVICE).to(torch.float)
output = model(torch.rand((2, 1, 2, 8, 8), dtype=torch.float, device=DEVICE))
assert output.size() == expected_output_size

0 comments on commit 1841ca9

Please sign in to comment.