From 17145ba62996367f60c9bc6d1d84a668c78810c6 Mon Sep 17 00:00:00 2001 From: Akira Date: Tue, 16 Jan 2024 00:08:38 +0900 Subject: [PATCH] adding convlstm exampes --- .gitignore | 1 + README.md | 13 ++++-- examples/__init__.py | 0 examples/moving_mnist_convlstm.py | 68 +++++++++++++++++++++++++++++++ 4 files changed, 78 insertions(+), 4 deletions(-) create mode 100644 examples/__init__.py create mode 100644 examples/moving_mnist_convlstm.py diff --git a/.gitignore b/.gitignore index 781d569..553baaa 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,4 @@ poetry.lock MovingMNIST/ tmp/ +*.pt diff --git a/README.md b/README.md index ae20d35..e0868b1 100644 --- a/README.md +++ b/README.md @@ -7,21 +7,26 @@ ## Examples +```bash +python -m examples.moving_mnist_convlstm + +``` + ## Directories -## `convlstm/` +### `convlstm/` ConvLSTM implementation based on [Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting](https://paperswithcode.com/paper/convolutional-lstm-network-a-machine-learning). -## `self_attention_convlstm/` +### `self_attention_convlstm/` Self Attention ConvLSTM implementation based on [Self-Attention ConvLSTM for Spatiotemporal Prediction](https://ojs.aaai.org/index.php/AAAI/article/view/6819/6673). -## `self_attention_memory_convlstm/` +### `self_attention_memory_convlstm/` Self-Attention ConvLSTM (with memory module) implementation based on [Self-Attention ConvLSTM for Spatiotemporal Prediction](https://ojs.aaai.org/index.php/AAAI/article/view/6819/6673). -## Visualized Attention Maps +### Visualized Attention Maps ![sa-convlstm](fig/sa-convlstm.png) diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/moving_mnist_convlstm.py b/examples/moving_mnist_convlstm.py new file mode 100644 index 0000000..b9e9c13 --- /dev/null +++ b/examples/moving_mnist_convlstm.py @@ -0,0 +1,68 @@ +from torch import nn +from torch.optim import Adam + +from convlstm.seq2seq import Seq2Seq, Seq2SeqParams +from core.constants import WeightsInitializer +from data_loaders.moving_mnist import MovingMNISTDataLoaders +from pipelines.experimenter import Experimenter +from pipelines.trainer import TrainingParams +from pipelines.utils.early_stopping import EarlyStopping + + +def main(): + ### + # Common Params + ### + artifact_dir = "./tmp" + input_seq_length = 10 + train_batch_size = 32 + validation_bath_size = 16 + ### + # Setup Pipeline + ### + model_params: Seq2SeqParams = { + "input_seq_length": input_seq_length, + "num_layers": 2, + "num_kernels": 64, + "return_sequences": False, + "convlstm_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_size": (3, 3), + "padding": "same", + "activation": "relu", + "frame_size": (64, 64), + "weights_initializer": WeightsInitializer.He, + }, + } + + model = Seq2Seq(**model_params) + + training_params: TrainingParams = { + "epochs": 1, + "loss_criterion": nn.BCELoss(reduction="sum"), + "accuracy_criterion": nn.L1Loss(), + "optimizer": Adam(model.parameters(), lr=1e-4), + "early_stopping": EarlyStopping( + patience=30, + verbose=True, + delta=0.0001, + ), + "metrics_filename": "metrics.csv", + } + + print("Loading dataset ...") + data_loaders = MovingMNISTDataLoaders( + train_batch_size=train_batch_size, + validation_batch_size=validation_bath_size, + input_frames=model_params["input_seq_length"], + label_frames=1, + split_ratios=[0.7, 0.299, 0.001], + ) + + experimenter = Experimenter(artifact_dir, data_loaders, model, training_params) + experimenter.run() + + +if __name__ == "__main__": + main()