diff --git a/images/mLSTM_5.png b/images/mLSTM_5.png index 1f0d9df..8767e28 100644 Binary files a/images/mLSTM_5.png and b/images/mLSTM_5.png differ diff --git a/images/sLSTM_5.png b/images/sLSTM_5.png index ea6d87e..6c76caa 100644 Binary files a/images/sLSTM_5.png and b/images/sLSTM_5.png differ diff --git a/src/utils/test_mLSTM.py b/src/utils/test_mLSTM.py index 87e4023..8b8b569 100644 --- a/src/utils/test_mLSTM.py +++ b/src/utils/test_mLSTM.py @@ -19,7 +19,7 @@ # [batch_size, seq_len, hidden_size/features] model = mLSTM(input_size=input_size, hidden_size=hidden_size, mem_dim=mem_dim) -optimizer = optim.Adam(model.parameters(), lr=0.001) +optimizer = optim.Adam(model.parameters(), lr=0.01) criterion = nn.MSELoss() for epoch in range(500): diff --git a/src/utils/test_sLSTM.py b/src/utils/test_sLSTM.py index da0494e..b33bb98 100644 --- a/src/utils/test_sLSTM.py +++ b/src/utils/test_sLSTM.py @@ -7,17 +7,17 @@ from src.utils.sine_wave import generate_sine_wave input_size = 5 -hidden_size = 10 +hidden_size = 5 mem_dim = 5 seq_len = 100 -num_sequences = 1 +num_sequences = 2 data = generate_sine_wave( seq_len=seq_len, num_sequences=num_sequences, input_size=input_size ) model = sLSTM(input_size=input_size, hidden_size=hidden_size) -optimizer = optim.Adam(model.parameters(), lr=0.001) +optimizer = optim.Adam(model.parameters(), lr=0.01) criterion = nn.MSELoss() for epoch in range(500):