Skip to content

Commit

Permalink
Model: Adding label_seq_length for ConvLSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
tsugumi-sys committed May 23, 2024
1 parent 1fb3b34 commit fb1d511
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 2 deletions.
19 changes: 17 additions & 2 deletions convlstm/seq2seq.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from typing import NotRequired, TypedDict
import logging
from typing import NotRequired, Optional, TypedDict

import torch
from torch import nn

from convlstm.model import ConvLSTM, ConvLSTMParams

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class Seq2SeqParams(TypedDict):
input_seq_length: int
label_seq_length: NotRequired[Optional[int]]
num_layers: int
num_kernels: int
return_sequences: NotRequired[bool]
Expand All @@ -23,18 +28,25 @@ def __init__(
num_layers: int,
num_kernels: int,
convlstm_params: ConvLSTMParams,
label_seq_length: Optional[int] = None,
return_sequences: bool = False,
) -> None:
"""
Args:
input_seq_length (int): Number of input frames.
label_seq_length (Optional[int]): Number of label frames.
num_layers (int): Number of ConvLSTM layers.
num_kernels (int): Number of kernels.
return_sequences (int): If True, the model predict the next frames that is the same length of inputs. If False, the model predicts only one next frame.
return_sequences (int): If True, the model predict the next frames that is the same length of inputs. If False, the model predicts only one next frame or the frames given by `label_seq_length`.
"""
super().__init__()
self.input_seq_length = input_seq_length
self.label_seq_length = label_seq_length
if label_seq_length is not None and return_sequences is True:
logger.warning(
"the `label_seq_length` is ignored because `return_sequences` is set to True."
)
self.num_layers = num_layers
self.num_kernels = num_kernels
self.return_sequences = return_sequences
Expand Down Expand Up @@ -108,4 +120,7 @@ def forward(self, X: torch.Tensor):
if self.return_sequences is True:
return output

if self.label_seq_length:
return output[:, :, : self.label_seq_length, ...]

return output[:, :, -1:, ...]
67 changes: 67 additions & 0 deletions tests/convlstm/test_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,70 @@ def test_seq2seq(return_sequences: bool, expected_output_size: Tuple):
model = Seq2Seq(**model_params).to(DEVICE).to(torch.float)
output = model(torch.rand((2, 1, 2, 8, 8), dtype=torch.float, device=DEVICE))
assert output.size() == expected_output_size


def test_seq2seq_label_seq_length():
# test if `label_seq_length` is less than the number of frames of the given datasaet.
label_seq_length = 2
model_params: Seq2SeqParams = {
"input_seq_length": 4,
"label_seq_length": label_seq_length,
"num_layers": 2,
"num_kernels": 4,
"convlstm_params": {
"in_channels": 1,
"out_channels": 1,
"kernel_size": 3,
"padding": 1,
"activation": "relu",
"frame_size": (8, 8),
"weights_initializer": WeightsInitializer.He,
},
}
model = Seq2Seq(**model_params).to(DEVICE).to(torch.float)
output = model(torch.rand((2, 1, 4, 8, 8), dtype=torch.float, device=DEVICE))
assert output.size() == (2, 1, label_seq_length, 8, 8)

# test if `label_seq_length` is more than the number of frames of the given dataset.
label_seq_length = 5
model_params: Seq2SeqParams = {
"input_seq_length": 4,
"label_seq_length": label_seq_length,
"num_layers": 2,
"num_kernels": 4,
"convlstm_params": {
"in_channels": 1,
"out_channels": 1,
"kernel_size": 3,
"padding": 1,
"activation": "relu",
"frame_size": (8, 8),
"weights_initializer": WeightsInitializer.He,
},
}
model = Seq2Seq(**model_params).to(DEVICE).to(torch.float)
output = model(torch.rand((2, 1, 4, 8, 8), dtype=torch.float, device=DEVICE))
# the output has the same frames as the given dataset.
assert output.size() == (2, 1, 4, 8, 8)

# test if both `label_seq_length` and `return_sequences` are given.
model_params: Seq2SeqParams = {
"input_seq_length": 4,
"label_seq_length": 3,
"num_layers": 2,
"num_kernels": 4,
"return_sequences": True,
"convlstm_params": {
"in_channels": 1,
"out_channels": 1,
"kernel_size": 3,
"padding": 1,
"activation": "relu",
"frame_size": (8, 8),
"weights_initializer": WeightsInitializer.He,
},
}
model = Seq2Seq(**model_params).to(DEVICE).to(torch.float)
output = model(torch.rand((2, 1, 4, 8, 8), dtype=torch.float, device=DEVICE))
# the priority of `return_sequences` is higher than that of `label_seq_length`.
assert output.size() == (2, 1, 4, 8, 8)

0 comments on commit fb1d511

Please sign in to comment.