Skip to content

Commit

Permalink
Standardized Testing Module (tensorflow#2233)
Browse files Browse the repository at this point in the history
* refactored discover classes method to the utils.tests.standardized_testing.

refactored standardized test for optimizers to import from the new module.

* refactored to remove useless extra file

* fixed errors

* renamed file
  • Loading branch information
hyang0129 authored Nov 25, 2020
1 parent f0ddf04 commit 7552adf
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
17 changes: 4 additions & 13 deletions tensorflow_addons/optimizers/tests/standard_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import tensorflow as tf

from tensorflow_addons import optimizers
import inspect
from tensorflow_addons.utils.test_utils import discover_classes

class_exceptions = [
"MultiOptimizer", # is wrapper
Expand All @@ -32,18 +32,9 @@
]


def discover_classes(module, parent):

classes = [
class_info[1]
for class_info in inspect.getmembers(module, inspect.isclass)
if issubclass(class_info[1], parent) and not class_info[0] in class_exceptions
]

return classes


classes_to_test = discover_classes(optimizers, tf.keras.optimizers.Optimizer)
classes_to_test = discover_classes(
optimizers, tf.keras.optimizers.Optimizer, class_exceptions
)


@pytest.mark.parametrize("optimizer", classes_to_test)
Expand Down
21 changes: 21 additions & 0 deletions tensorflow_addons/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import os
import random
import inspect

import numpy as np
import pytest
Expand Down Expand Up @@ -271,3 +272,23 @@ def assert_allclose_according_to_type(
atol = max(atol, bfloat16_atol)

np.testing.assert_allclose(a, b, rtol=rtol, atol=atol)


def discover_classes(module, parent, class_exceptions):
"""
Args:
module: a module in which to search for classes that inherit from the parent class
parent: the parent class that identifies classes in the module that should be tested
class_exceptions: a list of specific classes that should be excluded when discovering classes in a module
Returns:
a list of classes for testing using pytest for parameterized tests
"""

classes = [
class_info[1]
for class_info in inspect.getmembers(module, inspect.isclass)
if issubclass(class_info[1], parent) and not class_info[0] in class_exceptions
]

return classes

0 comments on commit 7552adf

Please sign in to comment.