Skip to content

Commit

Permalink
Use typed params for sconvlstm
Browse files Browse the repository at this point in the history
  • Loading branch information
tsugumi-sys committed Jan 9, 2024
1 parent 58c1314 commit a92e930
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions tests/self_attention_convlstm/test_model.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import torch

from core.constants import DEVICE
from self_attention_convlstm.model import SAConvLSTM
from core.constants import DEVICE, WeightsInitializer
from self_attention_convlstm.model import SAConvLSTM, SAConvLSTMParams


def test_ConvLSTM():
model = SAConvLSTM(
attention_hidden_dims=1,
in_channels=1,
out_channels=1,
kernel_size=3,
padding=1,
activation="relu",
frame_size=(8, 8),
)
model_params: SAConvLSTMParams = {
"attention_hidden_dims": 1,
"convlstm_params": {
"in_channels": 1,
"out_channels": 1,
"kernel_size": 3,
"padding": 1,
"activation": "relu",
"frame_size": (8, 8),
"weights_initializer": WeightsInitializer.He,
},
}
model = SAConvLSTM(**model_params)
output = model(torch.rand((2, 1, 3, 8, 8), dtype=torch.float, device=DEVICE))
assert output.size() == (2, 1, 3, 8, 8)

0 comments on commit a92e930

Please sign in to comment.