Skip to content

Commit

Permalink
Use typed params for ConvLSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
tsugumi-sys committed Jan 9, 2024
1 parent f74008e commit 721d630
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions tests/convlstm/test_model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import torch

from convlstm.model import ConvLSTM
from convlstm.model import ConvLSTM, ConvLSTMParams
from core.constants import DEVICE


def test_ConvLSTM():
model = ConvLSTM(
in_channels=1,
out_channels=1,
kernel_size=3,
padding=1,
activation="relu",
frame_size=(8, 8),
)
model_params: ConvLSTMParams = {
"in_channels": 1,
"out_channels": 1,
"kernel_size": 3,
"padding": 1,
"activation": "relu",
"frame_size": (8, 8),
}
model = ConvLSTM(**model_params)
output = model(torch.rand((2, 1, 3, 8, 8), dtype=torch.float, device=DEVICE))
assert output.size() == (2, 1, 3, 8, 8)

0 comments on commit 721d630

Please sign in to comment.