Skip to content

Commit

Permalink
Change test_values_to_numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes E. M. Mosig committed Jan 26, 2021
1 parent 977d662 commit e57575d
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions tests/utils/tensorflow/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,9 @@
import tensorflow as tf
import numpy as np
import rasa.utils.tensorflow.numpy
import json
from typing import Optional, Dict, Any


class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)


@pytest.mark.parametrize(
"value, expected_result",
[
Expand All @@ -32,6 +24,8 @@ def test_values_to_numpy(
type(value) for value in sorted(actual_result.values())
]
assert actual_result_value_types == expected_result_value_types
assert json.dumps(actual_result, sort_keys=True, cls=NumpyEncoder) == json.dumps(
expected_result, sort_keys=True, cls=NumpyEncoder
)
for key, value in actual_result.items():
if isinstance(expected_result.get(key), np.ndarray):
np.testing.assert_equal(value, expected_result.get(key))
else:
assert value == expected_result.get(key)

0 comments on commit e57575d

Please sign in to comment.