Skip to content

Commit

Permalink
adding convlstm exampes
Browse files Browse the repository at this point in the history
  • Loading branch information
tsugumi-sys committed Jan 15, 2024
1 parent 444575d commit 17145ba
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,4 @@ poetry.lock

MovingMNIST/
tmp/
*.pt
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,26 @@

## Examples

```bash
python -m examples.moving_mnist_convlstm

```

## Directories

## `convlstm/`
### `convlstm/`

ConvLSTM implementation based on [Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting](https://paperswithcode.com/paper/convolutional-lstm-network-a-machine-learning).

## `self_attention_convlstm/`
### `self_attention_convlstm/`

Self Attention ConvLSTM implementation based on [Self-Attention ConvLSTM for Spatiotemporal Prediction](https://ojs.aaai.org/index.php/AAAI/article/view/6819/6673).

## `self_attention_memory_convlstm/`
### `self_attention_memory_convlstm/`

Self-Attention ConvLSTM (with memory module) implementation based on [Self-Attention ConvLSTM for Spatiotemporal Prediction](https://ojs.aaai.org/index.php/AAAI/article/view/6819/6673).

## Visualized Attention Maps
### Visualized Attention Maps

![sa-convlstm](fig/sa-convlstm.png)

Expand Down
Empty file added examples/__init__.py
Empty file.
68 changes: 68 additions & 0 deletions examples/moving_mnist_convlstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from torch import nn
from torch.optim import Adam

from convlstm.seq2seq import Seq2Seq, Seq2SeqParams
from core.constants import WeightsInitializer
from data_loaders.moving_mnist import MovingMNISTDataLoaders
from pipelines.experimenter import Experimenter
from pipelines.trainer import TrainingParams
from pipelines.utils.early_stopping import EarlyStopping


def main():
###
# Common Params
###
artifact_dir = "./tmp"
input_seq_length = 10
train_batch_size = 32
validation_bath_size = 16
###
# Setup Pipeline
###
model_params: Seq2SeqParams = {
"input_seq_length": input_seq_length,
"num_layers": 2,
"num_kernels": 64,
"return_sequences": False,
"convlstm_params": {
"in_channels": 1,
"out_channels": 1,
"kernel_size": (3, 3),
"padding": "same",
"activation": "relu",
"frame_size": (64, 64),
"weights_initializer": WeightsInitializer.He,
},
}

model = Seq2Seq(**model_params)

training_params: TrainingParams = {
"epochs": 1,
"loss_criterion": nn.BCELoss(reduction="sum"),
"accuracy_criterion": nn.L1Loss(),
"optimizer": Adam(model.parameters(), lr=1e-4),
"early_stopping": EarlyStopping(
patience=30,
verbose=True,
delta=0.0001,
),
"metrics_filename": "metrics.csv",
}

print("Loading dataset ...")
data_loaders = MovingMNISTDataLoaders(
train_batch_size=train_batch_size,
validation_batch_size=validation_bath_size,
input_frames=model_params["input_seq_length"],
label_frames=1,
split_ratios=[0.7, 0.299, 0.001],
)

experimenter = Experimenter(artifact_dir, data_loaders, model, training_params)
experimenter.run()


if __name__ == "__main__":
main()

0 comments on commit 17145ba

Please sign in to comment.