Skip to content

Commit

Permalink
Merge pull request #81 from tsugumi-sys/fix/failed-tests-device-error
Browse files Browse the repository at this point in the history
fix failed tests on gpu device
  • Loading branch information
tsugumi-sys authored Jan 15, 2024
2 parents a6a8f28 + 78ac713 commit c2d533a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tests/self_attention_memory_convlstm/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ def test_SAMConvLSTM():
},
}

model = SAMConvLSTM(**model_params)
model = SAMConvLSTM(**model_params).to(DEVICE)
output = model(torch.rand((2, 1, 3, 8, 8), dtype=torch.float, device=DEVICE))
assert output.size() == (2, 1, 3, 8, 8)
5 changes: 3 additions & 2 deletions tests/test_model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import torch

from core.constants import DEVICE
from tests.test_model.model import TestModel


Expand All @@ -14,6 +15,6 @@ def test_TestModel_successfully_build(
return_sequences: bool, expected_output_size: Tuple
):
# return_sequence
model = TestModel(return_sequences)
output = model(torch.rand((5, 6, 3, 16, 16), dtype=torch.float))
model = TestModel(return_sequences).to(DEVICE)
output = model(torch.rand((5, 6, 3, 16, 16), dtype=torch.float, device=DEVICE))
assert output.size() == expected_output_size

0 comments on commit c2d533a

Please sign in to comment.