From 7a4e9a3f82787c38c938413f2de09cf610e2168d Mon Sep 17 00:00:00 2001 From: tsugumi-sys Date: Tue, 26 Mar 2024 12:18:42 +0900 Subject: [PATCH] addign saconvlstm example --- README.md | 7 ++ ...ng_mnist_self_attention_memory_convlstm.py | 69 +++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100644 examples/moving_mnist_self_attention_memory_convlstm.py diff --git a/README.md b/README.md index 6be09e0..0c9265a 100644 --- a/README.md +++ b/README.md @@ -7,10 +7,17 @@ ## Examples +### ConvLSTM + ```bash python -m examples.moving_mnist_convlstm ``` +### Self-Attention ConvLSTM + +```bash +python -m examples.moving_mnist_self_attention_memory_convlstm +``` ## Directories ### `convlstm/` diff --git a/examples/moving_mnist_self_attention_memory_convlstm.py b/examples/moving_mnist_self_attention_memory_convlstm.py new file mode 100644 index 0000000..f5eb8fc --- /dev/null +++ b/examples/moving_mnist_self_attention_memory_convlstm.py @@ -0,0 +1,69 @@ +from torch import nn +from torch.optim import Adam + +from core.constants import WeightsInitializer +from data_loaders.moving_mnist import MovingMNISTDataLoaders +from pipelines.experimenter import Experimenter +from pipelines.trainer import TrainingParams +from pipelines.utils.early_stopping import EarlyStopping +from self_attention_memory_convlstm.seq2seq import SAMSeq2Seq, SAMSeq2SeqParams + + +def main(): + ### + # Common Params + ### + artifact_dir = "./tmp" + input_seq_length = 10 + train_batch_size = 32 + validation_bath_size = 16 + ### + # Setup Pipeline + ### + model_params: SAMSeq2SeqParams = { + "attention_hidden_dims": 2, + "input_seq_length": input_seq_length, + "num_layers": 2, + "num_kernels": 64, + "return_sequences": False, + "convlstm_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_size": (3, 3), + "padding": "same", + "activation": "relu", + "frame_size": (64, 64), + "weights_initializer": WeightsInitializer.He, + }, + } + + model = SAMSeq2Seq(**model_params) + + training_params: TrainingParams = { + "epochs": 1, + "loss_criterion": nn.BCELoss(reduction="sum"), + "accuracy_criterion": nn.L1Loss(), + "optimizer": Adam(model.parameters(), lr=1e-4), + "early_stopping": EarlyStopping( + patience=30, + verbose=True, + delta=0.0001, + ), + "metrics_filename": "metrics.csv", + } + + print("Loading dataset ...") + data_loaders = MovingMNISTDataLoaders( + train_batch_size=train_batch_size, + validation_batch_size=validation_bath_size, + input_frames=model_params["input_seq_length"], + label_frames=1, + split_ratios=[0.7, 0.299, 0.001], + ) + + experimenter = Experimenter(artifact_dir, data_loaders, model, training_params) + experimenter.run() + + +if __name__ == "__main__": + main()