diff --git a/tests/test_attention_network.py b/tests/test_attention_network.py index cf73c12..c70c0ce 100644 --- a/tests/test_attention_network.py +++ b/tests/test_attention_network.py @@ -40,6 +40,7 @@ def test_attention_network(): heads =2 ) result = network(input_data) + print('result', result) print("result", result.shape) assert result.shape[0] == 32 assert result.shape[1] == 128