Skip to content

Commit

Permalink
added tests to cover get_data_cfg function and StopLossNanTrainingHoo…
Browse files Browse the repository at this point in the history
…k after_train_iter method
  • Loading branch information
saltykox committed Apr 5, 2022
1 parent 6a7df1c commit c99aebd
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def cluster_anchors(config: Config, dataset: DatasetEntity, model: BaseDetector)
return config, model


@check_input_parameters_type()
def get_data_cfg(config: Config, subset: str = 'train') -> Config:
data_cfg = config.data[subset]
while 'dataset' in data_cfg:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
cluster_anchors,
config_from_string,
config_to_string,
get_data_cfg,
is_epoch_based_runner,
patch_adaptive_repeat_dataset,
patch_config,
Expand Down Expand Up @@ -416,3 +417,33 @@ def test_cluster_anchors_input_params_validation(self):
unexpected_values=unexpected_values,
class_or_function=cluster_anchors,
)

@e2e_pytest_unit
def test_get_data_cfg_input_params_validation(self):
"""
<b>Description:</b>
Check "get_data_cfg" function input parameters validation
<b>Input data:</b>
"get_data_cfg" function input parameters with unexpected type
<b>Expected results:</b>
Test passes if ValueError exception is raised when unexpected type object is specified as input parameter for
"get_data_cfg" function
"""
correct_values_dict = {
"config": Config(),
}
unexpected_int = 1
unexpected_values = [
# Unexpected integer is specified as "config" parameter
("config", unexpected_int),
# Unexpected integer is specified as "subset" parameter
("subset", unexpected_int),
]

check_value_error_exception_raised(
correct_parameters=correct_values_dict,
unexpected_values=unexpected_values,
class_or_function=get_data_cfg,
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
FixedMomentumUpdaterHook,
OTELoggerHook,
OTEProgressHook,
StopLossNanTrainingHook,
ReduceLROnPlateauLrUpdaterHook,
)
from mmcv.runner import EpochBasedRunner
Expand Down Expand Up @@ -532,3 +533,22 @@ def test_reduce_lr_hook_before_run_params_validation(self):
hook = self.hook()
with pytest.raises(ValueError):
hook.before_run(runner="unexpected string") # type: ignore


class TestStopLossNanTrainingHook:
@e2e_pytest_unit
def test_stop_loss_nan_train_hook_after_train_iter_params_validation(self):
"""
<b>Description:</b>
Check StopLossNanTrainingHook object "after_train_iter" method input parameters validation
<b>Input data:</b>
StopLossNanTrainingHook object, "runner" non-BaseRunner type object
<b>Expected results:</b>
Test passes if ValueError exception is raised when unexpected type object is specified as
input parameter for "after_train_iter" method
"""
hook = StopLossNanTrainingHook()
with pytest.raises(ValueError):
hook.after_train_iter(runner="unexpected string") # type: ignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def model():
)

@e2e_pytest_unit
def test_train_task_input_params_validation(self):
def test_train_task_train_input_params_validation(self):
"""
<b>Description:</b>
Check OTEDetectionTrainingTask object "train" method input parameters validation
Expand Down

0 comments on commit c99aebd

Please sign in to comment.